|
2 | 2 | import logging |
3 | 3 | import os |
4 | 4 | from pathlib import Path |
| 5 | +import time |
5 | 6 | from typing import Dict, Optional # Added Dict |
6 | 7 |
|
| 8 | +import requests |
| 9 | + |
7 | 10 | logger = logging.getLogger(__name__) |
8 | 11 |
|
9 | 12 | FIREWORKS_CONFIG_DIR = Path.home() / ".fireworks" |
@@ -36,7 +39,19 @@ def _parse_simple_auth_file(file_path: Path) -> Dict[str, str]: |
36 | 39 | ): |
37 | 40 | value = value[1:-1] |
38 | 41 |
|
39 | | - if key in ["api_key", "account_id"] and value: |
| 42 | + if key in [ |
| 43 | + "api_key", |
| 44 | + "account_id", |
| 45 | + "api_base", |
| 46 | + # OAuth2-related keys |
| 47 | + "issuer", |
| 48 | + "client_id", |
| 49 | + "access_token", |
| 50 | + "refresh_token", |
| 51 | + "expires_at", |
| 52 | + "scope", |
| 53 | + "token_type", |
| 54 | + ] and value: |
40 | 55 | creds[key] = value |
41 | 56 | except Exception as e: |
42 | 57 | logger.warning(f"Error during simple parsing of {file_path}: {e}") |
@@ -142,15 +157,135 @@ def get_fireworks_api_base() -> str: |
142 | 157 | """ |
143 | 158 | Retrieves the Fireworks API base URL. |
144 | 159 |
|
145 | | - The base URL is sourced from the FIREWORKS_API_BASE environment variable. |
146 | | - If not set, it defaults to "https://api.fireworks.ai". |
| 160 | + The base URL is sourced in the following order: |
| 161 | + 1. FIREWORKS_API_BASE environment variable. |
| 162 | + 2. 'api_base' from the [fireworks] section of ~/.fireworks/auth.ini (or simple key=val). |
| 163 | + 3. Defaults to "https://api.fireworks.ai". |
147 | 164 |
|
148 | 165 | Returns: |
149 | 166 | The API base URL. |
150 | 167 | """ |
151 | | - api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") |
152 | | - if os.environ.get("FIREWORKS_API_BASE"): |
| 168 | + env_api_base = os.environ.get("FIREWORKS_API_BASE") |
| 169 | + if env_api_base: |
153 | 170 | logger.debug("Using FIREWORKS_API_BASE from environment variable.") |
154 | | - else: |
155 | | - logger.debug(f"FIREWORKS_API_BASE not set in environment, defaulting to {api_base}.") |
156 | | - return api_base |
| 171 | + return env_api_base |
| 172 | + |
| 173 | + file_api_base = _get_credential_from_config_file("api_base") |
| 174 | + if file_api_base: |
| 175 | + logger.debug("Using api_base from auth.ini configuration.") |
| 176 | + return file_api_base |
| 177 | + |
| 178 | + default_base = "https://api.fireworks.ai" |
| 179 | + logger.debug(f"FIREWORKS_API_BASE not set; defaulting to {default_base}.") |
| 180 | + return default_base |
| 181 | + |
| 182 | + |
| 183 | +def _get_from_env_or_file(key_name: str) -> Optional[str]: |
| 184 | + # 1. Check env |
| 185 | + env_val = os.environ.get(key_name.upper()) |
| 186 | + if env_val: |
| 187 | + return env_val |
| 188 | + # 2. Check config file |
| 189 | + return _get_credential_from_config_file(key_name.lower()) |
| 190 | + |
| 191 | + |
| 192 | +def _write_auth_config(updates: Dict[str, str]) -> None: |
| 193 | + """Merge-write simple key=value pairs into AUTH_INI_FILE preserving existing values.""" |
| 194 | + FIREWORKS_CONFIG_DIR.mkdir(parents=True, exist_ok=True) |
| 195 | + existing = _parse_simple_auth_file(AUTH_INI_FILE) |
| 196 | + existing.update({k: v for k, v in updates.items() if v is not None}) |
| 197 | + lines = [f"{k}={v}" for k, v in existing.items()] |
| 198 | + AUTH_INI_FILE.write_text("\n".join(lines) + "\n") |
| 199 | + try: |
| 200 | + os.chmod(AUTH_INI_FILE, 0o600) |
| 201 | + except Exception: |
| 202 | + pass |
| 203 | + |
| 204 | + |
| 205 | +def _discover_oidc(issuer: str) -> Dict[str, str]: |
| 206 | + """Fetch OIDC discovery doc. Returns empty dict on failure.""" |
| 207 | + try: |
| 208 | + url = issuer.rstrip("/") + "/.well-known/openid-configuration" |
| 209 | + resp = requests.get(url, timeout=10) |
| 210 | + if resp.ok: |
| 211 | + return resp.json() |
| 212 | + except Exception: |
| 213 | + return {} |
| 214 | + return {} |
| 215 | + |
| 216 | + |
| 217 | +def _refresh_oauth_token_if_needed() -> Optional[str]: |
| 218 | + """Refresh OAuth access token if expired and refresh token available. Returns current/new token or None.""" |
| 219 | + cfg = _parse_simple_auth_file(AUTH_INI_FILE) |
| 220 | + access_token = cfg.get("access_token") |
| 221 | + refresh_token = cfg.get("refresh_token") |
| 222 | + expires_at_str = cfg.get("expires_at") |
| 223 | + issuer = cfg.get("issuer") or os.environ.get("FIREWORKS_OIDC_ISSUER") |
| 224 | + client_id = cfg.get("client_id") or os.environ.get("FIREWORKS_OAUTH_CLIENT_ID") |
| 225 | + |
| 226 | + # If we have no expiry, just return access token (best effort) |
| 227 | + if not refresh_token or not issuer or not client_id: |
| 228 | + return access_token |
| 229 | + |
| 230 | + now = int(time.time()) |
| 231 | + try: |
| 232 | + expires_at = int(expires_at_str) if expires_at_str else None |
| 233 | + except ValueError: |
| 234 | + expires_at = None |
| 235 | + |
| 236 | + # If not expired (with 60s buffer), return current token |
| 237 | + if access_token and expires_at and expires_at - 60 > now: |
| 238 | + return access_token |
| 239 | + |
| 240 | + # Attempt refresh |
| 241 | + discovery = _discover_oidc(issuer) |
| 242 | + token_endpoint = discovery.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token" |
| 243 | + data = { |
| 244 | + "grant_type": "refresh_token", |
| 245 | + "refresh_token": refresh_token, |
| 246 | + "client_id": client_id, |
| 247 | + } |
| 248 | + try: |
| 249 | + resp = requests.post(token_endpoint, data=data, timeout=15) |
| 250 | + if not resp.ok: |
| 251 | + logger.warning(f"OAuth token refresh failed: {resp.status_code} {resp.text[:200]}") |
| 252 | + return access_token |
| 253 | + tok = resp.json() |
| 254 | + new_access = tok.get("access_token") |
| 255 | + new_refresh = tok.get("refresh_token") or refresh_token |
| 256 | + expires_in = tok.get("expires_in") |
| 257 | + new_expires_at = str(now + int(expires_in)) if expires_in else expires_at_str |
| 258 | + _write_auth_config( |
| 259 | + { |
| 260 | + "access_token": new_access, |
| 261 | + "refresh_token": new_refresh, |
| 262 | + "expires_at": new_expires_at, |
| 263 | + "token_type": tok.get("token_type") or cfg.get("token_type") or "Bearer", |
| 264 | + "scope": tok.get("scope") or cfg.get("scope") or "", |
| 265 | + } |
| 266 | + ) |
| 267 | + return new_access or access_token |
| 268 | + except Exception as e: |
| 269 | + logger.debug(f"Exception during oauth refresh: {e}") |
| 270 | + return access_token |
| 271 | + |
| 272 | + |
| 273 | +def get_auth_bearer() -> Optional[str]: |
| 274 | + """Return a bearer token to use in Authorization. |
| 275 | +
|
| 276 | + Priority: |
| 277 | + 1. FIREWORKS_ACCESS_TOKEN env |
| 278 | + 2. FIREWORKS_API_KEY env |
| 279 | + 3. Refreshed OAuth access_token from auth.ini (if present) |
| 280 | + 4. api_key from auth.ini |
| 281 | + """ |
| 282 | + env_access = os.environ.get("FIREWORKS_ACCESS_TOKEN") |
| 283 | + if env_access: |
| 284 | + return env_access |
| 285 | + env_key = os.environ.get("FIREWORKS_API_KEY") |
| 286 | + if env_key: |
| 287 | + return env_key |
| 288 | + refreshed = _refresh_oauth_token_if_needed() |
| 289 | + if refreshed: |
| 290 | + return refreshed |
| 291 | + return _get_credential_from_config_file("api_key") |
0 commit comments