Security error. Please try again.
" + ) + elif not auth_code_future.done(): + auth_code_future.set_result(query_params["code"][0]) + writer.write( + b"You can close this window.
" + ) + else: + error = query_params.get("error", ["Unknown error"])[0] + if not auth_code_future.done(): + auth_code_future.set_exception(Exception(f"OAuth failed: {error}")) + writer.write( + f"Error: {error}
".encode() + ) + + await writer.drain() + except Exception as e: + lib_logger.error(f"Error in OAuth callback handler: {e}") + finally: + writer.close() + + try: + server = await asyncio.start_server( + handle_callback, "127.0.0.1", self.callback_port + ) + + from urllib.parse import urlencode + + redirect_uri = f"http://localhost:{self.callback_port}{self.CALLBACK_PATH}" + + auth_params = { + "response_type": "code", + "client_id": self.CLIENT_ID, + "redirect_uri": redirect_uri, + "scope": " ".join(self.OAUTH_SCOPES), + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": state, + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + } + + auth_url = f"{self.AUTH_URL}?" + urlencode(auth_params) + + if is_headless: + auth_panel_text = Text.from_markup( + "Running in headless environment (no GUI detected).\n" + "Please open the URL below in a browser on another machine to authorize:\n" + ) + else: + auth_panel_text = Text.from_markup( + "1. Your browser will now open to log in and authorize the application.\n" + "2. If it doesn't open automatically, please open the URL below manually." + ) + + console.print( + Panel( + auth_panel_text, + title=f"{self.ENV_PREFIX} OAuth Setup for [bold yellow]{display_name}[/bold yellow]", + style="bold blue", + ) + ) + + escaped_url = rich_escape(auth_url) + console.print(f"[bold]URL:[/bold] [link={auth_url}]{escaped_url}[/link]\n") + + if not is_headless: + try: + webbrowser.open(auth_url) + lib_logger.info("Browser opened successfully for OAuth flow") + except Exception as e: + lib_logger.warning( + f"Failed to open browser automatically: {e}. Please open the URL manually." + ) + + with console.status( + "[bold green]Waiting for you to complete authentication in the browser...[/bold green]", + spinner="dots", + ): + auth_code = await asyncio.wait_for(auth_code_future, timeout=310) + + except asyncio.TimeoutError: + raise Exception("OAuth flow timed out. Please try again.") + finally: + if server: + server.close() + await server.wait_closed() + + lib_logger.info("Exchanging authorization code for tokens...") + + async with httpx.AsyncClient() as client: + redirect_uri = f"http://localhost:{self.callback_port}{self.CALLBACK_PATH}" + + response = await client.post( + self.TOKEN_URL, + data={ + "grant_type": "authorization_code", + "code": auth_code.strip(), + "client_id": self.CLIENT_ID, + "code_verifier": code_verifier, + "redirect_uri": redirect_uri, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + response.raise_for_status() + token_data = response.json() + + # Build credentials + new_creds = { + "access_token": token_data.get("access_token"), + "refresh_token": token_data.get("refresh_token"), + "id_token": token_data.get("id_token"), + "expiry_date": time.time() + token_data.get("expires_in", 3600), + } + + # Parse ID token for claims + id_token_claims = _parse_jwt_claims(token_data.get("id_token", "")) or {} + access_token_claims = _parse_jwt_claims(token_data.get("access_token", "")) or {} + + # Extract account ID and email + auth_claims = id_token_claims.get("https://api.openai.com/auth", {}) + account_id = auth_claims.get("chatgpt_account_id", "") + org_id = id_token_claims.get("organization_id") + project_id = id_token_claims.get("project_id") + + email = id_token_claims.get("email", "") + plan_type = access_token_claims.get("chatgpt_plan_type", "") + + new_creds["account_id"] = account_id + + # Try to exchange for API key if we have org and project + api_key = None + if org_id and project_id: + try: + api_key = await self._exchange_for_api_key( + client, token_data.get("id_token", "") + ) + new_creds["api_key"] = api_key + except Exception as e: + lib_logger.warning(f"API key exchange failed: {e}") + + new_creds["_proxy_metadata"] = { + "email": email, + "account_id": account_id, + "org_id": org_id, + "project_id": project_id, + "plan_type": plan_type, + "last_check_timestamp": time.time(), + } + + if path: + await self._save_credentials(path, new_creds) + + lib_logger.info( + f"{self.ENV_PREFIX} OAuth initialized successfully for '{display_name}'." + ) + + return new_creds + + async def _exchange_for_api_key( + self, client: httpx.AsyncClient, id_token: str + ) -> Optional[str]: + """ + Exchange ID token for an OpenAI API key. + + Uses the token exchange grant type to get a persistent API key. + """ + import datetime + + today = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d") + + response = await client.post( + self.TOKEN_URL, + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "client_id": self.CLIENT_ID, + "requested_token": "openai-api-key", + "subject_token": id_token, + "subject_token_type": "urn:ietf:params:oauth:token-type:id_token", + "name": f"LLM-API-Key-Proxy [auto-generated] ({today})", + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + response.raise_for_status() + exchange_data = response.json() + + return exchange_data.get("access_token") + + async def initialize_token( + self, + creds_or_path: Union[Dict[str, Any], str], + force_interactive: bool = False, + ) -> Dict[str, Any]: + """Initialize OAuth token, triggering interactive OAuth flow if needed.""" + path = creds_or_path if isinstance(creds_or_path, str) else None + + if isinstance(creds_or_path, dict): + display_name = creds_or_path.get("_proxy_metadata", {}).get( + "display_name", "in-memory object" + ) + else: + display_name = Path(path).name if path else "in-memory object" + + lib_logger.debug(f"Initializing {self.ENV_PREFIX} token for '{display_name}'...") + + try: + creds = ( + await self._load_credentials(creds_or_path) if path else creds_or_path + ) + reason = "" + + if force_interactive: + reason = "re-authentication was explicitly requested" + elif not creds.get("refresh_token") and not creds.get("api_key"): + reason = "refresh token and API key are missing" + elif self._is_token_expired(creds) and not creds.get("api_key"): + reason = "token is expired" + + if reason: + if reason == "token is expired" and creds.get("refresh_token"): + try: + return await self._refresh_token(path, creds) + except Exception as e: + lib_logger.warning( + f"Automatic token refresh for '{display_name}' failed: {e}. Proceeding to interactive login." + ) + + lib_logger.warning( + f"{self.ENV_PREFIX} OAuth token for '{display_name}' needs setup: {reason}." + ) + + coordinator = get_reauth_coordinator() + + async def _do_interactive_oauth(): + return await self._perform_interactive_oauth(path, creds, display_name) + + return await coordinator.execute_reauth( + credential_path=path or display_name, + provider_name=self.ENV_PREFIX, + reauth_func=_do_interactive_oauth, + timeout=300.0, + ) + + lib_logger.info(f"{self.ENV_PREFIX} OAuth token at '{display_name}' is valid.") + return creds + + except Exception as e: + raise ValueError( + f"Failed to initialize {self.ENV_PREFIX} OAuth for '{path}': {e}" + ) + + async def get_auth_header(self, credential_path: str) -> Dict[str, str]: + """Get auth header with graceful degradation if refresh fails.""" + try: + creds = await self._load_credentials(credential_path) + + # Prefer API key if available + if creds.get("api_key"): + return {"Authorization": f"Bearer {creds['api_key']}"} + + # Fall back to access token + if self._is_token_expired(creds): + try: + creds = await self._refresh_token(credential_path, creds) + except Exception as e: + cached = self._credentials_cache.get(credential_path) + if cached and (cached.get("access_token") or cached.get("api_key")): + lib_logger.warning( + f"Token refresh failed for {Path(credential_path).name}: {e}. " + "Using cached token." + ) + creds = cached + else: + raise + + token = creds.get("api_key") or creds.get("access_token") + return {"Authorization": f"Bearer {token}"} + + except Exception as e: + cached = self._credentials_cache.get(credential_path) + if cached and (cached.get("access_token") or cached.get("api_key")): + lib_logger.error( + f"Credential load failed for {credential_path}: {e}. Using stale cached token." + ) + token = cached.get("api_key") or cached.get("access_token") + return {"Authorization": f"Bearer {token}"} + raise + + async def get_account_id(self, credential_path: str) -> Optional[str]: + """Get the ChatGPT account ID for a credential.""" + creds = await self._load_credentials(credential_path) + return creds.get("account_id") or creds.get("_proxy_metadata", {}).get("account_id") + + async def proactively_refresh(self, credential_path: str): + """Proactively refresh a credential by queueing it for refresh.""" + creds = await self._load_credentials(credential_path) + if self._is_token_expired(creds) and not creds.get("api_key"): + await self._queue_refresh(credential_path, force=False, needs_reauth=False) + + # ========================================================================= + # CREDENTIAL MANAGEMENT METHODS + # ========================================================================= + + def _get_provider_file_prefix(self) -> str: + """Get the file name prefix for this provider's credential files.""" + return self.ENV_PREFIX.lower() + + def _get_oauth_base_dir(self) -> Path: + """Get the base directory for OAuth credential files.""" + return Path.cwd() / "oauth_creds" + + def _find_existing_credential_by_email( + self, email: str, base_dir: Optional[Path] = None + ) -> Optional[Path]: + """Find an existing credential file for the given email.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_oauth_*.json") + + for cred_file in glob(pattern): + try: + with open(cred_file, "r") as f: + creds = json.load(f) + existing_email = creds.get("_proxy_metadata", {}).get("email") + if existing_email == email: + return Path(cred_file) + except Exception: + continue + + return None + + def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int: + """Get the next available credential number.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_oauth_*.json") + + existing_numbers = [] + for cred_file in glob(pattern): + match = re.search(r"_oauth_(\d+)\.json$", cred_file) + if match: + existing_numbers.append(int(match.group(1))) + + if not existing_numbers: + return 1 + return max(existing_numbers) + 1 + + def _build_credential_path( + self, base_dir: Optional[Path] = None, number: Optional[int] = None + ) -> Path: + """Build a path for a new credential file.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + if number is None: + number = self._get_next_credential_number(base_dir) + + prefix = self._get_provider_file_prefix() + filename = f"{prefix}_oauth_{number}.json" + return base_dir / filename + + async def setup_credential( + self, base_dir: Optional[Path] = None + ) -> CredentialSetupResult: + """Complete credential setup flow: OAuth -> save -> discovery.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + base_dir.mkdir(exist_ok=True) + + try: + temp_creds = { + "_proxy_metadata": {"display_name": f"new {self.ENV_PREFIX} credential"} + } + new_creds = await self.initialize_token(temp_creds) + + email = new_creds.get("_proxy_metadata", {}).get("email") + + if not email: + return CredentialSetupResult( + success=False, error="Could not retrieve email from OAuth response" + ) + + existing_path = self._find_existing_credential_by_email(email, base_dir) + is_update = existing_path is not None + + if is_update: + file_path = existing_path + else: + file_path = self._build_credential_path(base_dir) + + await self._save_credentials(str(file_path), new_creds) + + account_id = new_creds.get("account_id") or new_creds.get( + "_proxy_metadata", {} + ).get("account_id") + + return CredentialSetupResult( + success=True, + file_path=str(file_path), + email=email, + account_id=account_id, + is_update=is_update, + credentials=new_creds, + ) + + except Exception as e: + lib_logger.error(f"Credential setup failed: {e}") + return CredentialSetupResult(success=False, error=str(e)) + + def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]: + """List all credential files for this provider.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_oauth_*.json") + + credentials = [] + for cred_file in sorted(glob(pattern)): + try: + with open(cred_file, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + + match = re.search(r"_oauth_(\d+)\.json$", cred_file) + number = int(match.group(1)) if match else 0 + + credentials.append({ + "file_path": cred_file, + "email": metadata.get("email", "unknown"), + "account_id": creds.get("account_id") or metadata.get("account_id"), + "number": number, + }) + except Exception: + continue + + return credentials diff --git a/src/rotator_library/providers/qwen_auth_base.py b/src/rotator_library/providers/qwen_auth_base.py index f31ead3c..146a274b 100644 --- a/src/rotator_library/providers/qwen_auth_base.py +++ b/src/rotator_library/providers/qwen_auth_base.py @@ -342,11 +342,17 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] if not force and cached_creds and not self._is_token_expired(cached_creds): return cached_creds - # [ROTATING TOKEN FIX] Always read fresh from disk before refresh. + # [ROTATING TOKEN FIX] Read fresh credentials before refresh. # Qwen uses rotating refresh tokens - each refresh invalidates the previous token. # If we use a stale cached token, refresh will fail with HTTP 400. - # Reading fresh from disk ensures we have the latest token. - await self._read_creds_from_file(path) + if not path.startswith("env://"): + # For file paths, read fresh from disk to pick up tokens that may have + # been updated by another process or a previous refresh cycle. + await self._read_creds_from_file(path) + # For env:// paths, the in-memory cache is the single source of truth. + # _save_credentials updates the cache after each refresh, so the cache + # always holds the latest rotating tokens. Re-reading from static env vars + # would discard the rotated refresh_token and break subsequent refreshes. creds_from_file = self._credentials_cache[path] lib_logger.debug(f"Refreshing Qwen OAuth token for '{Path(path).name}'...") @@ -524,15 +530,22 @@ async def get_api_details(self, credential_identifier: str) -> Tuple[str, str]: """ Returns the API base URL and access token. - Supports both credential types: - - OAuth: credential_identifier is a file path to JSON credentials - - API Key: credential_identifier is the API key string itself + Supports three credential types: + - OAuth file: credential_identifier is a file path to JSON credentials + - env:// virtual path: credential_identifier is "env://provider/index" + - Direct API key: credential_identifier is the API key string itself """ - # Detect credential type - if os.path.isfile(credential_identifier): - # OAuth credential: file path to JSON + try: + is_oauth = credential_identifier.startswith("env://") or os.path.isfile( + credential_identifier + ) + except (OSError, ValueError): + # os.path.isfile can raise on invalid path strings (e.g. very long API keys) + is_oauth = False + + if is_oauth: lib_logger.debug( - f"Using OAuth credentials from file: {credential_identifier}" + f"Using OAuth credentials from: {credential_identifier}" ) creds = await self._load_credentials(credential_identifier) diff --git a/src/rotator_library/providers/utilities/__init__.py b/src/rotator_library/providers/utilities/__init__.py index 7efe9f25..c0314831 100644 --- a/src/rotator_library/providers/utilities/__init__.py +++ b/src/rotator_library/providers/utilities/__init__.py @@ -4,6 +4,7 @@ # Utilities for provider implementations from .base_quota_tracker import BaseQuotaTracker from .antigravity_quota_tracker import AntigravityQuotaTracker +from .anthropic_quota_tracker import AnthropicQuotaTracker from .gemini_cli_quota_tracker import GeminiCliQuotaTracker # Shared utilities for Gemini-based providers @@ -38,6 +39,7 @@ # Quota trackers "BaseQuotaTracker", "AntigravityQuotaTracker", + "AnthropicQuotaTracker", "GeminiCliQuotaTracker", # Shared utilities "env_bool", diff --git a/src/rotator_library/providers/utilities/anthropic_quota_tracker.py b/src/rotator_library/providers/utilities/anthropic_quota_tracker.py new file mode 100644 index 00000000..4e2762b4 --- /dev/null +++ b/src/rotator_library/providers/utilities/anthropic_quota_tracker.py @@ -0,0 +1,494 @@ +# src/rotator_library/providers/utilities/anthropic_quota_tracker.py +""" +Anthropic Quota Tracking Mixin + +Provides quota tracking functionality for the Anthropic provider by: +1. Fetching utilization data from the /api/oauth/usage endpoint +2. Caching quota snapshots per credential +3. Pushing quota data to UsageManager for TUI and /quota-stats display + +Anthropic OAuth Usage API Response: +{ + "five_hour": { "utilization": 23.0, "resets_at": "ISO8601" }, + "seven_day": { "utilization": 15.0, "resets_at": "ISO8601" } | null, + ... +} + +Required from provider: + - self._credentials_cache: Dict[str, Dict[str, Any]] + - self.get_anthropic_auth_header(credential_path) -> Dict[str, str] +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import httpx + +if TYPE_CHECKING: + from ...usage import UsageManager + +lib_logger = logging.getLogger("rotator_library") + + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +ANTHROPIC_USAGE_URL = "https://api.anthropic.com/api/oauth/usage" +ANTHROPIC_BETA_HEADER = "oauth-2025-04-20" + +# Stale threshold - snapshots older than this are considered stale (10 minutes) +QUOTA_STALE_THRESHOLD_SECONDS = 600 + + +# ============================================================================= +# HELPERS +# ============================================================================= + + +def _get_credential_identifier(credential_path: str) -> str: + """Extract a short identifier from a credential path.""" + if credential_path.startswith("env://"): + return credential_path + return Path(credential_path).name + + +def _parse_iso_timestamp(iso_string: str) -> Optional[float]: + """Parse an ISO 8601 timestamp to Unix timestamp in seconds.""" + try: + dt = datetime.fromisoformat(iso_string.replace("Z", "+00:00")) + return dt.timestamp() + except (ValueError, TypeError): + return None + + + + +# ============================================================================= +# DATA CLASSES +# ============================================================================= + + +@dataclass +class AnthropicQuotaWindow: + """A single quota window (e.g., 5-hour or 7-day).""" + + utilization: float # Percentage used (0-100) + resets_at: Optional[float] = None # Unix timestamp + + @property + def remaining_percent(self) -> float: + """Remaining quota as percentage (0-100).""" + return max(0.0, 100.0 - self.utilization) + + @property + def is_exhausted(self) -> bool: + """Check if quota is fully used.""" + return self.utilization >= 100.0 + + +@dataclass +class AnthropicQuotaSnapshot: + """Complete quota snapshot for an Anthropic credential.""" + + credential_path: str + identifier: str + + # From /api/oauth/usage endpoint + five_hour: Optional[AnthropicQuotaWindow] = None + seven_day: Optional[AnthropicQuotaWindow] = None + + fetched_at: float = field(default_factory=time.time) + status: str = "success" # "success", "error", "no_data" + error: Optional[str] = None + + @property + def is_stale(self) -> bool: + """Check if this snapshot is stale.""" + return time.time() - self.fetched_at > QUOTA_STALE_THRESHOLD_SECONDS + + def to_dict(self) -> Dict[str, Any]: + """Convert to dict for JSON serialization.""" + result: Dict[str, Any] = { + "identifier": self.identifier, + "fetched_at": self.fetched_at, + "is_stale": self.is_stale, + "status": self.status, + } + + if self.five_hour: + result["five_hour"] = { + "utilization": self.five_hour.utilization, + "remaining_percent": self.five_hour.remaining_percent, + "resets_at": self.five_hour.resets_at, + "is_exhausted": self.five_hour.is_exhausted, + } + + if self.seven_day: + result["seven_day"] = { + "utilization": self.seven_day.utilization, + "remaining_percent": self.seven_day.remaining_percent, + "resets_at": self.seven_day.resets_at, + "is_exhausted": self.seven_day.is_exhausted, + } + + + if self.error: + result["error"] = self.error + + return result + + +# ============================================================================= +# QUOTA TRACKER MIXIN +# ============================================================================= + + +class AnthropicQuotaTracker: + """ + Mixin class providing quota tracking functionality for Anthropic provider. + + Capabilities: + - Fetch quota utilization from /api/oauth/usage endpoint + - Cache quota snapshots per credential + - Push quota data to UsageManager for TUI display + + Usage: + class AnthropicProvider(AnthropicOAuthBase, AnthropicQuotaTracker, ProviderInterface): + ... + + The provider class must call self._init_quota_tracker() in __init__. + """ + + # Type hints for attributes from provider + _credentials_cache: Dict[str, Dict[str, Any]] + _quota_cache: Dict[str, AnthropicQuotaSnapshot] + _quota_refresh_interval: int + + def _init_quota_tracker(self) -> None: + """Initialize quota tracker state. Call from provider's __init__.""" + self._quota_cache: Dict[str, AnthropicQuotaSnapshot] = {} + self._quota_refresh_interval: int = 300 # 5 min default + self._usage_manager: Optional["UsageManager"] = None + + def set_usage_manager(self, usage_manager: "UsageManager") -> None: + """Set the UsageManager reference for pushing quota updates.""" + self._usage_manager = usage_manager + + # ========================================================================= + # API-BASED QUOTA FETCH + # ========================================================================= + + async def fetch_quota_from_api( + self, + credential_path: str, + ) -> AnthropicQuotaSnapshot: + """ + Fetch quota utilization from the Anthropic /api/oauth/usage endpoint. + + Args: + credential_path: Path to OAuth credential file + + Returns: + AnthropicQuotaSnapshot with utilization data + """ + identifier = _get_credential_identifier(credential_path) + + try: + # Get auth header from the OAuth base class + auth_headers = await self.get_anthropic_auth_header(credential_path) + + async with httpx.AsyncClient() as client: + response = await client.get( + ANTHROPIC_USAGE_URL, + headers={ + **auth_headers, + "anthropic-beta": ANTHROPIC_BETA_HEADER, + }, + timeout=5.0, + ) + + if response.status_code != 200: + lib_logger.debug( + f"Anthropic usage API returned {response.status_code} " + f"for {identifier}: {response.text[:200]}" + ) + return AnthropicQuotaSnapshot( + credential_path=credential_path, + identifier=identifier, + status="error", + error=f"HTTP {response.status_code}", + ) + + data = response.json() + + # Parse five_hour window + five_hour = None + fh_data = data.get("five_hour") + if fh_data and isinstance(fh_data, dict): + utilization = fh_data.get("utilization") + if utilization is not None: + resets_at = None + if fh_data.get("resets_at"): + resets_at = _parse_iso_timestamp(fh_data["resets_at"]) + five_hour = AnthropicQuotaWindow( + utilization=float(utilization), + resets_at=resets_at, + ) + + # Parse seven_day window + seven_day = None + sd_data = data.get("seven_day") + if sd_data and isinstance(sd_data, dict): + utilization = sd_data.get("utilization") + if utilization is not None: + resets_at = None + if sd_data.get("resets_at"): + resets_at = _parse_iso_timestamp(sd_data["resets_at"]) + seven_day = AnthropicQuotaWindow( + utilization=float(utilization), + resets_at=resets_at, + ) + + snapshot = AnthropicQuotaSnapshot( + credential_path=credential_path, + identifier=identifier, + five_hour=five_hour, + seven_day=seven_day, + status="success", + ) + + # Log + parts = [] + if five_hour: + parts.append(f"5h={five_hour.utilization:.0f}%") + if seven_day: + parts.append(f"7d={seven_day.utilization:.0f}%") + lib_logger.debug( + f"Anthropic usage API ({identifier}): {', '.join(parts) or 'no windows'}" + ) + + # Cache and push + self._quota_cache[credential_path] = snapshot + if self._usage_manager: + self._push_quota_to_usage_manager(credential_path, snapshot) + + return snapshot + + except Exception as e: + lib_logger.debug( + f"Failed to fetch Anthropic usage for {identifier}: {e}" + ) + return AnthropicQuotaSnapshot( + credential_path=credential_path, + identifier=identifier, + status="error", + error=str(e), + ) + + + # ========================================================================= + # USAGE MANAGER INTEGRATION + # ========================================================================= + + def _push_quota_to_usage_manager( + self, + credential_path: str, + snapshot: AnthropicQuotaSnapshot, + ) -> None: + """ + Push quota snapshot to the UsageManager. + + Follows the Codex pattern: treats utilization percentage as + quota_used on a 100-scale (quota_max_requests=100). + """ + if not self._usage_manager: + return + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + return + + async def _push() -> None: + try: + if snapshot.five_hour: + quota_used = int(snapshot.five_hour.utilization) + await self._usage_manager.update_quota_baseline( + accessor=credential_path, + model="anthropic/_5h_window", + quota_max_requests=100, + quota_reset_ts=snapshot.five_hour.resets_at, + quota_used=quota_used, + quota_group="5h-limit", + force=True, + apply_exhaustion=snapshot.five_hour.is_exhausted, + ) + + if snapshot.seven_day: + quota_used = int(snapshot.seven_day.utilization) + await self._usage_manager.update_quota_baseline( + accessor=credential_path, + model="anthropic/_weekly_window", + quota_max_requests=100, + quota_reset_ts=snapshot.seven_day.resets_at, + quota_used=quota_used, + quota_group="weekly-limit", + force=True, + apply_exhaustion=snapshot.seven_day.is_exhausted, + ) + except Exception as e: + lib_logger.debug( + f"Failed to push Anthropic quota to UsageManager: {e}" + ) + + if loop.is_running(): + asyncio.ensure_future(_push()) + else: + loop.run_until_complete(_push()) + + # ========================================================================= + # BACKGROUND JOB SUPPORT + # ========================================================================= + + def get_background_job_config(self) -> Optional[Dict[str, Any]]: + """ + Return configuration for quota refresh background job. + + Returns: + Background job config dict + """ + return { + "interval": self._quota_refresh_interval, + "name": "anthropic_quota_refresh", + "run_on_start": True, + } + + async def run_background_job( + self, + usage_manager: "UsageManager", + credentials: List[str], + ) -> None: + """ + Execute periodic quota refresh for active credentials. + + Called by BackgroundRefresher at the configured interval. + + Args: + usage_manager: UsageManager instance + credentials: List of credential paths for this provider + """ + if usage_manager and not self._usage_manager: + self._usage_manager = usage_manager + + if not credentials: + return + + # Filter to OAuth credentials only + oauth_creds = [c for c in credentials if _is_oauth_path(c)] + + if not oauth_creds: + lib_logger.debug("No OAuth Anthropic credentials to refresh quota for") + return + + lib_logger.debug( + f"Refreshing Anthropic quota for {len(oauth_creds)} OAuth credentials" + ) + + # Fetch quotas with limited concurrency + semaphore = asyncio.Semaphore(3) + + async def fetch_with_semaphore(cred_path: str): + async with semaphore: + return await self.fetch_quota_from_api(cred_path) + + tasks = [fetch_with_semaphore(cred) for cred in oauth_creds] + results = await asyncio.gather(*tasks, return_exceptions=True) + + success_count = sum( + 1 + for r in results + if isinstance(r, AnthropicQuotaSnapshot) and r.status == "success" + ) + + lib_logger.debug( + f"Anthropic quota refresh complete: {success_count}/{len(oauth_creds)} successful" + ) + + # ========================================================================= + # CACHE ACCESS + # ========================================================================= + + def get_cached_quota( + self, + credential_path: str, + ) -> Optional[AnthropicQuotaSnapshot]: + """Get cached quota snapshot for a credential.""" + return self._quota_cache.get(credential_path) + + # ========================================================================= + # QUOTA INFO AGGREGATION (for /quota-stats) + # ========================================================================= + + def get_all_quota_info( + self, + credential_paths: List[str], + ) -> Dict[str, Any]: + """ + Get cached quota info for all credentials. + + Args: + credential_paths: List of credential paths to report on + + Returns: + Structured quota info dict for /quota-stats endpoint + """ + results = {} + exhausted_count = 0 + + for cred_path in credential_paths: + identifier = _get_credential_identifier(cred_path) + cached = self._quota_cache.get(cred_path) + + if cached: + entry = cached.to_dict() + entry["file_path"] = ( + cred_path if not cred_path.startswith("env://") else None + ) + if cached.five_hour and cached.five_hour.is_exhausted: + exhausted_count += 1 + else: + entry = { + "identifier": identifier, + "file_path": ( + cred_path if not cred_path.startswith("env://") else None + ), + "status": "no_data", + "fetched_at": None, + "is_stale": True, + } + + results[identifier] = entry + + return { + "credentials": results, + "summary": { + "total_credentials": len(credential_paths), + "exhausted_count": exhausted_count, + "data_source": "oauth_usage_api", + }, + "timestamp": time.time(), + } + + +def _is_oauth_path(path: str) -> bool: + """Check if a credential path is for an OAuth credential.""" + return "oauth" in path.lower() or path.startswith("env://anthropic/") diff --git a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py index e9711bce..5a2cbc8d 100644 --- a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py +++ b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py @@ -32,6 +32,7 @@ import httpx from .base_quota_tracker import BaseQuotaTracker, QUOTA_DISCOVERY_DELAY_SECONDS +from .gemini_shared_utils import is_paid_tier, normalize_tier_name if TYPE_CHECKING: from ...usage import UsageManager @@ -104,11 +105,41 @@ # Gemini 2.5 Pro - UNVERIFIED/UNUSED (assumed 0.1% = 1000 requests) "gemini-2.5-pro": 1, }, + # ULTRA tier - estimated ~5x PRO for premium models (seed values). + # These are provisional starting points that will be automatically + # overridden by dynamic learning from observed API fraction changes. + "ULTRA": { + # Claude/GPT-OSS group (~5x PRO: 750 requests) + "claude-sonnet-4-5": 750, + "claude-sonnet-4-5-thinking": 750, + "claude-opus-4-5": 750, + "claude-opus-4-5-thinking": 750, + "claude-opus-4-6": 750, + "claude-opus-4-6-thinking": 750, + "claude-sonnet-4.5": 750, + "claude-opus-4.5": 750, + "claude-opus-4.6": 750, + "gpt-oss-120b-medium": 750, + # Gemini 3 Pro group (~5x PRO: 1600 requests) + "gemini-3-pro-high": 1600, + "gemini-3-pro-low": 1600, + "gemini-3-pro-preview": 1600, + # Gemini 3 Flash (~5x PRO: 2000 requests) + "gemini-3-flash": 2000, + # Gemini 2.5 Flash group (same as PRO - already high limits) + "gemini-2.5-flash": 3000, + "gemini-2.5-flash-thinking": 3000, + # Gemini 2.5 Flash Lite (same as PRO - already high limits) + "gemini-2.5-flash-lite": 5000, + # Gemini 2.5 Pro - UNVERIFIED/UNUSED + "gemini-2.5-pro": 1, + }, } # Legacy tier name aliases (backwards compatibility) DEFAULT_MAX_REQUESTS["standard-tier"] = DEFAULT_MAX_REQUESTS["PRO"] DEFAULT_MAX_REQUESTS["free-tier"] = DEFAULT_MAX_REQUESTS["FREE"] +DEFAULT_MAX_REQUESTS["ultra-tier"] = DEFAULT_MAX_REQUESTS["ULTRA"] # Default max requests for unknown models (1% = 100 requests) DEFAULT_MAX_REQUESTS_UNKNOWN = 100 @@ -178,6 +209,7 @@ class AntigravityProvider(GoogleOAuthBase, AntigravityQuotaTracker): _quota_refresh_interval: int project_tier_cache: Dict[str, str] project_id_cache: Dict[str, str] + _fraction_tracking: Dict[str, Dict[str, Any]] # ========================================================================= # ANTIGRAVITY-SPECIFIC HELPERS @@ -288,6 +320,9 @@ def get_max_requests_for_model(self, model: str, tier: str) -> int: # Ensure learned values are loaded self._load_learned_costs() + # Normalize tier to canonical name (e.g., "g1-ultra-tier" -> "ULTRA") + tier = normalize_tier_name(tier) + # Strip provider prefix if present clean_model = model.split("/")[-1] if "/" in model else model @@ -301,10 +336,21 @@ def get_max_requests_for_model(self, model: str, tier: str) -> int: if clean_model in DEFAULT_MAX_REQUESTS[tier]: return DEFAULT_MAX_REQUESTS[tier][clean_model] - # Unknown model - use conservative default - lib_logger.debug( + # Unknown model/tier combo - try PRO fallback for paid tiers + if is_paid_tier(tier) and "PRO" in DEFAULT_MAX_REQUESTS: + if clean_model in DEFAULT_MAX_REQUESTS["PRO"]: + lib_logger.warning( + f"No max requests for model={clean_model}, tier={tier}. " + f"Falling back to PRO tier limits. Consider running " + f"discover_quota_costs to learn actual limits." + ) + return DEFAULT_MAX_REQUESTS["PRO"][clean_model] + + # Truly unknown model/tier - use conservative default + lib_logger.warning( f"Unknown max requests for model={clean_model}, tier={tier}. " - f"Using default {DEFAULT_MAX_REQUESTS_UNKNOWN}" + f"Using default {DEFAULT_MAX_REQUESTS_UNKNOWN}. " + f"Consider running discover_quota_costs to learn actual limits." ) return DEFAULT_MAX_REQUESTS_UNKNOWN @@ -317,6 +363,137 @@ def _get_quota_group_for_model(self, model: str) -> Optional[str]: return group_name return None + # ========================================================================= + # DYNAMIC QUOTA LEARNING + # ========================================================================= + + def _try_learn_max_requests_from_fraction( + self, + cred_path: str, + model: str, + tier: str, + new_remaining: float, + usage_manager: "UsageManager", + quota_group: Optional[str] = None, + ) -> Optional[int]: + """Try to derive max_requests from observed API fraction changes. + + Compares the current remaining_fraction with a previously stored value. + If the fraction has decreased by at least 5% (one API step), uses the + actual request count from the usage_manager to estimate max_requests. + + This enables automatic learning of quota limits for any tier, including + ULTRA and future tiers, without needing hardcoded values. + + Args: + cred_path: Credential path identifier + model: User-facing model name (without provider prefix) + tier: Account tier (e.g., "ULTRA", "PRO") + new_remaining: Current remaining_fraction from API (0.0-1.0) + usage_manager: UsageManager instance for request count lookup + quota_group: Optional quota group name + + Returns: + Learned max_requests if derivable, None otherwise. + """ + if not hasattr(self, "_fraction_tracking"): + self._fraction_tracking = {} + + # Normalize tier to canonical name for consistent storage + tier = normalize_tier_name(tier) + + tracking_key = f"{cred_path}:{model}" + prev = self._fraction_tracking.get(tracking_key) + + # Get current request count from usage_manager BEFORE baseline update + prefixed_model = f"antigravity/{model}" + current_count = usage_manager.get_window_request_count( + cred_path, prefixed_model, quota_group=quota_group + ) + + # Store current state for next observation + self._fraction_tracking[tracking_key] = { + "fraction": new_remaining, + "request_count": current_count or 0, + "timestamp": time.time(), + } + + if prev is None: + return None # First observation - nothing to compare + + prev_fraction = prev["fraction"] + prev_count = prev["request_count"] + + if current_count is None: + return None # No request tracking available yet + + # Calculate fraction consumed and requests made between observations + fraction_consumed = prev_fraction - new_remaining + requests_made = current_count - prev_count + + # Guard: need meaningful consumption and actual requests + # API updates in ~20% increments, so 5% threshold avoids noise + if fraction_consumed < 0.05: + return None # Too small a change, or quota reset (negative) + + if requests_made < 1: + return None # No requests between observations + + # Derive max_requests: if N requests consumed F fraction of quota, + # then total capacity = N / F + derived_max = int(round(requests_made / fraction_consumed)) + + # Sanity bounds + if derived_max < 10: + lib_logger.warning( + f"Dynamic learning: derived unreasonably low max_requests=" + f"{derived_max} for {model} tier={tier} " + f"(fraction_consumed={fraction_consumed:.4f}, " + f"requests={requests_made}). Ignoring." + ) + return None + + if derived_max > 100000: + lib_logger.warning( + f"Dynamic learning: derived unreasonably high max_requests=" + f"{derived_max} for {model} tier={tier} " + f"(fraction_consumed={fraction_consumed:.4f}, " + f"requests={requests_made}). Ignoring." + ) + return None + + # Smooth with existing learned value if available + self._load_learned_costs() + existing = None + if tier in self._learned_costs: + existing = self._learned_costs[tier].get(model) + + if existing is not None: + # Weighted average: 60% new observation, 40% existing + smoothed = int(round(0.6 * derived_max + 0.4 * existing)) + lib_logger.info( + f"Dynamic learning: {model} tier={tier} derived " + f"max_requests={derived_max} (existing={existing}, " + f"smoothed={smoothed}, fraction_consumed=" + f"{fraction_consumed:.4f}, requests={requests_made})" + ) + derived_max = smoothed + else: + lib_logger.info( + f"Dynamic learning: {model} tier={tier} derived " + f"max_requests={derived_max} " + f"(fraction_consumed={fraction_consumed:.4f}, " + f"requests={requests_made})" + ) + + # Persist learned value (updates in-memory dict + saves to file) + if tier not in self._learned_costs: + self._learned_costs[tier] = {} + self._learned_costs[tier][model] = derived_max + self._save_learned_costs() + + return derived_max + # ========================================================================= # BaseQuotaTracker ABSTRACT METHOD IMPLEMENTATIONS # ========================================================================= @@ -1072,7 +1249,23 @@ async def _store_baselines_to_usage_manager( if user_model in stored_for_cred: continue + # Determine quota group (needed for dynamic learning and baseline storage) + quota_group = self.get_model_quota_group(user_model) + + # Try dynamic learning from observed fraction changes. + # This may update _learned_costs, which get_max_requests_for_model + # checks first, enabling self-correcting limits for any tier. + self._try_learn_max_requests_from_fraction( + cred_path, + user_model, + tier, + remaining, + usage_manager, + quota_group=quota_group, + ) + # Calculate max_requests for this model/tier + # (will use dynamically learned values if available) max_requests = self.get_max_requests_for_model(user_model, tier) # Extract reset_timestamp (already parsed to float in fetch_quota_from_api) @@ -1087,7 +1280,6 @@ async def _store_baselines_to_usage_manager( quota_used = None if max_requests is not None: quota_used = int((1.0 - remaining) * max_requests) - quota_group = self.get_model_quota_group(user_model) # ANTIGRAVITY-SPECIFIC: Only apply exhaustion on initial fetch # (API only updates in ~20% increments, so we rely on local tracking diff --git a/src/rotator_library/providers/utilities/codex_quota_tracker.py b/src/rotator_library/providers/utilities/codex_quota_tracker.py new file mode 100644 index 00000000..8c623148 --- /dev/null +++ b/src/rotator_library/providers/utilities/codex_quota_tracker.py @@ -0,0 +1,997 @@ +# src/rotator_library/providers/utilities/codex_quota_tracker.py +""" +Codex Quota Tracking Mixin + +Provides quota tracking functionality for the Codex provider by: +1. Fetching rate limit status from the /usage endpoint +2. Parsing rate limit headers from API responses +3. Storing quota baselines in UsageManager + +Rate Limit Structure (from Codex API): +- Primary window: Short-term rate limit (e.g., 5 hours) +- Secondary window: Long-term rate limit (e.g., weekly/monthly) +- Credits: Account credit balance info + +Required from provider: + - self.get_auth_header(credential_path) -> Dict[str, str] + - self.get_account_id(credential_path) -> Optional[str] + - self._credentials_cache: Dict[str, Dict[str, Any]] +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import httpx + +if TYPE_CHECKING: + from ...usage_manager import UsageManager + +lib_logger = logging.getLogger("rotator_library") + + +# ============================================================================= +# HELPER FUNCTIONS +# ============================================================================= + + +def _get_credential_identifier(credential_path: str) -> str: + """Extract a short identifier from a credential path.""" + if credential_path.startswith("env://"): + return credential_path + return Path(credential_path).name + + +def _seconds_to_minutes(seconds: Optional[int]) -> Optional[int]: + """Convert seconds to minutes, or None if input is None.""" + if seconds is None: + return None + return seconds // 60 + + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +# Codex usage API endpoint +# The Codex CLI uses different paths based on PathStyle: +# - If base contains /backend-api: use /wham/usage (ChatGptApi style) +# - Otherwise: use /api/codex/usage (CodexApi style) +# Since we use chatgpt.com/backend-api, we need /wham/usage +CODEX_USAGE_URL = "https://chatgpt.com/backend-api/wham/usage" + +# Rate limit header names (from Codex API) +HEADER_PRIMARY_USED_PERCENT = "x-codex-primary-used-percent" +HEADER_PRIMARY_WINDOW_MINUTES = "x-codex-primary-window-minutes" +HEADER_PRIMARY_RESET_AT = "x-codex-primary-reset-at" +HEADER_SECONDARY_USED_PERCENT = "x-codex-secondary-used-percent" +HEADER_SECONDARY_WINDOW_MINUTES = "x-codex-secondary-window-minutes" +HEADER_SECONDARY_RESET_AT = "x-codex-secondary-reset-at" +HEADER_CREDITS_HAS_CREDITS = "x-codex-credits-has-credits" +HEADER_CREDITS_UNLIMITED = "x-codex-credits-unlimited" +HEADER_CREDITS_BALANCE = "x-codex-credits-balance" + +# Default quota refresh interval (5 minutes) +DEFAULT_QUOTA_REFRESH_INTERVAL = 300 + +# Stale threshold - quota data older than this is considered stale (15 minutes) +QUOTA_STALE_THRESHOLD_SECONDS = 900 + + +# ============================================================================= +# DATA CLASSES +# ============================================================================= + + +@dataclass +class RateLimitWindow: + """Rate limit window info from Codex API.""" + + used_percent: float # 0-100 + remaining_percent: float # 100 - used_percent + window_minutes: Optional[int] + reset_at: Optional[int] # Unix timestamp + + @property + def remaining_fraction(self) -> float: + """Get remaining quota as a fraction (0.0 to 1.0).""" + return max(0.0, min(1.0, (100 - self.used_percent) / 100)) + + @property + def is_exhausted(self) -> bool: + """Check if this window's quota is exhausted.""" + return self.used_percent >= 100 + + def seconds_until_reset(self) -> Optional[float]: + """Calculate seconds until reset, or None if unknown.""" + if self.reset_at is None: + return None + return max(0, self.reset_at - time.time()) + + +@dataclass +class CreditsInfo: + """Credits info from Codex API.""" + + has_credits: bool + unlimited: bool + balance: Optional[str] # Could be numeric string or "unlimited" + + +@dataclass +class CodexQuotaSnapshot: + """Complete quota snapshot for a Codex credential.""" + + credential_path: str + identifier: str + plan_type: Optional[str] + primary: Optional[RateLimitWindow] + secondary: Optional[RateLimitWindow] + credits: Optional[CreditsInfo] + fetched_at: float + status: str # "success" or "error" + error: Optional[str] + + @property + def is_stale(self) -> bool: + """Check if this snapshot is stale.""" + return time.time() - self.fetched_at > QUOTA_STALE_THRESHOLD_SECONDS + + +def _window_to_dict(window: RateLimitWindow) -> Dict[str, Any]: + """Convert RateLimitWindow to dict for JSON serialization.""" + return { + "remaining_percent": window.remaining_percent, + "remaining_fraction": window.remaining_fraction, + "used_percent": window.used_percent, + "window_minutes": window.window_minutes, + "reset_at": window.reset_at, + "reset_in_seconds": window.seconds_until_reset(), + "is_exhausted": window.is_exhausted, + } + + +def _credits_to_dict(credits: CreditsInfo) -> Dict[str, Any]: + """Convert CreditsInfo to dict for JSON serialization.""" + return { + "has_credits": credits.has_credits, + "unlimited": credits.unlimited, + "balance": credits.balance, + } + + +# ============================================================================= +# HEADER PARSING +# ============================================================================= + + +def parse_rate_limit_headers(headers: Dict[str, str]) -> CodexQuotaSnapshot: + """ + Parse rate limit information from Codex API response headers. + + Args: + headers: Response headers dict + + Returns: + CodexQuotaSnapshot with parsed rate limit data + """ + primary = _parse_window_from_headers( + headers, + HEADER_PRIMARY_USED_PERCENT, + HEADER_PRIMARY_WINDOW_MINUTES, + HEADER_PRIMARY_RESET_AT, + ) + + secondary = _parse_window_from_headers( + headers, + HEADER_SECONDARY_USED_PERCENT, + HEADER_SECONDARY_WINDOW_MINUTES, + HEADER_SECONDARY_RESET_AT, + ) + + credits = _parse_credits_from_headers(headers) + + return CodexQuotaSnapshot( + credential_path="", + identifier="", + plan_type=None, + primary=primary, + secondary=secondary, + credits=credits, + fetched_at=time.time(), + status="success" if (primary or secondary or credits) else "no_data", + error=None, + ) + + +def _parse_window_from_headers( + headers: Dict[str, str], + used_percent_header: str, + window_minutes_header: str, + reset_at_header: str, +) -> Optional[RateLimitWindow]: + """Parse a single rate limit window from headers.""" + used_percent_str = headers.get(used_percent_header) + if not used_percent_str: + return None + + try: + used_percent = float(used_percent_str) + except (ValueError, TypeError): + return None + + # Parse optional fields + window_minutes = None + window_minutes_str = headers.get(window_minutes_header) + if window_minutes_str: + try: + window_minutes = int(window_minutes_str) + except (ValueError, TypeError): + pass + + reset_at = None + reset_at_str = headers.get(reset_at_header) + if reset_at_str: + try: + reset_at = int(reset_at_str) + except (ValueError, TypeError): + pass + + return RateLimitWindow( + used_percent=used_percent, + remaining_percent=100 - used_percent, + window_minutes=window_minutes, + reset_at=reset_at, + ) + + +def _parse_credits_from_headers(headers: Dict[str, str]) -> Optional[CreditsInfo]: + """Parse credits info from headers.""" + has_credits_str = headers.get(HEADER_CREDITS_HAS_CREDITS) + if has_credits_str is None: + return None + + has_credits = has_credits_str.lower() in ("true", "1") + unlimited_str = headers.get(HEADER_CREDITS_UNLIMITED, "false") + unlimited = unlimited_str.lower() in ("true", "1") + balance = headers.get(HEADER_CREDITS_BALANCE) + + return CreditsInfo( + has_credits=has_credits, + unlimited=unlimited, + balance=balance, + ) + + +# ============================================================================= +# QUOTA TRACKER MIXIN +# ============================================================================= + + +class CodexQuotaTracker: + """ + Mixin class providing quota tracking functionality for Codex provider. + + This mixin adds the following capabilities: + - Fetch rate limit status from the Codex /usage API endpoint + - Parse rate limit headers from streaming responses + - Store quota baselines in UsageManager + - Get structured quota info for all credentials + + Usage: + class CodexProvider(OpenAIOAuthBase, CodexQuotaTracker, ProviderInterface): + ... + + The provider class must initialize these instance attributes in __init__: + self._quota_cache: Dict[str, CodexQuotaSnapshot] = {} + self._quota_refresh_interval: int = 300 + """ + + # Type hints for attributes from provider + _credentials_cache: Dict[str, Dict[str, Any]] + _quota_cache: Dict[str, CodexQuotaSnapshot] + _quota_refresh_interval: int + + def _init_quota_tracker(self): + """Initialize quota tracker state. Call from provider's __init__.""" + self._quota_cache: Dict[str, CodexQuotaSnapshot] = {} + self._quota_refresh_interval: int = DEFAULT_QUOTA_REFRESH_INTERVAL + self._usage_manager: Optional["UsageManager"] = None + self._initial_baselines_fetched: bool = False + + def set_usage_manager(self, usage_manager: "UsageManager") -> None: + """Set the UsageManager reference for pushing quota updates.""" + self._usage_manager = usage_manager + + # ========================================================================= + # QUOTA API FETCHING + # ========================================================================= + + async def fetch_quota_from_api( + self, + credential_path: str, + api_base: str = "https://chatgpt.com/backend-api/codex", + ) -> CodexQuotaSnapshot: + """ + Fetch quota information from the Codex /usage API endpoint. + + Args: + credential_path: Path to credential file or env:// URI + api_base: Base URL for the Codex API + + Returns: + CodexQuotaSnapshot with rate limit and credits info + """ + identifier = _get_credential_identifier(credential_path) + + try: + # Get auth headers + auth_headers = await self.get_auth_header(credential_path) + account_id = await self.get_account_id(credential_path) + + headers = { + **auth_headers, + "Content-Type": "application/json", + "User-Agent": "codex-cli", # Required by Codex API + } + if account_id: + headers["ChatGPT-Account-Id"] = account_id # Exact capitalization from Codex CLI + + # Use the correct Codex API URL + url = CODEX_USAGE_URL + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=30) + response.raise_for_status() + data = response.json() + + # Parse response + plan_type = data.get("plan_type") + + # Parse rate_limit section + rate_limit = data.get("rate_limit") + primary = None + secondary = None + + if rate_limit: + primary_data = rate_limit.get("primary_window") + if primary_data: + primary = RateLimitWindow( + used_percent=float(primary_data.get("used_percent", 0)), + remaining_percent=100 - float(primary_data.get("used_percent", 0)), + window_minutes=_seconds_to_minutes( + primary_data.get("limit_window_seconds") + ), + reset_at=primary_data.get("reset_at"), + ) + + secondary_data = rate_limit.get("secondary_window") + if secondary_data: + secondary = RateLimitWindow( + used_percent=float(secondary_data.get("used_percent", 0)), + remaining_percent=100 - float(secondary_data.get("used_percent", 0)), + window_minutes=_seconds_to_minutes( + secondary_data.get("limit_window_seconds") + ), + reset_at=secondary_data.get("reset_at"), + ) + + # Parse credits section + credits_data = data.get("credits") + credits = None + if credits_data: + credits = CreditsInfo( + has_credits=credits_data.get("has_credits", False), + unlimited=credits_data.get("unlimited", False), + balance=credits_data.get("balance"), + ) + + snapshot = CodexQuotaSnapshot( + credential_path=credential_path, + identifier=identifier, + plan_type=plan_type, + primary=primary, + secondary=secondary, + credits=credits, + fetched_at=time.time(), + status="success", + error=None, + ) + + # Cache the snapshot + self._quota_cache[credential_path] = snapshot + + lib_logger.debug( + f"Fetched Codex quota for {identifier}: " + f"primary={primary.remaining_percent:.1f}% remaining" + if primary + else f"Fetched Codex quota for {identifier}: no rate limit data" + ) + + return snapshot + + except httpx.HTTPStatusError as e: + error_msg = f"HTTP {e.response.status_code}: {e.response.text[:200]}" + lib_logger.warning(f"Failed to fetch Codex quota for {identifier}: {error_msg}") + return CodexQuotaSnapshot( + credential_path=credential_path, + identifier=identifier, + plan_type=None, + primary=None, + secondary=None, + credits=None, + fetched_at=time.time(), + status="error", + error=error_msg, + ) + + except Exception as e: + error_msg = str(e) + lib_logger.warning(f"Failed to fetch Codex quota for {identifier}: {error_msg}") + return CodexQuotaSnapshot( + credential_path=credential_path, + identifier=identifier, + plan_type=None, + primary=None, + secondary=None, + credits=None, + fetched_at=time.time(), + status="error", + error=error_msg, + ) + + def update_quota_from_headers( + self, + credential_path: str, + headers: Dict[str, str], + ) -> Optional[CodexQuotaSnapshot]: + """ + Update cached quota info from response headers. + + Call this after each API response to keep quota cache up-to-date. + Also pushes quota data to the UsageManager if available. + + Args: + credential_path: Credential that made the request + headers: Response headers dict + + Returns: + Updated CodexQuotaSnapshot or None if no quota headers present + """ + snapshot = parse_rate_limit_headers(headers) + + if snapshot.status == "no_data": + return None + + # Preserve existing metadata + existing = self._quota_cache.get(credential_path) + if existing: + snapshot.plan_type = existing.plan_type + + snapshot.credential_path = credential_path + snapshot.identifier = _get_credential_identifier(credential_path) + + self._quota_cache[credential_path] = snapshot + + # Log quota info when captured from headers + if snapshot.primary: + remaining = snapshot.primary.remaining_percent + reset_secs = snapshot.primary.seconds_until_reset() + if reset_secs is not None: + reset_str = f"{int(reset_secs // 60)}m" + else: + reset_str = "?" + lib_logger.debug( + f"Codex quota from headers ({snapshot.identifier}): " + f"{remaining:.0f}% remaining, resets in {reset_str}" + ) + + # Push quota data to UsageManager if available + if self._usage_manager: + self._push_quota_to_usage_manager(credential_path, snapshot) + + return snapshot + + def _push_quota_to_usage_manager( + self, + credential_path: str, + snapshot: CodexQuotaSnapshot, + ) -> None: + """ + Push parsed quota snapshot to the UsageManager. + + Translates the primary/secondary rate limit windows into + update_quota_baseline calls so the TUI can display quota status. + """ + if not self._usage_manager: + return + + provider_prefix = getattr(self, "provider_env_name", "codex") + + try: + import asyncio + loop = asyncio.get_event_loop() + except RuntimeError: + return + + async def _push(): + try: + if snapshot.primary: + used_pct = snapshot.primary.used_percent + # Convert percentage to a request count on a 100-scale + quota_used = int(used_pct) + await self._usage_manager.update_quota_baseline( + accessor=credential_path, + model=f"{provider_prefix}/_5h_window", + quota_max_requests=100, + quota_reset_ts=snapshot.primary.reset_at, + quota_used=quota_used, + quota_group="5h-limit", + force=True, + apply_exhaustion=snapshot.primary.is_exhausted, + ) + + if snapshot.secondary: + used_pct = snapshot.secondary.used_percent + quota_used = int(used_pct) + await self._usage_manager.update_quota_baseline( + accessor=credential_path, + model=f"{provider_prefix}/_weekly_window", + quota_max_requests=100, + quota_reset_ts=snapshot.secondary.reset_at, + quota_used=quota_used, + quota_group="weekly-limit", + force=True, + apply_exhaustion=snapshot.secondary.is_exhausted, + ) + except Exception as e: + lib_logger.debug( + f"Failed to push Codex quota to UsageManager: {e}" + ) + + # Schedule the async push - we're already in an async context + # when this is called from the streaming/non-streaming handlers + if loop.is_running(): + asyncio.ensure_future(_push()) + else: + loop.run_until_complete(_push()) + + def get_cached_quota( + self, + credential_path: str, + ) -> Optional[CodexQuotaSnapshot]: + """ + Get cached quota snapshot for a credential. + + Args: + credential_path: Credential to look up + + Returns: + Cached CodexQuotaSnapshot or None if not cached + """ + return self._quota_cache.get(credential_path) + + # ========================================================================= + # QUOTA INFO AGGREGATION + # ========================================================================= + + async def get_all_quota_info( + self, + credential_paths: List[str], + force_refresh: bool = False, + api_base: str = "https://chatgpt.com/backend-api/codex", + ) -> Dict[str, Any]: + """ + Get quota info for all credentials. + + Args: + credential_paths: List of credential paths to query + force_refresh: If True, fetch fresh data; if False, use cache if available + api_base: Base URL for the Codex API + + Returns: + { + "credentials": { + "identifier": { + "identifier": str, + "file_path": str | None, + "plan_type": str | None, + "status": "success" | "error" | "cached", + "error": str | None, + "primary": { + "remaining_percent": float, + "remaining_fraction": float, + "used_percent": float, + "window_minutes": int | None, + "reset_at": int | None, + "reset_in_seconds": float | None, + "is_exhausted": bool, + } | None, + "secondary": {...} | None, + "credits": { + "has_credits": bool, + "unlimited": bool, + "balance": str | None, + } | None, + "fetched_at": float, + "is_stale": bool, + } + }, + "summary": { + "total_credentials": int, + "by_plan_type": Dict[str, int], + "exhausted_count": int, + }, + "timestamp": float, + } + """ + results = {} + plan_type_counts: Dict[str, int] = {} + exhausted_count = 0 + + for cred_path in credential_paths: + identifier = _get_credential_identifier(cred_path) + + # Check cache first unless force_refresh + cached = self._quota_cache.get(cred_path) + if not force_refresh and cached and not cached.is_stale: + snapshot = cached + status = "cached" + else: + snapshot = await self.fetch_quota_from_api(cred_path, api_base) + status = snapshot.status + + # Count plan types + if snapshot.plan_type: + plan_type_counts[snapshot.plan_type] = ( + plan_type_counts.get(snapshot.plan_type, 0) + 1 + ) + + # Check if exhausted + if snapshot.primary and snapshot.primary.is_exhausted: + exhausted_count += 1 + + # Build result entry + entry = { + "identifier": identifier, + "file_path": cred_path if not cred_path.startswith("env://") else None, + "plan_type": snapshot.plan_type, + "status": status, + "error": snapshot.error, + "primary": _window_to_dict(snapshot.primary) if snapshot.primary else None, + "secondary": _window_to_dict(snapshot.secondary) if snapshot.secondary else None, + "credits": _credits_to_dict(snapshot.credits) if snapshot.credits else None, + "fetched_at": snapshot.fetched_at, + "is_stale": snapshot.is_stale, + } + + results[identifier] = entry + + return { + "credentials": results, + "summary": { + "total_credentials": len(credential_paths), + "by_plan_type": plan_type_counts, + "exhausted_count": exhausted_count, + }, + "timestamp": time.time(), + } + + # ========================================================================= + # BACKGROUND JOB SUPPORT + # ========================================================================= + + def get_background_job_config(self) -> Optional[Dict[str, Any]]: + """ + Return configuration for quota refresh background job. + + Returns: + Background job config dict + """ + return { + "interval": self._quota_refresh_interval, + "name": "codex_quota_refresh", + "run_on_start": True, + } + + async def run_background_job( + self, + usage_manager: "UsageManager", + credentials: List[str], + ) -> None: + """ + Execute periodic quota refresh for active credentials. + + Called by BackgroundRefresher at the configured interval. + On first run, fetches baselines for ALL credentials and applies + exhaustion cooldowns so we don't waste requests on depleted keys. + + Args: + usage_manager: UsageManager instance (for future baseline storage) + credentials: List of credential paths for this provider + """ + if not credentials: + return + + # On first run, fetch baselines for ALL credentials to detect exhaustion + if not self._initial_baselines_fetched: + self._initial_baselines_fetched = True + try: + quota_results = await self.fetch_initial_baselines(credentials) + stored = await self._store_baselines_to_usage_manager( + quota_results, + usage_manager, + force=True, + is_initial_fetch=True, + ) + # Log any exhausted credentials detected on startup + exhausted = [] + for cred_path, data in quota_results.items(): + if data.get("status") != "success": + continue + primary = data.get("primary") + secondary = data.get("secondary") + if primary and primary.get("is_exhausted"): + exhausted.append( + f"{_get_credential_identifier(cred_path)} (5h window)" + ) + if secondary and secondary.get("is_exhausted"): + exhausted.append( + f"{_get_credential_identifier(cred_path)} (weekly)" + ) + if exhausted: + lib_logger.warning( + f"Codex startup: {len(exhausted)} exhausted quota(s) detected, " + f"cooldowns applied: {', '.join(exhausted)}" + ) + else: + lib_logger.info( + f"Codex startup: {stored} baselines stored, no exhausted credentials" + ) + except Exception as e: + lib_logger.error(f"Codex startup baseline fetch failed: {e}") + return + + # Subsequent runs: only refresh credentials that have been used recently + now = time.time() + active_credentials = [] + + for cred_path in credentials: + cached = self._quota_cache.get(cred_path) + # Refresh if cached and was fetched within the last hour + if cached and (now - cached.fetched_at) < 3600: + active_credentials.append(cred_path) + + if not active_credentials: + lib_logger.debug("No active Codex credentials to refresh quota for") + return + + lib_logger.debug( + f"Refreshing Codex quota for {len(active_credentials)} active credentials" + ) + + # Fetch quotas with limited concurrency + semaphore = asyncio.Semaphore(3) + + async def fetch_with_semaphore(cred_path: str): + async with semaphore: + return await self.fetch_quota_from_api(cred_path) + + tasks = [fetch_with_semaphore(cred) for cred in active_credentials] + results = await asyncio.gather(*tasks, return_exceptions=True) + + success_count = sum( + 1 + for r in results + if isinstance(r, CodexQuotaSnapshot) and r.status == "success" + ) + + lib_logger.debug( + f"Codex quota refresh complete: {success_count}/{len(active_credentials)} successful" + ) + + # ========================================================================= + # USAGE MANAGER INTEGRATION + # ========================================================================= + + async def _store_baselines_to_usage_manager( + self, + quota_results: Dict[str, Dict[str, Any]], + usage_manager: "UsageManager", + force: bool = False, + is_initial_fetch: bool = False, + ) -> int: + """ + Store Codex quota baselines into UsageManager. + + Codex has a global rate limit (primary/secondary window) that applies + to all models. This method stores the same baseline for all models + so the quota display works correctly. + + Args: + quota_results: Dict from fetch_initial_baselines mapping cred_path -> quota data + usage_manager: UsageManager instance to store baselines in + force: If True, always overwrite existing values + is_initial_fetch: If True, apply exhaustion cooldowns + + Returns: + Number of baselines successfully stored + """ + stored_count = 0 + + # Get available models from the provider (will be set by CodexProvider) + models = getattr(self, "_available_models_for_quota", []) + provider_prefix = getattr(self, "provider_env_name", "codex") + + for cred_path, quota_data in quota_results.items(): + if quota_data.get("status") != "success": + continue + + # Get remaining fraction from primary and secondary windows + primary = quota_data.get("primary") + secondary = quota_data.get("secondary") + + # Short credential name for logging + if cred_path.startswith("env://"): + short_cred = cred_path.split("/")[-1] + else: + short_cred = Path(cred_path).stem + + # Store primary window (5h limit) under virtual model "_5h_window" + if primary: + primary_remaining = primary.get("remaining_fraction", 1.0) + primary_used_pct = primary.get("used_percent", 0) + primary_reset = primary.get("reset_at") + is_exhausted = primary.get("is_exhausted", False) + try: + await usage_manager.update_quota_baseline( + accessor=cred_path, + model=f"{provider_prefix}/_5h_window", + quota_max_requests=100, + quota_reset_ts=primary_reset, + quota_used=int(primary_used_pct), + quota_group="5h-limit", + force=force, + apply_exhaustion=is_exhausted and is_initial_fetch, + ) + stored_count += 1 + lib_logger.debug( + f"Stored Codex 5h baseline for {short_cred}: " + f"{primary_remaining * 100:.1f}% remaining" + ) + except Exception as e: + lib_logger.warning( + f"Failed to store Codex 5h baseline for {short_cred}: {e}" + ) + + # Store secondary window (weekly limit) under virtual model "_weekly_window" + if secondary: + secondary_remaining = secondary.get("remaining_fraction", 1.0) + secondary_used_pct = secondary.get("used_percent", 0) + secondary_reset = secondary.get("reset_at") + is_exhausted = secondary.get("is_exhausted", False) + try: + await usage_manager.update_quota_baseline( + accessor=cred_path, + model=f"{provider_prefix}/_weekly_window", + quota_max_requests=100, + quota_reset_ts=secondary_reset, + quota_used=int(secondary_used_pct), + quota_group="weekly-limit", + force=force, + apply_exhaustion=is_exhausted and is_initial_fetch, + ) + stored_count += 1 + lib_logger.debug( + f"Stored Codex weekly baseline for {short_cred}: " + f"{secondary_remaining * 100:.1f}% remaining" + ) + except Exception as e: + lib_logger.warning( + f"Failed to store Codex weekly baseline for {short_cred}: {e}" + ) + + return stored_count + + async def fetch_initial_baselines( + self, + credential_paths: List[str], + api_base: str = "https://chatgpt.com/backend-api/codex", + ) -> Dict[str, Dict[str, Any]]: + """ + Fetch quota baselines for all credentials. + + This matches the interface expected by RotatingClient for quota tracking. + + Args: + credential_paths: All credential paths to fetch baselines for + api_base: Base URL for the Codex API + + Returns: + Dict mapping credential_path -> quota data in format: + { + "status": "success" | "error", + "error": str | None, + "primary": { + "remaining_fraction": float, + "remaining_percent": float, + "used_percent": float, + "reset_at": int | None, + ... + }, + "secondary": {...} | None, + "plan_type": str | None, + } + """ + if not credential_paths: + return {} + + lib_logger.info( + f"codex: Fetching initial quota baselines for {len(credential_paths)} credentials..." + ) + + results: Dict[str, Dict[str, Any]] = {} + + # Fetch quotas concurrently with limited concurrency + semaphore = asyncio.Semaphore(3) + + async def fetch_with_semaphore(cred_path: str): + async with semaphore: + snapshot = await self.fetch_quota_from_api(cred_path, api_base) + return cred_path, snapshot + + tasks = [fetch_with_semaphore(cred) for cred in credential_paths] + fetch_results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in fetch_results: + if isinstance(result, Exception): + lib_logger.warning(f"Codex quota fetch error: {result}") + continue + + cred_path, snapshot = result + + # Convert snapshot to dict format expected by client.py + if snapshot.status == "success": + results[cred_path] = { + "status": "success", + "error": None, + "plan_type": snapshot.plan_type, + "primary": { + "remaining_fraction": snapshot.primary.remaining_fraction if snapshot.primary else 0, + "remaining_percent": snapshot.primary.remaining_percent if snapshot.primary else 0, + "used_percent": snapshot.primary.used_percent if snapshot.primary else 100, + "reset_at": snapshot.primary.reset_at if snapshot.primary else None, + "window_minutes": snapshot.primary.window_minutes if snapshot.primary else None, + "is_exhausted": snapshot.primary.is_exhausted if snapshot.primary else True, + } if snapshot.primary else None, + "secondary": { + "remaining_fraction": snapshot.secondary.remaining_fraction, + "remaining_percent": snapshot.secondary.remaining_percent, + "used_percent": snapshot.secondary.used_percent, + "reset_at": snapshot.secondary.reset_at, + "window_minutes": snapshot.secondary.window_minutes, + "is_exhausted": snapshot.secondary.is_exhausted, + } if snapshot.secondary else None, + "credits": { + "has_credits": snapshot.credits.has_credits, + "unlimited": snapshot.credits.unlimited, + "balance": snapshot.credits.balance, + } if snapshot.credits else None, + } + else: + results[cred_path] = { + "status": "error", + "error": snapshot.error or "Unknown error", + } + + success_count = sum(1 for v in results.values() if v.get("status") == "success") + lib_logger.info( + f"codex: Fetched {success_count}/{len(credential_paths)} quota baselines" + ) + + return results diff --git a/src/rotator_library/providers/utilities/cursor_quota_tracker.py b/src/rotator_library/providers/utilities/cursor_quota_tracker.py new file mode 100644 index 00000000..711eee00 --- /dev/null +++ b/src/rotator_library/providers/utilities/cursor_quota_tracker.py @@ -0,0 +1,344 @@ +""" +Cursor Quota Tracking Mixin + +Provides quota tracking for the Cursor provider using their web API. +Cursor uses a monthly quota system where requests are tracked per model. + +API Details: +- Endpoint: GET https://cursor.com/api/usage?user={user_id} +- Auth: Cookie header with WorkosCursorSessionToken +- Response: { "gpt-4": {"numRequests": int, "maxRequestUsage": int, ...}, "startOfMonth": str } + +The user_id is extracted from the session token (format: user_XXXX::jwt) + +Required from provider: + - self._quota_cache: Dict[str, Dict[str, Any]] = {} + - self._quota_refresh_interval: int = 300 +""" + +import logging +import os +import time +import urllib.parse +from typing import Any, Dict, List, Optional, Tuple + +import httpx + +# Use the shared rotator_library logger +lib_logger = logging.getLogger("rotator_library") + +# Cursor API configuration +CURSOR_API_BASE = "https://cursor.com/api" +CURSOR_USAGE_ENDPOINT = "/usage" + + +class CursorQuotaTracker: + """ + Mixin class providing quota tracking functionality for the Cursor provider. + + This mixin adds the following capabilities: + - Fetch quota usage from the Cursor web API + - Track monthly request limits per model + - Parse user ID from session token + + Usage: + class CursorProvider(CursorQuotaTracker, OpenAICompatibleProvider): + ... + + The provider class must initialize these instance attributes in __init__: + self._quota_cache: Dict[str, Dict[str, Any]] = {} + self._quota_refresh_interval: int = 300 # 5 min default + """ + + # Type hints for attributes from provider + _quota_cache: Dict[str, Dict[str, Any]] + _quota_refresh_interval: int + + # ========================================================================= + # TOKEN PARSING + # ========================================================================= + + def _extract_user_id_from_token(self, session_token: str) -> Optional[str]: + """ + Extract user ID from the session token. + + Token format: user_XXXX%3A%3Ajwt... (URL-encoded :: separator) + or: user_XXXX::jwt... (decoded format) + + Args: + session_token: The WorkosCursorSessionToken value + + Returns: + User ID (e.g., "user_01JWV7FARDJPMQ5QZSANMJDS9A") or None + """ + try: + # URL-decode first in case it's encoded + decoded = urllib.parse.unquote(session_token) + + # Split on :: separator + if "::" in decoded: + user_id = decoded.split("::")[0] + if user_id.startswith("user_"): + return user_id + + # Try extracting from the token prefix directly + if session_token.startswith("user_"): + # Find the separator (either %3A%3A or ::) + if "%3A%3A" in session_token: + return session_token.split("%3A%3A")[0] + elif "::" in session_token: + return session_token.split("::")[0] + + return None + except Exception as e: + lib_logger.warning(f"Failed to extract user ID from session token: {e}") + return None + + # ========================================================================= + # QUOTA USAGE API + # ========================================================================= + + async def fetch_cursor_quota_usage( + self, + session_token: str, + client: Optional[httpx.AsyncClient] = None, + ) -> Dict[str, Any]: + """ + Fetch quota usage from the Cursor web API. + + Args: + session_token: The WorkosCursorSessionToken cookie value + client: Optional HTTP client for connection reuse + + Returns: + { + "status": "success" | "error", + "error": str | None, + "models": { + "gpt-4": {"numRequests": int, "maxRequestUsage": int, "remaining_fraction": float}, + ... + }, + "start_of_month": str | None, + "fetched_at": float, + } + """ + try: + # Extract user ID from token + user_id = self._extract_user_id_from_token(session_token) + if not user_id: + return { + "status": "error", + "error": "Could not extract user ID from session token", + "models": {}, + "start_of_month": None, + "fetched_at": time.time(), + } + + headers = { + "Accept": "application/json", + "Cookie": f"WorkosCursorSessionToken={session_token}", + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36", + } + + # URL-encode user_id for safety + encoded_user_id = urllib.parse.quote(user_id, safe="") + url = f"{CURSOR_API_BASE}{CURSOR_USAGE_ENDPOINT}?user={encoded_user_id}" + + if client is not None: + response = await client.get(url, headers=headers, timeout=30, follow_redirects=True) + else: + async with httpx.AsyncClient() as new_client: + response = await new_client.get( + url, headers=headers, timeout=30, follow_redirects=True + ) + + response.raise_for_status() + data = response.json() + + # Check for auth errors + if "error" in data: + error_msg = data.get("description", data.get("error", "Unknown error")) + if data.get("error") == "not_authenticated": + lib_logger.warning( + "Cursor session token expired or invalid. " + "Please update CURSOR_SESSION_TOKEN with a fresh cookie from cursor.com" + ) + return { + "status": "error", + "error": error_msg, + "models": {}, + "start_of_month": None, + "fetched_at": time.time(), + } + + # Parse the response + # Format: {"gpt-4": {...}, "startOfMonth": "2026-01-23T22:27:08.000Z"} + start_of_month = data.pop("startOfMonth", None) + + models = {} + for model_name, usage_data in data.items(): + if isinstance(usage_data, dict): + num_requests = usage_data.get("numRequests", 0) + max_requests = usage_data.get("maxRequestUsage") + + # Calculate remaining fraction + if max_requests and max_requests > 0: + remaining = max(0, max_requests - num_requests) + remaining_fraction = remaining / max_requests + else: + # No limit or unknown limit + remaining_fraction = 1.0 + + models[model_name] = { + "numRequests": num_requests, + "numRequestsTotal": usage_data.get("numRequestsTotal", num_requests), + "numTokens": usage_data.get("numTokens", 0), + "maxRequestUsage": max_requests, + "maxTokenUsage": usage_data.get("maxTokenUsage"), + "remaining_fraction": remaining_fraction, + } + + return { + "status": "success", + "error": None, + "models": models, + "start_of_month": start_of_month, + "fetched_at": time.time(), + } + + except httpx.HTTPStatusError as e: + error_msg = f"HTTP {e.response.status_code}" + if e.response.status_code in (401, 403): + lib_logger.warning( + f"Cursor API authentication failed ({error_msg}). " + "Please update CURSOR_SESSION_TOKEN with a fresh cookie from cursor.com" + ) + else: + lib_logger.warning(f"Failed to fetch Cursor quota: {error_msg}") + return { + "status": "error", + "error": error_msg, + "models": {}, + "start_of_month": None, + "fetched_at": time.time(), + } + except Exception as e: + lib_logger.warning(f"Failed to fetch Cursor quota: {type(e).__name__}: {e}") + return { + "status": "error", + "error": str(e), + "models": {}, + "start_of_month": None, + "fetched_at": time.time(), + } + + def get_cursor_remaining_fraction( + self, usage_data: Dict[str, Any], model: str + ) -> Optional[float]: + """ + Get remaining quota fraction for a specific model. + + Args: + usage_data: Response from fetch_cursor_quota_usage() + model: Model name (e.g., "gpt-4") + + Returns: + Remaining fraction (0.0 to 1.0) or None if not found + """ + models = usage_data.get("models", {}) + model_data = models.get(model) + if model_data: + return model_data.get("remaining_fraction", 1.0) + return None + + # ========================================================================= + # BACKGROUND JOB SUPPORT + # ========================================================================= + + async def refresh_cursor_quota_usage( + self, + credential_identifier: str, + ) -> Dict[str, Any]: + """ + Refresh and cache quota usage for a credential. + + The credential_identifier for Cursor is the session token from + CURSOR_SESSION_TOKEN environment variable. + + Args: + credential_identifier: Identifier for caching (typically "cursor_session") + + Returns: + Usage data from fetch_cursor_quota_usage() + """ + session_token = os.environ.get("CURSOR_SESSION_TOKEN") + if not session_token: + lib_logger.warning( + "CURSOR_SESSION_TOKEN not set - cannot fetch quota" + ) + return { + "status": "error", + "error": "CURSOR_SESSION_TOKEN not configured", + "models": {}, + "start_of_month": None, + "fetched_at": time.time(), + } + + usage_data = await self.fetch_cursor_quota_usage(session_token) + + if usage_data.get("status") == "success": + self._quota_cache[credential_identifier] = usage_data + + models = usage_data.get("models", {}) + if models: + model_summary = ", ".join( + f"{m}: {d.get('remaining_fraction', 0) * 100:.1f}%" + for m, d in models.items() + ) + lib_logger.debug(f"Cursor quota: {model_summary}") + + return usage_data + + def get_cached_cursor_usage( + self, credential_identifier: str + ) -> Optional[Dict[str, Any]]: + """ + Get cached quota usage for a credential. + + Args: + credential_identifier: Identifier used in caching + + Returns: + Copy of cached usage data or None + """ + cached = self._quota_cache.get(credential_identifier) + return dict(cached) if cached else None + + # ========================================================================= + # MODEL QUOTA EXTRACTION + # ========================================================================= + + def extract_cursor_model_quotas( + self, usage_data: Dict[str, Any] + ) -> List[Tuple[str, float, Optional[int]]]: + """ + Extract model quota information from usage data. + + Args: + usage_data: Response from fetch_cursor_quota_usage() + + Returns: + List of tuples: (model_name, remaining_fraction, max_requests) + - model_name: Model name from Cursor API (e.g., "gpt-4") + - remaining_fraction: 0.0 to 1.0 + - max_requests: Maximum requests for this model, or None if unlimited + """ + result = [] + models = usage_data.get("models", {}) + + for model_name, model_data in models.items(): + remaining_fraction = model_data.get("remaining_fraction", 1.0) + max_requests = model_data.get("maxRequestUsage") + result.append((model_name, remaining_fraction, max_requests)) + + return result diff --git a/src/rotator_library/providers/zenmux_provider.py b/src/rotator_library/providers/zenmux_provider.py new file mode 100644 index 00000000..7e2c11f6 --- /dev/null +++ b/src/rotator_library/providers/zenmux_provider.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +import os +import logging +from typing import List, Dict, Any, Optional, Union, AsyncGenerator +import httpx +import litellm + +from .provider_interface import ProviderInterface + +lib_logger = logging.getLogger("rotator_library") +lib_logger.propagate = False +if not lib_logger.handlers: + lib_logger.addHandler(logging.NullHandler()) + + +class ZenmuxProvider(ProviderInterface): + """ + Provider for ZenMux via OpenCode Zen gateway - OpenAI-compatible API. + + Accesses free tier models through OpenCode's Zen gateway which proxies + to ZenMux. Uses a public API key for free models. + + Free models have the "-free" suffix in their model IDs. + + Environment Variables: + ZENMUX_API_BASE - The API base URL (default: https://opencode.ai/zen/v1) + + Custom Headers Required: + HTTP-Referer: https://opencode.ai/ + X-Title: opencode + """ + + provider_env_name = "zenmux" + skip_cost_calculation: bool = True # ZenMux free models have no cost tracking + + def __init__(self): + super().__init__() + self.api_base = os.getenv("ZENMUX_API_BASE", "https://opencode.ai/zen/v1") + + def _get_headers(self) -> Dict[str, str]: + """Return the custom headers required by ZenMux.""" + return { + "HTTP-Referer": "https://opencode.ai/", + "X-Title": "opencode", + } + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """ + Fetch available models from ZenMux. + + The models endpoint is public and doesn't require authentication. + """ + models = [] + try: + models_url = f"{self.api_base.rstrip('/')}/models" + response = await client.get( + models_url, + headers=self._get_headers(), + timeout=30.0, + ) + response.raise_for_status() + + data = response.json() + for model in data.get("data", []): + model_id = model.get("id") + if model_id: + models.append(f"zenmux/{model_id}") + + lib_logger.info(f"Discovered {len(models)} models from ZenMux") + + except Exception as e: + lib_logger.warning(f"Failed to fetch models from ZenMux: {e}") + + return models + + def has_custom_logic(self) -> bool: + """ + Returns True because we need to handle API calls with custom headers. + """ + return True + + async def acompletion( + self, + client: httpx.AsyncClient, + **kwargs, # client unused - LiteLLM manages its own + ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: + """ + Handle completion calls with ZenMux custom headers. + + We use LiteLLM but override the headers to include ZenMux's required + identification headers. + """ + # Clean up kwargs not needed by LiteLLM + kwargs.pop("credential_identifier", None) + kwargs.pop("transaction_context", None) + + # Transform model name for LiteLLM's OpenAI provider + # "zenmux/gpt-4-free" -> "openai/gpt-4-free" + model = kwargs.get("model", "") + if model.startswith("zenmux/"): + kwargs["model"] = "openai/" + model[len("zenmux/") :] + + # Add custom headers to the kwargs (without mutating caller's dict) + extra_headers = self._get_headers() + existing_headers = kwargs.get("extra_headers") or {} + kwargs["extra_headers"] = {**existing_headers, **extra_headers} + + # Ensure api_base is set + kwargs["api_base"] = self.api_base + + # Use the public API key for the OpenCode Zen gateway + if not kwargs.get("api_key"): + kwargs["api_key"] = "public" + + # Call LiteLLM with the custom headers + is_streaming = kwargs.get("stream", False) + if is_streaming: + # Return an async generator for streaming + async def stream_wrapper(): + async for chunk in await litellm.acompletion(**kwargs): + yield chunk + + return stream_wrapper() + else: + return await litellm.acompletion(**kwargs) + + async def aembedding( + self, + client: httpx.AsyncClient, + **kwargs, # client unused - LiteLLM manages its own + ) -> litellm.EmbeddingResponse: + """ + Handle embedding calls with ZenMux custom headers. + """ + # Clean up kwargs not needed by LiteLLM + kwargs.pop("credential_identifier", None) + kwargs.pop("transaction_context", None) + + # Transform model name for LiteLLM's OpenAI provider + model = kwargs.get("model", "") + if model.startswith("zenmux/"): + kwargs["model"] = "openai/" + model[len("zenmux/") :] + + # Add custom headers (without mutating caller's dict) + extra_headers = self._get_headers() + existing_headers = kwargs.get("extra_headers") or {} + kwargs["extra_headers"] = {**existing_headers, **extra_headers} + + kwargs["api_base"] = self.api_base + + if not kwargs.get("api_key"): + kwargs["api_key"] = "public" + + return await litellm.aembedding(**kwargs) + + def convert_safety_settings( + self, settings: Dict[str, str] + ) -> Optional[List[Dict[str, Any]]]: + """ + ZenMux doesn't have specific safety settings to convert. + """ + return None + + def get_credential_tier_name(self, credential: str) -> Optional[str]: + """ + ZenMux free models are all free tier. + """ + return "free-tier" + + def get_model_tier_requirement(self, model: str) -> Optional[int]: + """ + All ZenMux models available through this provider are free tier. + """ + return None diff --git a/src/rotator_library/transaction_logger.py b/src/rotator_library/transaction_logger.py index e1de4d67..61f01e91 100644 --- a/src/rotator_library/transaction_logger.py +++ b/src/rotator_library/transaction_logger.py @@ -265,8 +265,12 @@ def _log_metadata( model = response_data.get("model", self.model) finish_reason = "N/A" + # Handle OpenAI format (choices[0].finish_reason) if "choices" in response_data and response_data["choices"]: finish_reason = response_data["choices"][0].get("finish_reason", "N/A") + # Handle Anthropic format (stop_reason at top level) + elif "stop_reason" in response_data: + finish_reason = response_data.get("stop_reason", "N/A") # Check for provider subdirectory has_provider_logs = False @@ -279,6 +283,19 @@ def _log_metadata( except OSError: has_provider_logs = False + # Extract token counts - support both OpenAI and Anthropic formats + # Prefers OpenAI format if available: prompt_tokens, completion_tokens + # Falls back to Anthropic format: input_tokens, output_tokens + prompt_tokens = usage.get("prompt_tokens") + if prompt_tokens is None: + prompt_tokens = usage.get("input_tokens") + completion_tokens = usage.get("completion_tokens") + if completion_tokens is None: + completion_tokens = usage.get("output_tokens") + total_tokens = usage.get("total_tokens") + if total_tokens is None and prompt_tokens is not None and completion_tokens is not None: + total_tokens = prompt_tokens + completion_tokens + metadata = { "request_id": self.request_id, "timestamp_utc": datetime.utcnow().isoformat(), @@ -288,9 +305,9 @@ def _log_metadata( "model": model, "streaming": self.streaming, "usage": { - "prompt_tokens": usage.get("prompt_tokens"), - "completion_tokens": usage.get("completion_tokens"), - "total_tokens": usage.get("total_tokens"), + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, }, "finish_reason": finish_reason, "has_provider_logs": has_provider_logs, diff --git a/src/rotator_library/usage/manager.py b/src/rotator_library/usage/manager.py index 1fbf2da1..108f2368 100644 --- a/src/rotator_library/usage/manager.py +++ b/src/rotator_library/usage/manager.py @@ -1318,6 +1318,30 @@ def _get_grouped_models(self, group: str) -> List[str]: return [] + def _get_group_models_from_data( + self, state: "CredentialState", group: str + ) -> List[str]: + """ + Get models from actual usage data that belong to a quota group. + + Unlike _get_grouped_models which returns a static list from the provider, + this method finds models dynamically from actual usage data. This is + necessary for providers like Firmware where all models share a quota pool + but the provider can't enumerate all possible models upfront. + + Args: + state: Credential state containing model usage data + group: Group name (e.g., "firmware_global") + + Returns: + List of model names from model_usage that belong to the group + """ + return [ + model + for model in state.model_usage + if self._get_model_quota_group(model) == group + ] + async def save(self, force: bool = False) -> bool: """ Save usage data to file. @@ -1519,6 +1543,47 @@ async def update_quota_baseline( return None + def get_window_request_count( + self, + accessor: str, + model: str, + quota_group: Optional[str] = None, + ) -> Optional[int]: + """Get the current request count from the primary usage window. + + Used by quota trackers to support dynamic limit learning from + observed fraction changes. Returns the raw request_count from + the usage window without modifying any state. + + Args: + accessor: Credential path/accessor string + model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5") + quota_group: Optional quota group name (if quota is tracked at group level) + + Returns: + Current request_count from the primary window, or None if not found. + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if not state: + return None + + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + primary_def = self._window_manager.get_primary_definition() + if not primary_def: + return None + + if group_key: + group_stats = state.get_group_stats(group_key) + window = group_stats.windows.get(primary_def.name) + else: + model_stats = state.get_model_stats(normalized_model) + window = model_stats.windows.get(primary_def.name) + + return window.request_count if window else None + # ========================================================================= # WINDOW CLEANUP # ========================================================================= @@ -1806,13 +1871,19 @@ def _sync_group_timing_to_models( consistent started_at, reset_at, and limit values. All models in a quota group share the same timing since they share API quota. + Uses dynamic model discovery from actual usage data, which is necessary + for providers like Firmware where all models share a quota pool but + the provider can't enumerate all possible models upfront. + Args: state: Credential state containing model stats group_key: Quota group name group_window: The authoritative group window window_name: Name of the window to sync (e.g., "5h") """ - models_in_group = self._get_grouped_models(group_key) + # Use dynamic model discovery from actual usage data + # This handles providers like Firmware where models can't be enumerated upfront + models_in_group = self._get_group_models_from_data(state, group_key) for model_name in models_in_group: model_stats = state.get_model_stats(model_name, create=False) if model_stats: diff --git a/src/rotator_library/utils/resilient_io.py b/src/rotator_library/utils/resilient_io.py index 91e96f37..11809a08 100644 --- a/src/rotator_library/utils/resilient_io.py +++ b/src/rotator_library/utils/resilient_io.py @@ -35,6 +35,41 @@ DEFAULT_BUFFERED_WRITE_RETRY_INTERVAL: float = 30.0 +# ============================================================================= +# SYMLINK-AWARE ATOMIC WRITE HELPER +# ============================================================================= + + +def _resolve_write_target(path: Path, logger: Optional[logging.Logger] = None) -> Path: + """ + Resolve symlinks to get the actual write target. + + When writing atomically with tempfile + shutil.move(), we must write to the + resolved path (symlink target) rather than the symlink itself. Otherwise, + shutil.move() replaces the symlink with a regular file instead of writing + through the symlink to the target. + + This is critical for Docker volume mounts where a symlink points to a + persistent volume - writing to the symlink path would write to the + container's ephemeral overlay filesystem instead. + + Args: + path: Original path (may be a symlink) + logger: Optional logger for warning on resolution failure + + Returns: + Resolved path (symlink target if path is a symlink, otherwise unchanged) + """ + try: + # resolve() follows all symlinks and returns the canonical absolute path + return path.resolve() + except (OSError, RuntimeError) as e: + # Resolution failed (permissions, symlink loops, etc.) - use original path + if logger: + logger.warning(f"Symlink resolution failed for {path.name}: {e}") + return path + + # ============================================================================= # BUFFERED WRITE REGISTRY (SINGLETON) # ============================================================================= @@ -193,9 +228,11 @@ def _try_write(self, path_str: str, remove_on_success: bool = True) -> bool: data, serializer, options = self._pending[path_str] path = Path(path_str) + # Resolve symlinks to write to actual target (critical for Docker volume mounts) + write_path = _resolve_write_target(path, self._logger) try: # Ensure directory exists - path.parent.mkdir(parents=True, exist_ok=True) + write_path.parent.mkdir(parents=True, exist_ok=True) # Serialize data content = serializer(data) @@ -205,7 +242,7 @@ def _try_write(self, path_str: str, remove_on_success: bool = True) -> bool: tmp_path = None try: tmp_fd, tmp_path = tempfile.mkstemp( - dir=path.parent, prefix=".tmp_", suffix=".json", text=True + dir=write_path.parent, prefix=".tmp_", suffix=".json", text=True ) with os.fdopen(tmp_fd, "w", encoding="utf-8") as f: f.write(content) @@ -218,7 +255,7 @@ def _try_write(self, path_str: str, remove_on_success: bool = True) -> bool: except (OSError, AttributeError): pass - shutil.move(tmp_path, path) + shutil.move(tmp_path, write_path) tmp_path = None finally: @@ -426,9 +463,12 @@ def _try_disk_write(self) -> bool: self._last_attempt = time.time() + # Resolve symlinks to write to actual target (critical for Docker volume mounts) + write_path = _resolve_write_target(self.path, self.logger) + try: # Ensure directory exists - self.path.parent.mkdir(parents=True, exist_ok=True) + write_path.parent.mkdir(parents=True, exist_ok=True) # Serialize data content = self._serializer(self._current_state) @@ -438,7 +478,7 @@ def _try_disk_write(self) -> bool: tmp_path = None try: tmp_fd, tmp_path = tempfile.mkstemp( - dir=self.path.parent, prefix=".tmp_", suffix=".json", text=True + dir=write_path.parent, prefix=".tmp_", suffix=".json", text=True ) with os.fdopen(tmp_fd, "w", encoding="utf-8") as f: @@ -446,7 +486,7 @@ def _try_disk_write(self) -> bool: tmp_fd = None # fdopen closes the fd # Atomic move - shutil.move(tmp_path, self.path) + shutil.move(tmp_path, write_path) tmp_path = None finally: @@ -551,13 +591,15 @@ def safe_write_json( True on success, False on failure (never raises) """ path = Path(path) + # Resolve symlinks to write to actual target (critical for Docker volume mounts) + write_path = _resolve_write_target(path, logger) # Create serializer function that matches the requested formatting def serializer(d: Any) -> str: return json.dumps(d, indent=indent, ensure_ascii=ensure_ascii) try: - path.parent.mkdir(parents=True, exist_ok=True) + write_path.parent.mkdir(parents=True, exist_ok=True) content = serializer(data) if atomic: @@ -565,7 +607,7 @@ def serializer(d: Any) -> str: tmp_path = None try: tmp_fd, tmp_path = tempfile.mkstemp( - dir=path.parent, prefix=".tmp_", suffix=".json", text=True + dir=write_path.parent, prefix=".tmp_", suffix=".json", text=True ) with os.fdopen(tmp_fd, "w", encoding="utf-8") as f: f.write(content) @@ -579,7 +621,7 @@ def serializer(d: Any) -> str: # Windows may not support chmod, ignore pass - shutil.move(tmp_path, path) + shutil.move(tmp_path, write_path) tmp_path = None finally: if tmp_fd is not None: @@ -593,13 +635,13 @@ def serializer(d: Any) -> str: except OSError: pass else: - with open(path, "w", encoding="utf-8") as f: + with open(write_path, "w", encoding="utf-8") as f: f.write(content) # Set secure permissions if requested if secure_permissions: try: - os.chmod(path, 0o600) + os.chmod(write_path, 0o600) except (OSError, AttributeError): pass