Skip to content

Commit c5f1aa6

Browse files
committed
[WIP] Support fireworks login
1 parent 3f7b4c3 commit c5f1aa6

File tree

7 files changed

+452
-30
lines changed

7 files changed

+452
-30
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,28 @@ Install with pip:
7070
pip install eval-protocol
7171
```
7272

73+
## Fireworks Login (REST)
74+
75+
Use the CLI to sign in without gRPC.
76+
77+
```
78+
# API key flow
79+
eval-protocol login --api-key YOUR_KEY --account-id YOUR_ACCOUNT_ID --validate
80+
81+
# OAuth2 device flow (like firectl)
82+
eval-protocol login --oauth --issuer https://YOUR_ISSUER --client-id YOUR_PUBLIC_CLIENT_ID \
83+
--account-id YOUR_ACCOUNT_ID --open-browser
84+
```
85+
86+
- Omit `--api-key` to be prompted securely.
87+
- Omit `--account-id` to save only the key; you can add it later.
88+
- Add `--api-base https://api.fireworks.ai` for a custom base, if needed.
89+
- For OAuth2, you can also set env vars: `FIREWORKS_OIDC_ISSUER`, `FIREWORKS_OAUTH_CLIENT_ID`, `FIREWORKS_OAUTH_SCOPE`.
90+
91+
Credentials are stored at `~/.fireworks/auth.ini` with 600 permissions and are read automatically by the SDK.
92+
93+
Note: Model/LLM calls still require a Fireworks API key. OAuth login alone does not enable LLM calls yet; ensure `FIREWORKS_API_KEY` is set or saved via `eval-protocol login --api-key ...`.
94+
7395
## License
7496

7597
[MIT](LICENSE)

eval_protocol/auth.py

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import logging
33
import os
44
from pathlib import Path
5+
import time
56
from typing import Dict, Optional # Added Dict
67

8+
import requests
9+
710
logger = logging.getLogger(__name__)
811

912
FIREWORKS_CONFIG_DIR = Path.home() / ".fireworks"
@@ -36,7 +39,19 @@ def _parse_simple_auth_file(file_path: Path) -> Dict[str, str]:
3639
):
3740
value = value[1:-1]
3841

39-
if key in ["api_key", "account_id"] and value:
42+
if key in [
43+
"api_key",
44+
"account_id",
45+
"api_base",
46+
# OAuth2-related keys
47+
"issuer",
48+
"client_id",
49+
"access_token",
50+
"refresh_token",
51+
"expires_at",
52+
"scope",
53+
"token_type",
54+
] and value:
4055
creds[key] = value
4156
except Exception as e:
4257
logger.warning(f"Error during simple parsing of {file_path}: {e}")
@@ -142,15 +157,135 @@ def get_fireworks_api_base() -> str:
142157
"""
143158
Retrieves the Fireworks API base URL.
144159
145-
The base URL is sourced from the FIREWORKS_API_BASE environment variable.
146-
If not set, it defaults to "https://api.fireworks.ai".
160+
The base URL is sourced in the following order:
161+
1. FIREWORKS_API_BASE environment variable.
162+
2. 'api_base' from the [fireworks] section of ~/.fireworks/auth.ini (or simple key=val).
163+
3. Defaults to "https://api.fireworks.ai".
147164
148165
Returns:
149166
The API base URL.
150167
"""
151-
api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai")
152-
if os.environ.get("FIREWORKS_API_BASE"):
168+
env_api_base = os.environ.get("FIREWORKS_API_BASE")
169+
if env_api_base:
153170
logger.debug("Using FIREWORKS_API_BASE from environment variable.")
154-
else:
155-
logger.debug(f"FIREWORKS_API_BASE not set in environment, defaulting to {api_base}.")
156-
return api_base
171+
return env_api_base
172+
173+
file_api_base = _get_credential_from_config_file("api_base")
174+
if file_api_base:
175+
logger.debug("Using api_base from auth.ini configuration.")
176+
return file_api_base
177+
178+
default_base = "https://api.fireworks.ai"
179+
logger.debug(f"FIREWORKS_API_BASE not set; defaulting to {default_base}.")
180+
return default_base
181+
182+
183+
def _get_from_env_or_file(key_name: str) -> Optional[str]:
184+
# 1. Check env
185+
env_val = os.environ.get(key_name.upper())
186+
if env_val:
187+
return env_val
188+
# 2. Check config file
189+
return _get_credential_from_config_file(key_name.lower())
190+
191+
192+
def _write_auth_config(updates: Dict[str, str]) -> None:
193+
"""Merge-write simple key=value pairs into AUTH_INI_FILE preserving existing values."""
194+
FIREWORKS_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
195+
existing = _parse_simple_auth_file(AUTH_INI_FILE)
196+
existing.update({k: v for k, v in updates.items() if v is not None})
197+
lines = [f"{k}={v}" for k, v in existing.items()]
198+
AUTH_INI_FILE.write_text("\n".join(lines) + "\n")
199+
try:
200+
os.chmod(AUTH_INI_FILE, 0o600)
201+
except Exception:
202+
pass
203+
204+
205+
def _discover_oidc(issuer: str) -> Dict[str, str]:
206+
"""Fetch OIDC discovery doc. Returns empty dict on failure."""
207+
try:
208+
url = issuer.rstrip("/") + "/.well-known/openid-configuration"
209+
resp = requests.get(url, timeout=10)
210+
if resp.ok:
211+
return resp.json()
212+
except Exception:
213+
return {}
214+
return {}
215+
216+
217+
def _refresh_oauth_token_if_needed() -> Optional[str]:
218+
"""Refresh OAuth access token if expired and refresh token available. Returns current/new token or None."""
219+
cfg = _parse_simple_auth_file(AUTH_INI_FILE)
220+
access_token = cfg.get("access_token")
221+
refresh_token = cfg.get("refresh_token")
222+
expires_at_str = cfg.get("expires_at")
223+
issuer = cfg.get("issuer") or os.environ.get("FIREWORKS_OIDC_ISSUER")
224+
client_id = cfg.get("client_id") or os.environ.get("FIREWORKS_OAUTH_CLIENT_ID")
225+
226+
# If we have no expiry, just return access token (best effort)
227+
if not refresh_token or not issuer or not client_id:
228+
return access_token
229+
230+
now = int(time.time())
231+
try:
232+
expires_at = int(expires_at_str) if expires_at_str else None
233+
except ValueError:
234+
expires_at = None
235+
236+
# If not expired (with 60s buffer), return current token
237+
if access_token and expires_at and expires_at - 60 > now:
238+
return access_token
239+
240+
# Attempt refresh
241+
discovery = _discover_oidc(issuer)
242+
token_endpoint = discovery.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token"
243+
data = {
244+
"grant_type": "refresh_token",
245+
"refresh_token": refresh_token,
246+
"client_id": client_id,
247+
}
248+
try:
249+
resp = requests.post(token_endpoint, data=data, timeout=15)
250+
if not resp.ok:
251+
logger.warning(f"OAuth token refresh failed: {resp.status_code} {resp.text[:200]}")
252+
return access_token
253+
tok = resp.json()
254+
new_access = tok.get("access_token")
255+
new_refresh = tok.get("refresh_token") or refresh_token
256+
expires_in = tok.get("expires_in")
257+
new_expires_at = str(now + int(expires_in)) if expires_in else expires_at_str
258+
_write_auth_config(
259+
{
260+
"access_token": new_access,
261+
"refresh_token": new_refresh,
262+
"expires_at": new_expires_at,
263+
"token_type": tok.get("token_type") or cfg.get("token_type") or "Bearer",
264+
"scope": tok.get("scope") or cfg.get("scope") or "",
265+
}
266+
)
267+
return new_access or access_token
268+
except Exception as e:
269+
logger.debug(f"Exception during oauth refresh: {e}")
270+
return access_token
271+
272+
273+
def get_auth_bearer() -> Optional[str]:
274+
"""Return a bearer token to use in Authorization.
275+
276+
Priority:
277+
1. FIREWORKS_ACCESS_TOKEN env
278+
2. FIREWORKS_API_KEY env
279+
3. Refreshed OAuth access_token from auth.ini (if present)
280+
4. api_key from auth.ini
281+
"""
282+
env_access = os.environ.get("FIREWORKS_ACCESS_TOKEN")
283+
if env_access:
284+
return env_access
285+
env_key = os.environ.get("FIREWORKS_API_KEY")
286+
if env_key:
287+
return env_key
288+
refreshed = _refresh_oauth_token_if_needed()
289+
if refreshed:
290+
return refreshed
291+
return _get_credential_from_config_file("api_key")

eval_protocol/cli.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .cli_commands.logs import logs_command
2929
from .cli_commands.preview import preview_command
3030
from .cli_commands.run_eval_cmd import hydra_cli_entry_point
31+
from .cli_commands.login import login_command
3132

3233

3334
def parse_args(args=None):
@@ -37,6 +38,30 @@ def parse_args(args=None):
3738

3839
subparsers = parser.add_subparsers(dest="command", help="Command to run")
3940

41+
# Login command
42+
login_parser = subparsers.add_parser(
43+
"login", help="Sign in to Fireworks via API key or OAuth2 device flow"
44+
)
45+
# API key flow
46+
login_parser.add_argument("--api-key", help="Fireworks API key (prompted if not provided)")
47+
# OAuth2 flow toggles
48+
login_parser.add_argument("--oauth", action="store_true", help="Use OAuth2 device flow (like firectl)")
49+
login_parser.add_argument("--issuer", help="OIDC issuer URL (e.g., https://auth.fireworks.ai)")
50+
login_parser.add_argument("--client-id", help="OAuth2 public client id for device flow")
51+
login_parser.add_argument(
52+
"--scope",
53+
help="OAuth2 scopes (default: 'openid offline_access email profile')",
54+
)
55+
login_parser.add_argument(
56+
"--open-browser", action="store_true", help="Attempt to open the verification URL in a browser"
57+
)
58+
# Common options
59+
login_parser.add_argument("--account-id", help="Fireworks Account ID to associate with this login")
60+
login_parser.add_argument("--api-base", help="Custom API base (defaults to https://api.fireworks.ai)")
61+
vgroup = login_parser.add_mutually_exclusive_group()
62+
vgroup.add_argument("--validate", action="store_true", help="Validate account with a test API call (API key flow)")
63+
vgroup.add_argument("--no-validate", action="store_true", help="Do not validate; just write the file")
64+
4065
# Preview command
4166
preview_parser = subparsers.add_parser("preview", help="Preview an evaluator with sample data")
4267
preview_parser.add_argument(
@@ -338,6 +363,10 @@ def main():
338363

339364
if args.command == "preview":
340365
return preview_command(args)
366+
elif args.command == "login":
367+
# translate mutually exclusive group into a single boolean
368+
setattr(args, "validate", bool(getattr(args, "validate", False) and not getattr(args, "no_validate", False)))
369+
return login_command(args)
341370
elif args.command == "deploy":
342371
return deploy_command(args)
343372
elif args.command == "deploy-mcp":

eval_protocol/cli_commands/common.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import os
88
from typing import Any, Dict, Iterator, List, Optional
99

10+
from eval_protocol.auth import get_auth_bearer
11+
1012
logger = logging.getLogger(__name__)
1113

1214

@@ -42,13 +44,22 @@ def setup_logging(verbose=False, debug=False):
4244

4345

4446
def check_environment():
45-
"""Check if required environment variables are set for general commands."""
46-
if not os.environ.get("FIREWORKS_API_KEY"):
47-
logger.warning("FIREWORKS_API_KEY environment variable is not set.")
48-
logger.warning("This is required for API calls. Set this variable before running the command.")
49-
logger.warning("Example: FIREWORKS_API_KEY=$DEV_FIREWORKS_API_KEY reward-kit [command]")
50-
return False
51-
return True
47+
"""Check if credentials are available for non-LLM API calls.
48+
49+
Accepts either FIREWORKS_API_KEY or an OAuth bearer (FIREWORKS_ACCESS_TOKEN or tokens in auth.ini).
50+
LLM calls elsewhere still explicitly require FIREWORKS_API_KEY.
51+
"""
52+
if os.environ.get("FIREWORKS_API_KEY"):
53+
return True
54+
bearer = get_auth_bearer()
55+
if bearer:
56+
if not os.environ.get("FIREWORKS_API_KEY"):
57+
logger.info(
58+
"Using OAuth bearer for non-LLM API calls. Note: LLM/model calls still require FIREWORKS_API_KEY."
59+
)
60+
return True
61+
logger.warning("No credentials found. Set FIREWORKS_API_KEY or login via OAuth: eval-protocol login --oauth ...")
62+
return False
5263

5364

5465
def check_agent_environment(test_mode=False):

0 commit comments

Comments
 (0)