From 35ba74b8872fca1108b66cfe3fa4b3c77ab079f5 Mon Sep 17 00:00:00 2001 From: LarytheLord Date: Mon, 4 May 2026 16:22:15 +0530 Subject: [PATCH] ci: add pytest, ruff, mypy, and Docker build verification to CI - Add test-lint.yml workflow with 4 jobs: - Python Tests: runs all 116 existing pytest tests - Ruff Lint & Format: enforces code style - Mypy Type Check: advisory type checking (non-blocking) - Docker Build Check: builds image and verifies container starts - Add pyproject.toml with ruff and mypy configuration - Fix all ruff violations: - Sort imports, fix duplicate imports in admin.py - Replace nested ifs with combined conditions - Fix implicit Optional in usage_tracker.py - Use contextlib.suppress for ValueError handling - Fix unused variables in tests - Apply ruff formatting across all files - Add E501 and other per-file ignores for admin.py HTML templates --- .github/workflows/test-lint.yml | 120 +++++++++ auth-proxy/admin.py | 446 +++++++++++++++----------------- auth-proxy/config.py | 32 +-- auth-proxy/key_manager.py | 82 +++--- auth-proxy/server.py | 293 ++++++++++----------- auth-proxy/usage_tracker.py | 169 ++++++------ auth-proxy/utils.py | 7 +- pyproject.toml | 41 +++ scripts/manage_keys.py | 8 +- scripts/scrape_docs.py | 55 ++-- tests/conftest.py | 38 +-- tests/helpers.py | 1 - tests/test_admin.py | 74 +++--- tests/test_auth.py | 66 ++--- tests/test_endpoints.py | 130 +++++----- tests/test_proxy.py | 33 +-- tests/test_rate_limiting.py | 154 ++++++----- tests/test_usage_tracker.py | 50 ++-- tests/test_utils.py | 35 +-- 19 files changed, 991 insertions(+), 843 deletions(-) create mode 100644 .github/workflows/test-lint.yml create mode 100644 pyproject.toml diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml new file mode 100644 index 0000000..6eda8bd --- /dev/null +++ b/.github/workflows/test-lint.yml @@ -0,0 +1,120 @@ +name: Test & Lint + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + workflow_dispatch: + +permissions: + contents: read + +jobs: + test: + name: Python Tests + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + cache-dependency-path: | + auth-proxy/requirements.txt + requirements-test.txt + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r auth-proxy/requirements.txt + pip install -r requirements-test.txt + + - name: Run tests + working-directory: . + run: | + python -m pytest tests/ -v --tb=short + + lint: + name: Ruff Lint & Format + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + + - name: Install ruff + run: pip install ruff + + - name: Run ruff check + run: ruff check auth-proxy/ scripts/ tests/ + + - name: Run ruff format check + run: ruff format --check auth-proxy/ scripts/ tests/ + + typecheck: + name: Mypy Type Check + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + cache-dependency-path: | + auth-proxy/requirements.txt + requirements-test.txt + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r auth-proxy/requirements.txt + pip install -r requirements-test.txt + pip install mypy + + - name: Run mypy + run: | + echo "Mypy is advisory — errors are reported but do not fail CI." + echo "Fix errors incrementally as you touch files." + mypy auth-proxy/ scripts/ --ignore-missing-imports --check-untyped-defs --no-strict-optional || true + + docker: + name: Docker Build Check + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build Docker image + uses: docker/build-push-action@v6 + with: + context: . + push: false + load: true + tags: privatemode-proxy:test + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Verify container starts + run: | + docker run --rm -d --name test-container \ + -e API_KEYS_FILE=/app/secrets/keys.json \ + -e UPSTREAM_URL=http://localhost:8081 \ + -e PORT=8080 \ + -e PRIVATEMODE_API_KEY=test-key \ + -e FORCE_HTTPS=false \ + privatemode-proxy:test + sleep 5 + docker logs test-container + docker stop test-container diff --git a/auth-proxy/admin.py b/auth-proxy/admin.py index 46b4238..8501a44 100644 --- a/auth-proxy/admin.py +++ b/auth-proxy/admin.py @@ -11,24 +11,32 @@ Protected by admin password. """ -import os -import json -import time +import base64 +import contextlib import hashlib +import json +import os import secrets -import base64 +import time from collections import defaultdict -from html import escape from datetime import datetime +from html import escape from pathlib import Path + from aiohttp import web from cryptography.fernet import Fernet -from usage_tracker import get_tracker, get_time_range + from config import ( - API_KEYS_FILE, SETTINGS_FILE, ADMIN_PASSWORD, PRIVATEMODE_API_KEY, - DEFAULT_RATE_LIMIT_REQUESTS, DEFAULT_RATE_LIMIT_WINDOW, - DEFAULT_IP_RATE_LIMIT_REQUESTS, DEFAULT_IP_RATE_LIMIT_WINDOW + ADMIN_PASSWORD, + API_KEYS_FILE, + DEFAULT_IP_RATE_LIMIT_REQUESTS, + DEFAULT_IP_RATE_LIMIT_WINDOW, + DEFAULT_RATE_LIMIT_REQUESTS, + DEFAULT_RATE_LIMIT_WINDOW, + PRIVATEMODE_API_KEY, + SETTINGS_FILE, ) +from usage_tracker import get_time_range, get_tracker from utils import get_client_ip # Aliases for backwards compatibility @@ -62,9 +70,7 @@ def validate_session(token: str, ip: str) -> bool: del _sessions[token] return False # Validate IP matches the session's original IP - if ip != session_ip: - return False - return True + return ip == session_ip def delete_session(token: str) -> None: @@ -83,10 +89,10 @@ def _cleanup_sessions() -> None: def get_default_settings() -> dict: """Return default settings with rate limits from config.""" return { - 'rate_limit_requests': DEFAULT_RATE_LIMIT_REQUESTS, - 'rate_limit_window': DEFAULT_RATE_LIMIT_WINDOW, - 'ip_rate_limit_requests': DEFAULT_IP_RATE_LIMIT_REQUESTS, - 'ip_rate_limit_window': DEFAULT_IP_RATE_LIMIT_WINDOW, + "rate_limit_requests": DEFAULT_RATE_LIMIT_REQUESTS, + "rate_limit_window": DEFAULT_RATE_LIMIT_WINDOW, + "ip_rate_limit_requests": DEFAULT_IP_RATE_LIMIT_REQUESTS, + "ip_rate_limit_window": DEFAULT_IP_RATE_LIMIT_WINDOW, } @@ -100,14 +106,14 @@ def load_settings() -> dict: saved = json.load(f) # Merge saved settings with defaults (saved takes precedence) return {**defaults, **saved} - except (json.JSONDecodeError, IOError): + except (OSError, json.JSONDecodeError): return defaults def save_settings(data: dict) -> None: """Save settings to file.""" Path(SETTINGS_FILE).parent.mkdir(parents=True, exist_ok=True) - with open(SETTINGS_FILE, 'w') as f: + with open(SETTINGS_FILE, "w") as f: json.dump(data, f, indent=2) @@ -115,11 +121,12 @@ def get_privatemode_key_status() -> tuple[str, str]: """Get status of Privatemode API key. Returns (status_text, css_class).""" if PRIVATEMODE_API_KEY: # Only show prefix to minimize exposure - if PRIVATEMODE_API_KEY.startswith('pm_'): + if PRIVATEMODE_API_KEY.startswith("pm_"): return "Configured (pm_****)", "status-active" return "Configured (****)", "status-active" return "Not configured", "status-revoked" + # Temporary storage for newly generated keys (in-memory, short-lived) # Maps key_id -> (encrypted_key, timestamp) _pending_keys: dict[str, tuple[str, float]] = {} @@ -130,7 +137,7 @@ def get_privatemode_key_status() -> tuple[str, str]: CSRF_TTL = 3600 # 1 hour -PBKDF2_SALT = os.environ.get('PBKDF2_SALT', '').encode() +PBKDF2_SALT = os.environ.get("PBKDF2_SALT", "").encode() if len(PBKDF2_SALT) < 16: raise ValueError("PBKDF2_SALT environment variable must be at least 16 bytes") @@ -151,11 +158,11 @@ def _get_fernet_key() -> bytes: return _FERNET_KEY # Use PBKDF2 with 600,000 iterations (OWASP recommended for HMAC-SHA256) key_material = hashlib.pbkdf2_hmac( - 'sha256', + "sha256", ADMIN_PASSWORD.encode(), PBKDF2_SALT, iterations=600_000, - dklen=32 # Fernet requires 32 bytes + dklen=32, # Fernet requires 32 bytes ) _FERNET_KEY = base64.urlsafe_b64encode(key_material) return _FERNET_KEY @@ -234,7 +241,7 @@ def check_admin_auth(request: web.Request) -> bool: if not ADMIN_PASSWORD: return False - token = request.cookies.get('admin_session') + token = request.cookies.get("admin_session") if not token: return False @@ -253,19 +260,19 @@ def load_keys() -> dict: def save_keys(data: dict) -> None: """Save keys to file.""" Path(KEYS_FILE).parent.mkdir(parents=True, exist_ok=True) - with open(KEYS_FILE, 'w') as f: + with open(KEYS_FILE, "w") as f: json.dump(data, f, indent=2) def update_key_rate_limit(key_id: str, rate_limit: int | None) -> bool: """Update the rate limit for a specific key. Returns True if successful.""" keys_data = load_keys() - for key in keys_data['keys']: - if key['key_id'] == key_id: + for key in keys_data["keys"]: + if key["key_id"] == key_id: if rate_limit is None: - key.pop('rate_limit', None) + key.pop("rate_limit", None) else: - key['rate_limit'] = rate_limit + key["rate_limit"] = rate_limit save_keys(keys_data) return True return False @@ -280,9 +287,9 @@ def format_timestamp(ts: float | None) -> str: def get_key_status(key: dict) -> tuple[str, str]: """Get status and CSS class for a key.""" - if not key.get('enabled', True): + if not key.get("enabled", True): return "Revoked", "status-revoked" - expires_at = key.get('expires_at') + expires_at = key.get("expires_at") if expires_at and time.time() > expires_at: return "Expired", "status-expired" return "Active", "status-active" @@ -634,23 +641,18 @@ def get_key_status(key: dict) -> tuple[str, str]: async def admin_login_page(request: web.Request) -> web.Response: """Show login page.""" if not ADMIN_PASSWORD: - return web.Response( - text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", - status=403 - ) + return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403) if check_admin_auth(request): - raise web.HTTPFound('/admin') + raise web.HTTPFound("/admin") - error = request.query.get('error', '') + error = request.query.get("error", "") # Escape error message to prevent XSS attacks - error_html = f'

{escape(error)}

' if error else '' + error_html = f'

{escape(error)}

' if error else "" csrf_token = generate_csrf_token() - html = HTML_TEMPLATE.format( - content=LOGIN_CONTENT.format(error=error_html, csrf_token=csrf_token) - ) - return web.Response(text=html, content_type='text/html') + html = HTML_TEMPLATE.format(content=LOGIN_CONTENT.format(error=error_html, csrf_token=csrf_token)) + return web.Response(text=html, content_type="text/html") async def admin_login_post(request: web.Request) -> web.Response: @@ -662,14 +664,11 @@ async def admin_login_post(request: web.Request) -> web.Response: # Check rate limit BEFORE password validation if not check_login_rate_limit(client_ip): - return web.Response( - text="Too many login attempts. Please try again later.", - status=429 - ) + return web.Response(text="Too many login attempts. Please try again later.", status=429) data = await request.post() - csrf_token = data.get('csrf_token', '') - password = data.get('password', '') + csrf_token = data.get("csrf_token", "") + password = data.get("password", "") # Validate CSRF token if not validate_csrf_token(csrf_token): @@ -678,56 +677,53 @@ async def admin_login_post(request: web.Request) -> web.Response: if secrets.compare_digest(password, ADMIN_PASSWORD): # Create a new random session token token = create_session(client_ip) - response = web.HTTPFound('/admin') + response = web.HTTPFound("/admin") # Detect if running behind HTTPS (Fly.io, nginx, etc.) - is_https = request.headers.get('X-Forwarded-Proto', '').lower() == 'https' + is_https = request.headers.get("X-Forwarded-Proto", "").lower() == "https" response.set_cookie( - 'admin_session', + "admin_session", token, max_age=86400, # 24 hours httponly=True, - samesite='Strict', - secure=is_https # Only send cookie over HTTPS in production + samesite="Strict", + secure=is_https, # Only send cookie over HTTPS in production ) return response # Record failed login attempt check_login_rate_limit(client_ip, record_attempt=True) - raise web.HTTPFound('/admin/login?error=Invalid password') + raise web.HTTPFound("/admin/login?error=Invalid password") async def admin_logout(request: web.Request) -> web.Response: """Handle logout.""" # Delete the session from store - token = request.cookies.get('admin_session') + token = request.cookies.get("admin_session") if token: delete_session(token) - response = web.HTTPFound('/admin/login') - response.del_cookie('admin_session') + response = web.HTTPFound("/admin/login") + response.del_cookie("admin_session") return response async def admin_dashboard(request: web.Request) -> web.Response: """Show admin dashboard.""" if not ADMIN_PASSWORD: - return web.Response( - text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", - status=403 - ) + return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403) if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") # Check for newly generated key to display (one-time retrieval) - new_key = '' - show_new_key = '' - show_key_id = request.query.get('show_key', '') + new_key = "" + show_new_key = "" + show_key_id = request.query.get("show_key", "") if show_key_id and show_key_id in _pending_keys: encrypted_key, timestamp = _pending_keys.pop(show_key_id) # Remove after retrieval if time.time() - timestamp < PENDING_KEY_TTL: new_key = _decrypt_key_for_display(encrypted_key) - show_new_key = 'show' + show_new_key = "show" # Load keys data = load_keys() @@ -736,30 +732,30 @@ async def admin_dashboard(request: web.Request) -> web.Response: csrf_token = generate_csrf_token() # Build keys table - if not data['keys']: + if not data["keys"]: keys_table = '
No API keys configured. Generate one above.
' else: rows = [] - for key in data['keys']: + for key in data["keys"]: status, status_class = get_key_status(key) - if key.get('enabled', True): + if key.get("enabled", True): actions = f''' -
+
''' else: actions = f''' -
+
''' actions += f''' -
@@ -767,10 +763,10 @@ async def admin_dashboard(request: web.Request) -> web.Response: ''' # Format rate limit display with inline edit form - key_rate_limit = key.get('rate_limit') + key_rate_limit = key.get("rate_limit") if key_rate_limit: rate_limit_display = f''' - + @@ -779,58 +775,56 @@ async def admin_dashboard(request: web.Request) -> web.Response: ''' else: rate_limit_display = f''' - +
''' - rows.append(KEY_ROW.format( - key_id=escape(key['key_id']), - description=escape(key.get('description', '-')), - status=status, - status_class=status_class, - rate_limit_display=rate_limit_display, - created=format_timestamp(key.get('created_at')), - expires=format_timestamp(key.get('expires_at')), - actions=actions - )) - - keys_table = KEYS_TABLE.format(rows=''.join(rows)) + rows.append( + KEY_ROW.format( + key_id=escape(key["key_id"]), + description=escape(key.get("description", "-")), + status=status, + status_class=status_class, + rate_limit_display=rate_limit_display, + created=format_timestamp(key.get("created_at")), + expires=format_timestamp(key.get("expires_at")), + actions=actions, + ) + ) + + keys_table = KEYS_TABLE.format(rows="".join(rows)) # Determine base URL for usage examples - scheme = request.headers.get('X-Forwarded-Proto', request.scheme) - host = request.headers.get('X-Forwarded-Host', request.host) + scheme = request.headers.get("X-Forwarded-Proto", request.scheme) + host = request.headers.get("X-Forwarded-Host", request.host) base_url = escape(f"{scheme}://{host}") content = DASHBOARD_CONTENT.format( - keys_table=keys_table, - new_key=new_key, - show_new_key=show_new_key, - base_url=base_url, - csrf_token=csrf_token + keys_table=keys_table, new_key=new_key, show_new_key=show_new_key, base_url=base_url, csrf_token=csrf_token ) html = HTML_TEMPLATE.format(content=content) - return web.Response(text=html, content_type='text/html') + return web.Response(text=html, content_type="text/html") async def admin_generate_key(request: web.Request) -> web.Response: """Generate a new API key.""" if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") data = await request.post() - csrf_token = data.get('csrf_token', '') + csrf_token = data.get("csrf_token", "") # Validate CSRF token if not validate_csrf_token(csrf_token): return web.Response(text="Invalid or expired CSRF token", status=403) - description = data.get('description', '') - expires_days = data.get('expires_days', '') - rate_limit = data.get('rate_limit', '') + description = data.get("description", "") + expires_days = data.get("expires_days", "") + rate_limit = data.get("rate_limit", "") # Generate key new_key = f"pm_{secrets.token_urlsafe(32)}" @@ -838,28 +832,24 @@ async def admin_generate_key(request: web.Request) -> web.Response: key_id = f"key_{secrets.token_hex(4)}" entry = { - 'key_id': key_id, - 'key_hash': key_hash, - 'created_at': time.time(), - 'description': description, - 'enabled': True + "key_id": key_id, + "key_hash": key_hash, + "created_at": time.time(), + "description": description, + "enabled": True, } if expires_days: - try: - entry['expires_at'] = time.time() + (int(expires_days) * 86400) - except ValueError: - pass + with contextlib.suppress(ValueError): + entry["expires_at"] = time.time() + (int(expires_days) * 86400) if rate_limit: - try: - entry['rate_limit'] = int(rate_limit) - except ValueError: - pass + with contextlib.suppress(ValueError): + entry["rate_limit"] = int(rate_limit) # Save keys_data = load_keys() - keys_data['keys'].append(entry) + keys_data["keys"].append(entry) save_keys(keys_data) # Store encrypted key temporarily for one-time display @@ -867,100 +857,100 @@ async def admin_generate_key(request: web.Request) -> web.Response: _pending_keys[key_id] = (_encrypt_key_for_display(new_key), time.time()) # Redirect with just the key_id (not the actual key) - raise web.HTTPFound(f'/admin?show_key={key_id}') + raise web.HTTPFound(f"/admin?show_key={key_id}") async def admin_revoke_key(request: web.Request) -> web.Response: """Revoke an API key.""" if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") data = await request.post() - csrf_token = data.get('csrf_token', '') + csrf_token = data.get("csrf_token", "") # Validate CSRF token if not validate_csrf_token(csrf_token): return web.Response(text="Invalid or expired CSRF token", status=403) - key_id = request.match_info['key_id'] + key_id = request.match_info["key_id"] keys_data = load_keys() - for key in keys_data['keys']: - if key['key_id'] == key_id: - key['enabled'] = False - key['revoked_at'] = time.time() + for key in keys_data["keys"]: + if key["key_id"] == key_id: + key["enabled"] = False + key["revoked_at"] = time.time() break save_keys(keys_data) - raise web.HTTPFound('/admin') + raise web.HTTPFound("/admin") async def admin_enable_key(request: web.Request) -> web.Response: """Re-enable an API key.""" if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") data = await request.post() - csrf_token = data.get('csrf_token', '') + csrf_token = data.get("csrf_token", "") # Validate CSRF token if not validate_csrf_token(csrf_token): return web.Response(text="Invalid or expired CSRF token", status=403) - key_id = request.match_info['key_id'] + key_id = request.match_info["key_id"] keys_data = load_keys() - for key in keys_data['keys']: - if key['key_id'] == key_id: - key['enabled'] = True - if 'revoked_at' in key: - del key['revoked_at'] + for key in keys_data["keys"]: + if key["key_id"] == key_id: + key["enabled"] = True + if "revoked_at" in key: + del key["revoked_at"] break save_keys(keys_data) - raise web.HTTPFound('/admin') + raise web.HTTPFound("/admin") async def admin_delete_key(request: web.Request) -> web.Response: """Delete an API key permanently.""" if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") data = await request.post() - csrf_token = data.get('csrf_token', '') + csrf_token = data.get("csrf_token", "") # Validate CSRF token if not validate_csrf_token(csrf_token): return web.Response(text="Invalid or expired CSRF token", status=403) - key_id = request.match_info['key_id'] + key_id = request.match_info["key_id"] keys_data = load_keys() - keys_data['keys'] = [k for k in keys_data['keys'] if k['key_id'] != key_id] + keys_data["keys"] = [k for k in keys_data["keys"] if k["key_id"] != key_id] save_keys(keys_data) - raise web.HTTPFound('/admin') + raise web.HTTPFound("/admin") async def admin_update_key_rate_limit(request: web.Request) -> web.Response: """Update rate limit for a specific API key.""" if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") data = await request.post() - csrf_token = data.get('csrf_token', '') + csrf_token = data.get("csrf_token", "") # Validate CSRF token if not validate_csrf_token(csrf_token): return web.Response(text="Invalid or expired CSRF token", status=403) - key_id = request.match_info['key_id'] + key_id = request.match_info["key_id"] # Check if clearing the rate limit - if data.get('clear'): + if data.get("clear"): update_key_rate_limit(key_id, None) else: - rate_limit_str = data.get('rate_limit', '') + rate_limit_str = data.get("rate_limit", "") if rate_limit_str: try: rate_limit = int(rate_limit_str) @@ -969,16 +959,16 @@ async def admin_update_key_rate_limit(request: web.Request) -> web.Response: except ValueError: pass - raise web.HTTPFound('/admin') + raise web.HTTPFound("/admin") async def admin_save_rate_limits(request: web.Request) -> web.Response: """Save global rate limit settings.""" if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") data = await request.post() - csrf_token = data.get('csrf_token', '') + csrf_token = data.get("csrf_token", "") # Validate CSRF token if not validate_csrf_token(csrf_token): @@ -989,20 +979,20 @@ async def admin_save_rate_limits(request: web.Request) -> web.Response: # Update rate limit settings try: - if data.get('rate_limit_requests'): - settings['rate_limit_requests'] = int(data['rate_limit_requests']) - if data.get('rate_limit_window'): - settings['rate_limit_window'] = int(data['rate_limit_window']) - if data.get('ip_rate_limit_requests'): - settings['ip_rate_limit_requests'] = int(data['ip_rate_limit_requests']) - if data.get('ip_rate_limit_window'): - settings['ip_rate_limit_window'] = int(data['ip_rate_limit_window']) + if data.get("rate_limit_requests"): + settings["rate_limit_requests"] = int(data["rate_limit_requests"]) + if data.get("rate_limit_window"): + settings["rate_limit_window"] = int(data["rate_limit_window"]) + if data.get("ip_rate_limit_requests"): + settings["ip_rate_limit_requests"] = int(data["ip_rate_limit_requests"]) + if data.get("ip_rate_limit_window"): + settings["ip_rate_limit_window"] = int(data["ip_rate_limit_window"]) except ValueError: pass save_settings(settings) - raise web.HTTPFound('/admin/settings?success=rate_limits') + raise web.HTTPFound("/admin/settings?success=rate_limits") USAGE_CONTENT = """ @@ -1186,18 +1176,15 @@ async def admin_save_rate_limits(request: web.Request) -> web.Response: async def admin_settings(request: web.Request) -> web.Response: """Show settings page.""" if not ADMIN_PASSWORD: - return web.Response( - text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", - status=403 - ) + return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403) if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") # Check for success message - success = request.query.get('success', '') - success_message = '' - if success == 'rate_limits': + success = request.query.get("success", "") + success_message = "" + if success == "rate_limits": success_message = '
Rate limit settings saved successfully.
' # Get Privatemode key status @@ -1206,10 +1193,10 @@ async def admin_settings(request: web.Request) -> web.Response: # Build message and dot color based on key status if PRIVATEMODE_API_KEY: pm_key_message = '

E2E encryption active

' - pm_dot_color = '#22c55e' + pm_dot_color = "#22c55e" else: pm_key_message = '

Not configured - set PRIVATEMODE_API_KEY

' - pm_dot_color = '#ef4444' + pm_dot_color = "#ef4444" # Load current rate limit settings settings = load_settings() @@ -1222,48 +1209,45 @@ async def admin_settings(request: web.Request) -> web.Response: pm_dot_color=pm_dot_color, csrf_token=csrf_token, success_message=success_message, - rate_limit_requests=settings.get('rate_limit_requests', 100), - rate_limit_window=settings.get('rate_limit_window', 60), - ip_rate_limit_requests=settings.get('ip_rate_limit_requests', 1000), - ip_rate_limit_window=settings.get('ip_rate_limit_window', 60), + rate_limit_requests=settings.get("rate_limit_requests", 100), + rate_limit_window=settings.get("rate_limit_window", 60), + ip_rate_limit_requests=settings.get("ip_rate_limit_requests", 1000), + ip_rate_limit_window=settings.get("ip_rate_limit_window", 60), ) html = HTML_TEMPLATE.format(content=content) - return web.Response(text=html, content_type='text/html') + return web.Response(text=html, content_type="text/html") async def admin_usage(request: web.Request) -> web.Response: """Show usage dashboard.""" if not ADMIN_PASSWORD: - return web.Response( - text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", - status=403 - ) + return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403) if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") # Get time period from query params - period = request.query.get('period', 'month') + period = request.query.get("period", "month") start_time, end_time = get_time_range(period) # Period labels and active states period_labels = { - 'today': 'Today', - 'week': 'Last 7 Days', - 'month': 'Last 30 Days', - 'year': 'Last Year', - 'all': 'All Time' + "today": "Today", + "week": "Last 7 Days", + "month": "Last 30 Days", + "year": "Last Year", + "all": "All Time", } - period_label = period_labels.get(period, 'Last 30 Days') + period_label = period_labels.get(period, "Last 30 Days") # Active button states active_states = { - 'active_today': 'btn-primary' if period == 'today' else '', - 'active_week': 'btn-primary' if period == 'week' else '', - 'active_month': 'btn-primary' if period == 'month' else '', - 'active_year': 'btn-primary' if period == 'year' else '', - 'active_all': 'btn-primary' if period == 'all' else '' + "active_today": "btn-primary" if period == "today" else "", + "active_week": "btn-primary" if period == "week" else "", + "active_month": "btn-primary" if period == "month" else "", + "active_year": "btn-primary" if period == "year" else "", + "active_all": "btn-primary" if period == "all" else "", } tracker = get_tracker() @@ -1276,7 +1260,7 @@ async def admin_usage(request: web.Request) -> web.Response: # Load keys to get descriptions keys_data = load_keys() - key_descriptions = {k['key_id']: k.get('description', '-') for k in keys_data.get('keys', [])} + key_descriptions = {k["key_id"]: k.get("description", "-") for k in keys_data.get("keys", [])} # Build usage by key table if not usage_by_key: @@ -1284,44 +1268,45 @@ async def admin_usage(request: web.Request) -> web.Response: else: rows = [] # Sort by cost descending - sorted_keys = sorted(usage_by_key.items(), key=lambda x: x[1]['cost_eur'], reverse=True) + sorted_keys = sorted(usage_by_key.items(), key=lambda x: x[1]["cost_eur"], reverse=True) for key_id, data in sorted_keys: - rows.append(USAGE_BY_KEY_ROW.format( - description=escape(key_descriptions.get(key_id, 'Unknown Key')), - tokens=data['tokens'], - requests=data['requests'], - cost=data['cost_eur'] - )) - usage_by_key_table = USAGE_BY_KEY_TABLE.format(rows=''.join(rows)) + rows.append( + USAGE_BY_KEY_ROW.format( + description=escape(key_descriptions.get(key_id, "Unknown Key")), + tokens=data["tokens"], + requests=data["requests"], + cost=data["cost_eur"], + ) + ) + usage_by_key_table = USAGE_BY_KEY_TABLE.format(rows="".join(rows)) # Build usage by model table - if not summary['by_model']: + if not summary["by_model"]: usage_by_model_table = '
No usage data for this period.
' else: rows = [] # Sort by cost descending - sorted_models = sorted(summary['by_model'].items(), key=lambda x: x[1]['cost'], reverse=True) + sorted_models = sorted(summary["by_model"].items(), key=lambda x: x[1]["cost"], reverse=True) for model, data in sorted_models: - rows.append(USAGE_BY_MODEL_ROW.format( - model=escape(model), - tokens=data['tokens'], - requests=data['requests'], - cost=data['cost'] - )) - usage_by_model_table = USAGE_BY_MODEL_TABLE.format(rows=''.join(rows)) + rows.append( + USAGE_BY_MODEL_ROW.format( + model=escape(model), tokens=data["tokens"], requests=data["requests"], cost=data["cost"] + ) + ) + usage_by_model_table = USAGE_BY_MODEL_TABLE.format(rows="".join(rows)) content = USAGE_CONTENT.format( period_label=period_label, - total_cost=summary['total_cost_eur'], - total_tokens=summary['total_tokens'], - total_requests=summary['requests'], + total_cost=summary["total_cost_eur"], + total_tokens=summary["total_tokens"], + total_requests=summary["requests"], usage_by_key_table=usage_by_key_table, usage_by_model_table=usage_by_model_table, - **active_states + **active_states, ) html = HTML_TEMPLATE.format(content=content) - return web.Response(text=html, content_type='text/html') + return web.Response(text=html, content_type="text/html") ABOUT_CONTENT = """ @@ -1502,28 +1487,25 @@ async def admin_usage(request: web.Request) -> web.Response: async def admin_about(request: web.Request) -> web.Response: """Show about page.""" if not ADMIN_PASSWORD: - return web.Response( - text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", - status=403 - ) + return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403) if not check_admin_auth(request): - raise web.HTTPFound('/admin/login') + raise web.HTTPFound("/admin/login") html = HTML_TEMPLATE.format(content=ABOUT_CONTENT) - return web.Response(text=html, content_type='text/html') + return web.Response(text=html, content_type="text/html") async def admin_static(request: web.Request) -> web.Response: """Serve static files (logo, etc.).""" - filename = request.match_info.get('filename', '') + filename = request.match_info.get("filename", "") # Security: allowlist maps filenames to (relative path, content type). # The user-controlled value is only used as a dict key — the filesystem # path is constructed entirely from hardcoded values, breaking taint flow. - static_dir = Path(__file__).parent / 'static' + static_dir = Path(__file__).parent / "static" allowed_files = { - 'logo.png': (static_dir / 'logo.png', 'image/png'), + "logo.png": (static_dir / "logo.png", "image/png"), } entry = allowed_files.get(filename) @@ -1534,23 +1516,23 @@ async def admin_static(request: web.Request) -> web.Response: if not file_path.exists(): raise web.HTTPNotFound() - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: return web.Response(body=f.read(), content_type=content_type) def setup_admin_routes(app: web.Application) -> None: """Add admin routes to the app.""" - app.router.add_get('/admin', admin_dashboard) - app.router.add_get('/admin/settings', admin_settings) - app.router.add_post('/admin/settings/rate-limits', admin_save_rate_limits) - app.router.add_get('/admin/usage', admin_usage) - app.router.add_get('/admin/about', admin_about) - app.router.add_get('/admin/static/{filename}', admin_static) - app.router.add_get('/admin/login', admin_login_page) - app.router.add_post('/admin/login', admin_login_post) - app.router.add_get('/admin/logout', admin_logout) - app.router.add_post('/admin/keys/generate', admin_generate_key) - app.router.add_post('/admin/keys/{key_id}/revoke', admin_revoke_key) - app.router.add_post('/admin/keys/{key_id}/enable', admin_enable_key) - app.router.add_post('/admin/keys/{key_id}/delete', admin_delete_key) - app.router.add_post('/admin/keys/{key_id}/rate-limit', admin_update_key_rate_limit) + app.router.add_get("/admin", admin_dashboard) + app.router.add_get("/admin/settings", admin_settings) + app.router.add_post("/admin/settings/rate-limits", admin_save_rate_limits) + app.router.add_get("/admin/usage", admin_usage) + app.router.add_get("/admin/about", admin_about) + app.router.add_get("/admin/static/{filename}", admin_static) + app.router.add_get("/admin/login", admin_login_page) + app.router.add_post("/admin/login", admin_login_post) + app.router.add_get("/admin/logout", admin_logout) + app.router.add_post("/admin/keys/generate", admin_generate_key) + app.router.add_post("/admin/keys/{key_id}/revoke", admin_revoke_key) + app.router.add_post("/admin/keys/{key_id}/enable", admin_enable_key) + app.router.add_post("/admin/keys/{key_id}/delete", admin_delete_key) + app.router.add_post("/admin/keys/{key_id}/rate-limit", admin_update_key_rate_limit) diff --git a/auth-proxy/config.py b/auth-proxy/config.py index 3605a06..9472914 100644 --- a/auth-proxy/config.py +++ b/auth-proxy/config.py @@ -8,38 +8,38 @@ import os # File paths -API_KEYS_FILE = os.environ.get('API_KEYS_FILE', '/app/secrets/api_keys.json') -SETTINGS_FILE = os.environ.get('SETTINGS_FILE', '/app/secrets/settings.json') -USAGE_FILE = os.environ.get('USAGE_FILE', '/app/data/usage.json') +API_KEYS_FILE = os.environ.get("API_KEYS_FILE", "/app/secrets/api_keys.json") +SETTINGS_FILE = os.environ.get("SETTINGS_FILE", "/app/secrets/settings.json") +USAGE_FILE = os.environ.get("USAGE_FILE", "/app/data/usage.json") # Server configuration # Default assumes both services run in same container (supervisord setup) # Override with UPSTREAM_URL env var for different deployments -UPSTREAM_URL = os.environ.get('UPSTREAM_URL', 'http://localhost:8081') -PORT = int(os.environ.get('PORT', '8080')) -PRIVATEMODE_API_KEY = os.environ.get('PRIVATEMODE_API_KEY', '') -ADMIN_PASSWORD = os.environ.get('ADMIN_PASSWORD', '') +UPSTREAM_URL = os.environ.get("UPSTREAM_URL", "http://localhost:8081") +PORT = int(os.environ.get("PORT", "8080")) +PRIVATEMODE_API_KEY = os.environ.get("PRIVATEMODE_API_KEY", "") +ADMIN_PASSWORD = os.environ.get("ADMIN_PASSWORD", "") # Trust proxy headers (X-Forwarded-For, X-Forwarded-Proto, etc.) # Set to 'true' when running behind a trusted reverse proxy (nginx, Fly.io, etc.) # When false, X-Forwarded-For headers are ignored to prevent IP spoofing -TRUST_PROXY = os.environ.get('TRUST_PROXY', 'false').lower() in ('true', '1', 'yes') +TRUST_PROXY = os.environ.get("TRUST_PROXY", "false").lower() in ("true", "1", "yes") # Default rate limits -DEFAULT_RATE_LIMIT_REQUESTS = int(os.environ.get('RATE_LIMIT_REQUESTS', '100')) -DEFAULT_RATE_LIMIT_WINDOW = int(os.environ.get('RATE_LIMIT_WINDOW', '60')) -DEFAULT_IP_RATE_LIMIT_REQUESTS = int(os.environ.get('IP_RATE_LIMIT_REQUESTS', '1000')) -DEFAULT_IP_RATE_LIMIT_WINDOW = int(os.environ.get('IP_RATE_LIMIT_WINDOW', '60')) +DEFAULT_RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "100")) +DEFAULT_RATE_LIMIT_WINDOW = int(os.environ.get("RATE_LIMIT_WINDOW", "60")) +DEFAULT_IP_RATE_LIMIT_REQUESTS = int(os.environ.get("IP_RATE_LIMIT_REQUESTS", "1000")) +DEFAULT_IP_RATE_LIMIT_WINDOW = int(os.environ.get("IP_RATE_LIMIT_WINDOW", "60")) # TLS configuration # Set TLS_CERT_FILE and TLS_KEY_FILE to enable HTTPS # If not set, server runs in HTTP mode -TLS_CERT_FILE = os.environ.get('TLS_CERT_FILE', '') -TLS_KEY_FILE = os.environ.get('TLS_KEY_FILE', '') +TLS_CERT_FILE = os.environ.get("TLS_CERT_FILE", "") +TLS_KEY_FILE = os.environ.get("TLS_KEY_FILE", "") TLS_ENABLED = bool(TLS_CERT_FILE and TLS_KEY_FILE) # Force HTTPS - reject non-HTTPS requests (default: true when TLS is enabled) # When behind a trusted proxy, checks X-Forwarded-Proto header # Set to 'false' to allow HTTP (not recommended for production) -_force_https_default = 'true' if TLS_ENABLED else 'false' -FORCE_HTTPS = os.environ.get('FORCE_HTTPS', _force_https_default).lower() in ('true', '1', 'yes') +_force_https_default = "true" if TLS_ENABLED else "false" +FORCE_HTTPS = os.environ.get("FORCE_HTTPS", _force_https_default).lower() in ("true", "1", "yes") diff --git a/auth-proxy/key_manager.py b/auth-proxy/key_manager.py index 69f893a..d5a1e99 100644 --- a/auth-proxy/key_manager.py +++ b/auth-proxy/key_manager.py @@ -9,24 +9,24 @@ - Description/owner info """ -import json import hashlib +import json +import os import secrets -import time import threading -import os +import time from dataclasses import dataclass -from typing import Optional @dataclass class APIKey: """Represents an API key with metadata.""" + key_id: str key_hash: str # We store hash, not plaintext created_at: float - expires_at: Optional[float] = None - rate_limit: Optional[int] = None + expires_at: float | None = None + rate_limit: int | None = None description: str = "" enabled: bool = True @@ -34,9 +34,7 @@ def is_valid(self) -> bool: """Check if key is valid (enabled and not expired).""" if not self.enabled: return False - if self.expires_at and time.time() > self.expires_at: - return False - return True + return not (self.expires_at and time.time() > self.expires_at) class KeyManager: @@ -56,12 +54,12 @@ def _hash_key(self, key: str) -> str: def _load_keys_from_env(self) -> None: """Load keys from environment variables (for cloud deployment).""" # API_KEYS env var: comma-separated list of keys - env_keys = os.environ.get('API_KEYS', '') + env_keys = os.environ.get("API_KEYS", "") if not env_keys: return with self._lock: - for i, key in enumerate(env_keys.split(',')): + for i, key in enumerate(env_keys.split(",")): key = key.strip() if not key: continue @@ -71,7 +69,7 @@ def _load_keys_from_env(self) -> None: key_hash=key_hash, created_at=time.time(), description="From API_KEYS environment variable", - enabled=True + enabled=True, ) self.keys[key_hash] = api_key @@ -94,28 +92,28 @@ def _load_keys(self) -> None: if mtime <= self._last_modified: return - with open(self.keys_file, 'r') as f: + with open(self.keys_file) as f: data = json.load(f) with self._lock: self.keys.clear() - for key_data in data.get('keys', []): + for key_data in data.get("keys", []): # Support both hashed and plaintext keys in config - if 'key_hash' in key_data: - key_hash = key_data['key_hash'] - elif 'key' in key_data: - key_hash = self._hash_key(key_data['key']) + if "key_hash" in key_data: + key_hash = key_data["key_hash"] + elif "key" in key_data: + key_hash = self._hash_key(key_data["key"]) else: continue api_key = APIKey( - key_id=key_data.get('key_id', key_hash[:8]), + key_id=key_data.get("key_id", key_hash[:8]), key_hash=key_hash, - created_at=key_data.get('created_at', time.time()), - expires_at=key_data.get('expires_at'), - rate_limit=key_data.get('rate_limit'), - description=key_data.get('description', ''), - enabled=key_data.get('enabled', True) + created_at=key_data.get("created_at", time.time()), + expires_at=key_data.get("expires_at"), + rate_limit=key_data.get("rate_limit"), + description=key_data.get("description", ""), + enabled=key_data.get("enabled", True), ) self.keys[key_hash] = api_key @@ -135,7 +133,7 @@ def reload_if_changed(self) -> None: except Exception as e: print(f"Error checking keys file: {e}") - def validate_key(self, key: str) -> tuple[bool, Optional[APIKey]]: + def validate_key(self, key: str) -> tuple[bool, APIKey | None]: """ Validate an API key. Returns (is_valid, api_key_obj or None). @@ -149,17 +147,17 @@ def validate_key(self, key: str) -> tuple[bool, Optional[APIKey]]: return True, api_key return False, None - def get_key_info(self, key: str) -> Optional[dict]: + def get_key_info(self, key: str) -> dict | None: """Get non-sensitive info about a key.""" valid, api_key = self.validate_key(key) if not valid or not api_key: return None return { - 'key_id': api_key.key_id, - 'created_at': api_key.created_at, - 'expires_at': api_key.expires_at, - 'rate_limit': api_key.rate_limit, - 'description': api_key.description + "key_id": api_key.key_id, + "created_at": api_key.created_at, + "expires_at": api_key.expires_at, + "rate_limit": api_key.rate_limit, + "description": api_key.description, } @@ -172,27 +170,27 @@ def generate_api_key(prefix: str = "pm") -> str: def create_key_entry( key: str, description: str = "", - expires_in_days: Optional[int] = None, - rate_limit: Optional[int] = None, - store_hash_only: bool = True + expires_in_days: int | None = None, + rate_limit: int | None = None, + store_hash_only: bool = True, ) -> dict: """Create a key entry for the keys file.""" entry = { - 'key_id': f"key_{secrets.token_hex(4)}", - 'created_at': time.time(), - 'description': description, - 'enabled': True + "key_id": f"key_{secrets.token_hex(4)}", + "created_at": time.time(), + "description": description, + "enabled": True, } if store_hash_only: - entry['key_hash'] = hashlib.sha256(key.encode()).hexdigest() + entry["key_hash"] = hashlib.sha256(key.encode()).hexdigest() else: - entry['key'] = key + entry["key"] = key if expires_in_days: - entry['expires_at'] = time.time() + (expires_in_days * 86400) + entry["expires_at"] = time.time() + (expires_in_days * 86400) if rate_limit: - entry['rate_limit'] = rate_limit + entry["rate_limit"] = rate_limit return entry diff --git a/auth-proxy/server.py b/auth-proxy/server.py index 0d93318..53a9ef9 100644 --- a/auth-proxy/server.py +++ b/auth-proxy/server.py @@ -9,21 +9,31 @@ - Key rotation support via hot-reload """ +import json import ssl import time -import json -import asyncio from collections import defaultdict -from aiohttp import web, ClientSession, ClientTimeout -from key_manager import KeyManager -from admin import setup_admin_routes, load_settings -from usage_tracker import get_tracker + +from aiohttp import ClientSession, ClientTimeout, web + +from admin import load_settings, setup_admin_routes from config import ( - API_KEYS_FILE, UPSTREAM_URL, PORT, PRIVATEMODE_API_KEY, - DEFAULT_RATE_LIMIT_REQUESTS, DEFAULT_RATE_LIMIT_WINDOW, - DEFAULT_IP_RATE_LIMIT_REQUESTS, DEFAULT_IP_RATE_LIMIT_WINDOW, - TLS_ENABLED, TLS_CERT_FILE, TLS_KEY_FILE, FORCE_HTTPS, TRUST_PROXY + API_KEYS_FILE, + DEFAULT_IP_RATE_LIMIT_REQUESTS, + DEFAULT_IP_RATE_LIMIT_WINDOW, + DEFAULT_RATE_LIMIT_REQUESTS, + DEFAULT_RATE_LIMIT_WINDOW, + FORCE_HTTPS, + PORT, + PRIVATEMODE_API_KEY, + TLS_CERT_FILE, + TLS_ENABLED, + TLS_KEY_FILE, + TRUST_PROXY, + UPSTREAM_URL, ) +from key_manager import KeyManager +from usage_tracker import get_tracker from utils import get_client_ip @@ -31,12 +41,13 @@ def get_rate_limit_settings() -> dict: """Get current rate limit settings from settings file or defaults.""" settings = load_settings() return { - 'rate_limit_requests': settings.get('rate_limit_requests', DEFAULT_RATE_LIMIT_REQUESTS), - 'rate_limit_window': settings.get('rate_limit_window', DEFAULT_RATE_LIMIT_WINDOW), - 'ip_rate_limit_requests': settings.get('ip_rate_limit_requests', DEFAULT_IP_RATE_LIMIT_REQUESTS), - 'ip_rate_limit_window': settings.get('ip_rate_limit_window', DEFAULT_IP_RATE_LIMIT_WINDOW), + "rate_limit_requests": settings.get("rate_limit_requests", DEFAULT_RATE_LIMIT_REQUESTS), + "rate_limit_window": settings.get("rate_limit_window", DEFAULT_RATE_LIMIT_WINDOW), + "ip_rate_limit_requests": settings.get("ip_rate_limit_requests", DEFAULT_IP_RATE_LIMIT_REQUESTS), + "ip_rate_limit_window": settings.get("ip_rate_limit_window", DEFAULT_IP_RATE_LIMIT_WINDOW), } + # Rate limiting storage: key_id -> list of request timestamps rate_limit_store: dict[str, list[float]] = defaultdict(list) @@ -50,12 +61,12 @@ def get_rate_limit_settings() -> dict: def extract_api_key(request: web.Request) -> str | None: """Extract API key from request headers.""" # Check Authorization header (Bearer token) - auth_header = request.headers.get('Authorization', '') - if auth_header.startswith('Bearer '): + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): return auth_header[7:] # Check X-API-Key header - api_key = request.headers.get('X-API-Key') + api_key = request.headers.get("X-API-Key") if api_key: return api_key @@ -69,8 +80,8 @@ def check_global_rate_limit() -> tuple[bool, int, int, int]: """ global global_rate_limit_store settings = get_rate_limit_settings() - limit = settings['rate_limit_requests'] - window = settings['rate_limit_window'] + limit = settings["rate_limit_requests"] + window = settings["rate_limit_window"] now = time.time() window_start = now - window @@ -99,15 +110,13 @@ def check_per_key_rate_limit(key_id: str, limit: int | None) -> tuple[bool, int, return True, -1, 0 settings = get_rate_limit_settings() - window = settings['rate_limit_window'] + window = settings["rate_limit_window"] now = time.time() window_start = now - window # Clean old entries - rate_limit_store[key_id] = [ - ts for ts in rate_limit_store[key_id] if ts > window_start - ] + rate_limit_store[key_id] = [ts for ts in rate_limit_store[key_id] if ts > window_start] current_count = len(rate_limit_store[key_id]) remaining = max(0, limit - current_count) @@ -122,8 +131,8 @@ def check_per_key_rate_limit(key_id: str, limit: int | None) -> tuple[bool, int, def check_ip_rate_limit(ip: str) -> tuple[bool, int, int, int]: """Check if IP is within global rate limit. Returns (allowed, remaining, limit, window).""" settings = get_rate_limit_settings() - ip_limit = settings['ip_rate_limit_requests'] - ip_window = settings['ip_rate_limit_window'] + ip_limit = settings["ip_rate_limit_requests"] + ip_window = settings["ip_rate_limit_window"] now = time.time() window_start = now - ip_window @@ -143,24 +152,24 @@ def check_ip_rate_limit(ip: str) -> tuple[bool, int, int, int]: def detect_endpoint_type(path: str) -> str: """Detect the API endpoint type from path.""" - if '/chat/completions' in path: - return 'chat' - elif '/embeddings' in path: - return 'embeddings' - elif '/audio/transcriptions' in path: - return 'transcriptions' - elif '/completions' in path: - return 'completions' - return 'other' + if "/chat/completions" in path: + return "chat" + elif "/embeddings" in path: + return "embeddings" + elif "/audio/transcriptions" in path: + return "transcriptions" + elif "/completions" in path: + return "completions" + return "other" def extract_model_from_request(body: bytes) -> str: """Extract model name from request body.""" try: data = json.loads(body) - return data.get('model', 'unknown') + return data.get("model", "unknown") except (json.JSONDecodeError, UnicodeDecodeError): - return 'unknown' + return "unknown" def extract_usage_from_response(response_body: bytes, endpoint: str) -> dict: @@ -168,31 +177,26 @@ def extract_usage_from_response(response_body: bytes, endpoint: str) -> dict: Extract token usage from response body. Only extracts usage metadata, never the actual content. """ - usage = { - 'prompt_tokens': 0, - 'completion_tokens': 0, - 'total_tokens': 0, - 'model': 'unknown' - } + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "model": "unknown"} try: data = json.loads(response_body) # Get model from response - usage['model'] = data.get('model', 'unknown') + usage["model"] = data.get("model", "unknown") # Chat completions and completions have usage object - if 'usage' in data: - usage_data = data['usage'] - usage['prompt_tokens'] = usage_data.get('prompt_tokens', 0) - usage['completion_tokens'] = usage_data.get('completion_tokens', 0) - usage['total_tokens'] = usage_data.get('total_tokens', 0) + if "usage" in data: + usage_data = data["usage"] + usage["prompt_tokens"] = usage_data.get("prompt_tokens", 0) + usage["completion_tokens"] = usage_data.get("completion_tokens", 0) + usage["total_tokens"] = usage_data.get("total_tokens", 0) # Embeddings also have usage - elif endpoint == 'embeddings' and 'usage' in data: - usage_data = data['usage'] - usage['total_tokens'] = usage_data.get('total_tokens', 0) - usage['prompt_tokens'] = usage_data.get('prompt_tokens', usage['total_tokens']) + elif endpoint == "embeddings" and "usage" in data: + usage_data = data["usage"] + usage["total_tokens"] = usage_data.get("total_tokens", 0) + usage["prompt_tokens"] = usage_data.get("prompt_tokens", usage["total_tokens"]) except (json.JSONDecodeError, UnicodeDecodeError, KeyError): pass @@ -210,40 +214,42 @@ async def proxy_request(request: web.Request, session: ClientSession) -> web.Res # Forward headers (excluding hop-by-hop headers) headers = {} - hop_by_hop = {'connection', 'keep-alive', 'proxy-authenticate', - 'proxy-authorization', 'te', 'trailers', 'transfer-encoding', - 'upgrade', 'host'} + hop_by_hop = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + "host", + } for key, value in request.headers.items(): - if key.lower() not in hop_by_hop: - # Don't forward our auth headers to upstream - if key.lower() not in ('authorization', 'x-api-key'): - headers[key] = value + if key.lower() not in hop_by_hop and key.lower() not in ("authorization", "x-api-key"): + headers[key] = value # Add Privatemode API key for upstream authentication if PRIVATEMODE_API_KEY: - headers['Authorization'] = f'Bearer {PRIVATEMODE_API_KEY}' + headers["Authorization"] = f"Bearer {PRIVATEMODE_API_KEY}" # Read request body body = await request.read() # Detect endpoint type and model for usage tracking endpoint = detect_endpoint_type(path) - request_model = extract_model_from_request(body) if body else 'unknown' + request_model = extract_model_from_request(body) if body else "unknown" # Track audio file size for transcription requests audio_bytes = 0 - if endpoint == 'transcriptions': + if endpoint == "transcriptions": audio_bytes = len(body) try: # Forward request to upstream async with session.request( - method=request.method, - url=upstream_url, - headers=headers, - data=body, - allow_redirects=False + method=request.method, url=upstream_url, headers=headers, data=body, allow_redirects=False ) as upstream_response: # Read response response_body = await upstream_response.read() @@ -255,38 +261,28 @@ async def proxy_request(request: web.Request, session: ClientSession) -> web.Res response_headers[key] = value # Track usage if request was successful and we have a key_id - if upstream_response.status == 200 and 'key_id' in request: + if upstream_response.status == 200 and "key_id" in request: usage = extract_usage_from_response(response_body, endpoint) - model = usage['model'] if usage['model'] != 'unknown' else request_model + model = usage["model"] if usage["model"] != "unknown" else request_model tracker = get_tracker() tracker.record_usage( - key_id=request['key_id'], + key_id=request["key_id"], model=model, endpoint=endpoint, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'], - total_tokens=usage['total_tokens'], - audio_bytes=audio_bytes + prompt_tokens=usage["prompt_tokens"], + completion_tokens=usage["completion_tokens"], + total_tokens=usage["total_tokens"], + audio_bytes=audio_bytes, ) - return web.Response( - status=upstream_response.status, - headers=response_headers, - body=response_body - ) + return web.Response(status=upstream_response.status, headers=response_headers, body=response_body) - except asyncio.TimeoutError: - return web.json_response( - {'error': 'Upstream timeout'}, - status=504 - ) + except TimeoutError: + return web.json_response({"error": "Upstream timeout"}, status=504) except Exception as e: print(f"Proxy error: {e}") - return web.json_response( - {'error': 'Proxy error'}, - status=502 - ) + return web.json_response({"error": "Proxy error"}, status=502) @web.middleware @@ -300,18 +296,15 @@ async def https_enforcement_middleware(request: web.Request, handler): # If behind a trusted proxy, also check X-Forwarded-Proto if not is_https and TRUST_PROXY: - forwarded_proto = request.headers.get('X-Forwarded-Proto', '') - is_https = forwarded_proto.lower() == 'https' + forwarded_proto = request.headers.get("X-Forwarded-Proto", "") + is_https = forwarded_proto.lower() == "https" # Allow health checks over HTTP for internal load balancer probes - if request.path == '/health': + if request.path == "/health": return await handler(request) if not is_https: - return web.json_response( - {'error': 'HTTPS required. Please use a secure connection.'}, - status=403 - ) + return web.json_response({"error": "HTTPS required. Please use a secure connection."}, status=403) return await handler(request) @@ -322,8 +315,8 @@ async def security_headers_middleware(request: web.Request, handler): response = await handler(request) # Content Security Policy - restrictive for admin UI - if request.path.startswith('/admin'): - response.headers['Content-Security-Policy'] = ( + if request.path.startswith("/admin"): + response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline'; " "style-src 'self' 'unsafe-inline'; " @@ -333,27 +326,23 @@ async def security_headers_middleware(request: web.Request, handler): ) # Prevent MIME type sniffing - response.headers['X-Content-Type-Options'] = 'nosniff' + response.headers["X-Content-Type-Options"] = "nosniff" # Prevent clickjacking - response.headers['X-Frame-Options'] = 'DENY' + response.headers["X-Frame-Options"] = "DENY" # XSS protection (legacy, but still useful for older browsers) - response.headers['X-XSS-Protection'] = '1; mode=block' + response.headers["X-XSS-Protection"] = "1; mode=block" # Referrer policy - response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin' + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" # Permissions policy (restrict sensitive APIs) - response.headers['Permissions-Policy'] = ( - 'geolocation=(), microphone=(), camera=()' - ) + response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" # HSTS - enforce HTTPS for 1 year when TLS is enabled if TLS_ENABLED or FORCE_HTTPS: - response.headers['Strict-Transport-Security'] = ( - 'max-age=31536000; includeSubDomains' - ) + response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" return response @@ -364,49 +353,45 @@ def create_auth_middleware(key_manager: KeyManager): @web.middleware async def auth_middleware(request: web.Request, handler): # Skip auth for health check and admin routes - if request.path == '/health' or request.path.startswith('/admin'): + if request.path == "/health" or request.path.startswith("/admin"): return await handler(request) # Check global IP rate limit first client_ip = get_client_ip(request) - ip_allowed, ip_remaining, ip_limit, ip_window = check_ip_rate_limit(client_ip) + ip_allowed, _ip_remaining, ip_limit, ip_window = check_ip_rate_limit(client_ip) if not ip_allowed: return web.json_response( - {'error': 'Global rate limit exceeded'}, + {"error": "Global rate limit exceeded"}, status=429, headers={ - 'X-RateLimit-Limit': str(ip_limit), - 'X-RateLimit-Remaining': '0', - 'X-RateLimit-Reset': str(int(time.time()) + ip_window) - } + "X-RateLimit-Limit": str(ip_limit), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time()) + ip_window), + }, ) # Extract and validate API key api_key = extract_api_key(request) if not api_key: return web.json_response( - {'error': 'Missing API key. Use Authorization: Bearer or X-API-Key header'}, - status=401 + {"error": "Missing API key. Use Authorization: Bearer or X-API-Key header"}, status=401 ) valid, key_obj = key_manager.validate_key(api_key) if not valid: - return web.json_response( - {'error': 'Invalid or expired API key'}, - status=401 - ) + return web.json_response({"error": "Invalid or expired API key"}, status=401) # Check global rate limit first (shared across ALL keys) global_allowed, global_remaining, global_limit, global_window = check_global_rate_limit() if not global_allowed: return web.json_response( - {'error': 'Global rate limit exceeded'}, + {"error": "Global rate limit exceeded"}, status=429, headers={ - 'X-RateLimit-Limit': str(global_limit), - 'X-RateLimit-Remaining': '0', - 'X-RateLimit-Reset': str(int(time.time()) + global_window) - } + "X-RateLimit-Limit": str(global_limit), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time()) + global_window), + }, ) # Check per-key rate limit (only if key has a specific limit set) @@ -415,24 +400,24 @@ async def auth_middleware(request: web.Request, handler): if not key_allowed: return web.json_response( - {'error': 'Per-key rate limit exceeded'}, + {"error": "Per-key rate limit exceeded"}, status=429, headers={ - 'X-RateLimit-Limit': str(key_limit), - 'X-RateLimit-Remaining': '0', - 'X-RateLimit-Reset': str(int(time.time()) + global_window) - } + "X-RateLimit-Limit": str(key_limit), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time()) + global_window), + }, ) # Add rate limit headers to request for handler # Show per-key limit if set, otherwise show global if per_key_limit: - request['rate_limit_remaining'] = key_remaining - request['rate_limit_limit'] = key_limit + request["rate_limit_remaining"] = key_remaining + request["rate_limit_limit"] = key_limit else: - request['rate_limit_remaining'] = global_remaining - request['rate_limit_limit'] = global_limit - request['key_id'] = key_obj.key_id + request["rate_limit_remaining"] = global_remaining + request["rate_limit_limit"] = global_limit + request["key_id"] = key_obj.key_id return await handler(request) @@ -441,34 +426,34 @@ async def auth_middleware(request: web.Request, handler): async def health_handler(request: web.Request) -> web.Response: """Health check endpoint.""" - return web.json_response({'status': 'healthy'}) + return web.json_response({"status": "healthy"}) async def key_info_handler(request: web.Request) -> web.Response: """Return info about the current API key (non-sensitive).""" - key_manager = request.app['key_manager'] + key_manager = request.app["key_manager"] api_key = extract_api_key(request) if not api_key: - return web.json_response({'error': 'No API key provided'}, status=401) + return web.json_response({"error": "No API key provided"}, status=401) info = key_manager.get_key_info(api_key) if not info: - return web.json_response({'error': 'Invalid API key'}, status=401) + return web.json_response({"error": "Invalid API key"}, status=401) return web.json_response(info) async def catch_all_handler(request: web.Request) -> web.Response: """Proxy all other requests to upstream.""" - session = request.app['client_session'] + session = request.app["client_session"] response = await proxy_request(request, session) # Add rate limit headers - if 'rate_limit_remaining' in request: - response.headers['X-RateLimit-Remaining'] = str(request['rate_limit_remaining']) - response.headers['X-RateLimit-Limit'] = str(request.get('rate_limit_limit', 100)) + if "rate_limit_remaining" in request: + response.headers["X-RateLimit-Remaining"] = str(request["rate_limit_remaining"]) + response.headers["X-RateLimit-Limit"] = str(request.get("rate_limit_limit", 100)) return response @@ -476,13 +461,13 @@ async def catch_all_handler(request: web.Request) -> web.Response: async def on_startup(app: web.Application): """Initialize client session on startup.""" timeout = ClientTimeout(total=300) # 5 minute timeout for LLM requests - app['client_session'] = ClientSession(timeout=timeout) + app["client_session"] = ClientSession(timeout=timeout) print(f"Auth proxy started, forwarding to {UPSTREAM_URL}") async def on_cleanup(app: web.Application): """Cleanup client session and flush usage data.""" - await app['client_session'].close() + await app["client_session"].close() # Flush usage data to disk get_tracker().flush() @@ -492,23 +477,19 @@ def create_app() -> web.Application: key_manager = KeyManager(API_KEYS_FILE) # Middleware order: HTTPS enforcement -> security headers -> auth - middlewares = [ - https_enforcement_middleware, - security_headers_middleware, - create_auth_middleware(key_manager) - ] + middlewares = [https_enforcement_middleware, security_headers_middleware, create_auth_middleware(key_manager)] app = web.Application(middlewares=middlewares) - app['key_manager'] = key_manager + app["key_manager"] = key_manager # Routes - app.router.add_get('/health', health_handler) - app.router.add_get('/auth/key-info', key_info_handler) + app.router.add_get("/health", health_handler) + app.router.add_get("/auth/key-info", key_info_handler) # Admin UI routes setup_admin_routes(app) # Catch-all for proxying (must be last) - app.router.add_route('*', '/{path_info:.*}', catch_all_handler) + app.router.add_route("*", "/{path_info:.*}", catch_all_handler) app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) @@ -528,7 +509,7 @@ def create_ssl_context(): return ssl_context -if __name__ == '__main__': +if __name__ == "__main__": app = create_app() ssl_context = create_ssl_context() @@ -537,4 +518,4 @@ def create_ssl_context(): else: print(f"Starting HTTP server on port {PORT} (TLS not configured)") - web.run_app(app, host='0.0.0.0', port=PORT, ssl_context=ssl_context) + web.run_app(app, host="0.0.0.0", port=PORT, ssl_context=ssl_context) diff --git a/auth-proxy/usage_tracker.py b/auth-proxy/usage_tracker.py index 7aaa9c8..14595ba 100644 --- a/auth-proxy/usage_tracker.py +++ b/auth-proxy/usage_tracker.py @@ -5,39 +5,39 @@ Calculates costs based on Privatemode pricing. """ -import os import json +import os +import threading import time +from collections import defaultdict +from dataclasses import asdict, dataclass from datetime import datetime, timedelta from pathlib import Path -from dataclasses import dataclass, asdict -from typing import Optional -from collections import defaultdict -import threading from config import USAGE_FILE # Privatemode pricing (EUR per unit) PRICING = { # Chat models: EUR per 1M tokens - 'gpt-oss-120b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000}, - 'llama-3.3-70b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000}, - 'gemma-3-27b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000}, - 'qwen3-coder-30b-a3b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000}, + "gpt-oss-120b": {"type": "token", "rate": 5.0, "per": 1_000_000}, + "llama-3.3-70b": {"type": "token", "rate": 5.0, "per": 1_000_000}, + "gemma-3-27b": {"type": "token", "rate": 5.0, "per": 1_000_000}, + "qwen3-coder-30b-a3b": {"type": "token", "rate": 5.0, "per": 1_000_000}, # Embedding models: EUR per 1M tokens - 'multilingual-e5': {'type': 'token', 'rate': 0.13, 'per': 1_000_000}, - 'qwen3-embedding-4b': {'type': 'token', 'rate': 0.13, 'per': 1_000_000}, + "multilingual-e5": {"type": "token", "rate": 0.13, "per": 1_000_000}, + "qwen3-embedding-4b": {"type": "token", "rate": 0.13, "per": 1_000_000}, # Speech-to-text: EUR per MB - 'whisper-large-v3': {'type': 'audio', 'rate': 0.096, 'per': 1}, # per MB + "whisper-large-v3": {"type": "audio", "rate": 0.096, "per": 1}, # per MB } # Default pricing for unknown models -DEFAULT_PRICING = {'type': 'token', 'rate': 5.0, 'per': 1_000_000} +DEFAULT_PRICING = {"type": "token", "rate": 5.0, "per": 1_000_000} @dataclass class UsageRecord: """Single usage record.""" + timestamp: float key_id: str model: str @@ -52,7 +52,7 @@ class UsageRecord: class UsageTracker: """Thread-safe usage tracker with file persistence.""" - def __init__(self, usage_file: str = None): + def __init__(self, usage_file: str | None = None): self.usage_file = usage_file or USAGE_FILE self._lock = threading.Lock() self._records: list[dict] = [] @@ -62,32 +62,32 @@ def _load(self): """Load usage data from file.""" try: if os.path.exists(self.usage_file): - with open(self.usage_file, 'r') as f: + with open(self.usage_file) as f: data = json.load(f) - self._records = data.get('records', []) - except (json.JSONDecodeError, IOError): + self._records = data.get("records", []) + except (OSError, json.JSONDecodeError): self._records = [] def _save(self): """Save usage data to file.""" try: Path(self.usage_file).parent.mkdir(parents=True, exist_ok=True) - with open(self.usage_file, 'w') as f: - json.dump({'records': self._records}, f) - except IOError as e: + with open(self.usage_file, "w") as f: + json.dump({"records": self._records}, f) + except OSError as e: print(f"Failed to save usage data: {e}") def calculate_cost(self, model: str, tokens: int = 0, audio_bytes: int = 0) -> float: """Calculate cost in EUR based on model and usage.""" pricing = PRICING.get(model, DEFAULT_PRICING) - if pricing['type'] == 'audio': + if pricing["type"] == "audio": # Audio: cost per MB mb = audio_bytes / (1024 * 1024) - return mb * pricing['rate'] + return mb * pricing["rate"] else: # Tokens: cost per million tokens - return (tokens / pricing['per']) * pricing['rate'] + return (tokens / pricing["per"]) * pricing["rate"] def record_usage( self, @@ -97,7 +97,7 @@ def record_usage( prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0, - audio_bytes: int = 0 + audio_bytes: int = 0, ): """Record a usage event.""" if total_tokens == 0 and prompt_tokens + completion_tokens > 0: @@ -114,7 +114,7 @@ def record_usage( completion_tokens=completion_tokens, total_tokens=total_tokens, audio_bytes=audio_bytes, - cost_eur=cost + cost_eur=cost, ) with self._lock: @@ -129,10 +129,7 @@ def flush(self): self._save() def get_usage_summary( - self, - key_id: Optional[str] = None, - start_time: Optional[float] = None, - end_time: Optional[float] = None + self, key_id: str | None = None, start_time: float | None = None, end_time: float | None = None ) -> dict: """ Get usage summary for a key or all keys within a time range. @@ -152,51 +149,47 @@ def get_usage_summary( # Filter records filtered = [] for r in records: - if key_id and r['key_id'] != key_id: + if key_id and r["key_id"] != key_id: continue - if start_time and r['timestamp'] < start_time: + if start_time and r["timestamp"] < start_time: continue - if end_time and r['timestamp'] > end_time: + if end_time and r["timestamp"] > end_time: continue filtered.append(r) # Aggregate summary = { - 'total_tokens': 0, - 'total_audio_bytes': 0, - 'total_cost_eur': 0.0, - 'by_model': defaultdict(lambda: {'tokens': 0, 'audio_bytes': 0, 'cost': 0.0, 'requests': 0}), - 'by_endpoint': defaultdict(lambda: {'tokens': 0, 'requests': 0, 'cost': 0.0}), - 'requests': len(filtered) + "total_tokens": 0, + "total_audio_bytes": 0, + "total_cost_eur": 0.0, + "by_model": defaultdict(lambda: {"tokens": 0, "audio_bytes": 0, "cost": 0.0, "requests": 0}), + "by_endpoint": defaultdict(lambda: {"tokens": 0, "requests": 0, "cost": 0.0}), + "requests": len(filtered), } for r in filtered: - summary['total_tokens'] += r['total_tokens'] - summary['total_audio_bytes'] += r['audio_bytes'] - summary['total_cost_eur'] += r['cost_eur'] + summary["total_tokens"] += r["total_tokens"] + summary["total_audio_bytes"] += r["audio_bytes"] + summary["total_cost_eur"] += r["cost_eur"] - model = r['model'] - summary['by_model'][model]['tokens'] += r['total_tokens'] - summary['by_model'][model]['audio_bytes'] += r['audio_bytes'] - summary['by_model'][model]['cost'] += r['cost_eur'] - summary['by_model'][model]['requests'] += 1 + model = r["model"] + summary["by_model"][model]["tokens"] += r["total_tokens"] + summary["by_model"][model]["audio_bytes"] += r["audio_bytes"] + summary["by_model"][model]["cost"] += r["cost_eur"] + summary["by_model"][model]["requests"] += 1 - endpoint = r['endpoint'] - summary['by_endpoint'][endpoint]['tokens'] += r['total_tokens'] - summary['by_endpoint'][endpoint]['requests'] += 1 - summary['by_endpoint'][endpoint]['cost'] += r['cost_eur'] + endpoint = r["endpoint"] + summary["by_endpoint"][endpoint]["tokens"] += r["total_tokens"] + summary["by_endpoint"][endpoint]["requests"] += 1 + summary["by_endpoint"][endpoint]["cost"] += r["cost_eur"] # Convert defaultdicts to regular dicts - summary['by_model'] = dict(summary['by_model']) - summary['by_endpoint'] = dict(summary['by_endpoint']) + summary["by_model"] = dict(summary["by_model"]) + summary["by_endpoint"] = dict(summary["by_endpoint"]) return summary - def get_usage_by_key( - self, - start_time: Optional[float] = None, - end_time: Optional[float] = None - ) -> dict[str, dict]: + def get_usage_by_key(self, start_time: float | None = None, end_time: float | None = None) -> dict[str, dict]: """Get usage breakdown by key.""" with self._lock: records = self._records.copy() @@ -204,34 +197,25 @@ def get_usage_by_key( # Filter by time filtered = [] for r in records: - if start_time and r['timestamp'] < start_time: + if start_time and r["timestamp"] < start_time: continue - if end_time and r['timestamp'] > end_time: + if end_time and r["timestamp"] > end_time: continue filtered.append(r) # Group by key - by_key = defaultdict(lambda: { - 'tokens': 0, - 'audio_bytes': 0, - 'cost_eur': 0.0, - 'requests': 0 - }) + by_key = defaultdict(lambda: {"tokens": 0, "audio_bytes": 0, "cost_eur": 0.0, "requests": 0}) for r in filtered: - key_id = r['key_id'] - by_key[key_id]['tokens'] += r['total_tokens'] - by_key[key_id]['audio_bytes'] += r['audio_bytes'] - by_key[key_id]['cost_eur'] += r['cost_eur'] - by_key[key_id]['requests'] += 1 + key_id = r["key_id"] + by_key[key_id]["tokens"] += r["total_tokens"] + by_key[key_id]["audio_bytes"] += r["audio_bytes"] + by_key[key_id]["cost_eur"] += r["cost_eur"] + by_key[key_id]["requests"] += 1 return dict(by_key) - def get_daily_breakdown( - self, - key_id: Optional[str] = None, - days: int = 30 - ) -> list[dict]: + def get_daily_breakdown(self, key_id: str | None = None, days: int = 30) -> list[dict]: """Get daily usage breakdown for the last N days.""" with self._lock: records = self._records.copy() @@ -244,28 +228,25 @@ def get_daily_breakdown( # Filter records filtered = [] for r in records: - if r['timestamp'] < start_time: + if r["timestamp"] < start_time: continue - if key_id and r['key_id'] != key_id: + if key_id and r["key_id"] != key_id: continue filtered.append(r) # Group by day - daily = defaultdict(lambda: {'tokens': 0, 'cost_eur': 0.0, 'requests': 0}) + daily = defaultdict(lambda: {"tokens": 0, "cost_eur": 0.0, "requests": 0}) for r in filtered: - day = datetime.fromtimestamp(r['timestamp']).strftime('%Y-%m-%d') - daily[day]['tokens'] += r['total_tokens'] - daily[day]['cost_eur'] += r['cost_eur'] - daily[day]['requests'] += 1 + day = datetime.fromtimestamp(r["timestamp"]).strftime("%Y-%m-%d") + daily[day]["tokens"] += r["total_tokens"] + daily[day]["cost_eur"] += r["cost_eur"] + daily[day]["requests"] += 1 # Convert to sorted list result = [] for date_str in sorted(daily.keys()): - result.append({ - 'date': date_str, - **daily[date_str] - }) + result.append({"date": date_str, **daily[date_str]}) return result @@ -284,29 +265,29 @@ def get_time_range(period: str) -> tuple[float, float]: now = datetime.now() end = now.timestamp() - if period == 'today': + if period == "today": start = now.replace(hour=0, minute=0, second=0, microsecond=0) - elif period == 'yesterday': + elif period == "yesterday": yesterday = now - timedelta(days=1) start = yesterday.replace(hour=0, minute=0, second=0, microsecond=0) end = now.replace(hour=0, minute=0, second=0, microsecond=0).timestamp() - elif period == 'week': + elif period == "week": start = now - timedelta(days=7) - elif period == 'month': + elif period == "month": start = now - timedelta(days=30) - elif period == 'year': + elif period == "year": start = now - timedelta(days=365) - elif period == 'all': + elif period == "all": return None, None else: # Default to all time return None, None - return start.timestamp() if hasattr(start, 'timestamp') else start, end + return start.timestamp() if hasattr(start, "timestamp") else start, end # Global instance -_tracker: Optional[UsageTracker] = None +_tracker: UsageTracker | None = None def get_tracker() -> UsageTracker: diff --git a/auth-proxy/utils.py b/auth-proxy/utils.py index 0d83f7f..8713dc9 100644 --- a/auth-proxy/utils.py +++ b/auth-proxy/utils.py @@ -3,6 +3,7 @@ """ from aiohttp import web + from config import TRUST_PROXY @@ -14,7 +15,7 @@ def get_client_ip(request: web.Request) -> str: This prevents IP spoofing when not behind a trusted reverse proxy. """ if TRUST_PROXY: - forwarded = request.headers.get('X-Forwarded-For', '') + forwarded = request.headers.get("X-Forwarded-For", "") if forwarded: - return forwarded.split(',')[0].strip() - return request.remote or 'unknown' + return forwarded.split(",")[0].strip() + return request.remote or "unknown" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..efe6fe2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[tool.ruff] +target-version = "py312" +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "W", # pycodestyle warnings + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "RUF", # ruff-specific rules +] +ignore = [ + "B904", # allow raise without from in some cases + "SIM108", # allow if-else blocks for readability + "RUF001", # allow unicode characters (e.g. × symbol in admin UI) +] + +[tool.ruff.lint.per-file-ignores] +"auth-proxy/admin.py" = [ + "E501", # HTML template strings are inherently long + "SIM103", # inline conditions would reduce readability in HTML templates + "SIM105", # contextlib.suppress less readable in HTML-heavy file +] + +[tool.ruff.lint.isort] +known-first-party = ["key_manager", "admin", "usage_tracker", "config", "utils"] + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +check_untyped_defs = true +no_strict_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +strict_equality = true diff --git a/scripts/manage_keys.py b/scripts/manage_keys.py index f5e7d40..9e51d09 100755 --- a/scripts/manage_keys.py +++ b/scripts/manage_keys.py @@ -45,7 +45,7 @@ def load_keys() -> dict: def save_keys(data: dict) -> None: """Save keys to file.""" KEYS_FILE.parent.mkdir(parents=True, exist_ok=True) - with open(KEYS_FILE, 'w') as f: + with open(KEYS_FILE, "w") as f: json.dump(data, f, indent=2) print(f"Keys saved to {KEYS_FILE}") @@ -82,7 +82,7 @@ def cmd_generate(args): "key_hash": key_hash, "created_at": time.time(), "description": args.description or "", - "enabled": True + "enabled": True, } if args.expires_days: @@ -193,7 +193,7 @@ def cmd_rotate(args): "created_at": time.time(), "description": old_key.get("description", "") + " (rotated)", "enabled": True, - "rotated_from": args.key_id + "rotated_from": args.key_id, } if args.expires_days: @@ -242,7 +242,7 @@ def main(): parser = argparse.ArgumentParser( description="Manage API keys for Privatemode proxy", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__ + epilog=__doc__, ) subparsers = parser.add_subparsers(dest="command", required=True) diff --git a/scripts/scrape_docs.py b/scripts/scrape_docs.py index 8fb2d44..188cf9b 100644 --- a/scripts/scrape_docs.py +++ b/scripts/scrape_docs.py @@ -6,11 +6,11 @@ import os import re from pathlib import Path +from urllib.parse import urljoin, urlparse import requests from bs4 import BeautifulSoup from markdownify import markdownify as md -from urllib.parse import urljoin, urlparse BASE_URL = "https://docs.privatemode.ai" DOCS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docs") @@ -40,7 +40,7 @@ def get_page(url: str) -> BeautifulSoup | None: print(f"Fetching: {full_url}") response = requests.get(full_url, timeout=10) response.raise_for_status() - return BeautifulSoup(response.text, 'html.parser') + return BeautifulSoup(response.text, "html.parser") except Exception as e: print(f"Error fetching {url}: {e}") return None @@ -50,45 +50,52 @@ def extract_nav_links(soup: BeautifulSoup) -> list[str]: """Extract navigation links from the page.""" links = [] # Look for sidebar navigation - for a in soup.find_all('a', href=True): - href = a['href'] - if href.startswith('/') and not href.startswith('//'): - if not any(x in href for x in ['#', 'mailto:', 'javascript:']): - links.append(href) + for a in soup.find_all("a", href=True): + href = a["href"] + if ( + href.startswith("/") + and not href.startswith("//") + and not any(x in href for x in ["#", "mailto:", "javascript:"]) + ): + links.append(href) return list(set(links)) def extract_content(soup: BeautifulSoup) -> tuple[str, str]: """Extract main content from the page.""" # Try to find main content area - main = soup.find('main') or soup.find('article') or soup.find(class_=re.compile(r'content|docs|markdown')) + main = soup.find("main") or soup.find("article") or soup.find(class_=re.compile(r"content|docs|markdown")) if not main: # Fallback: try to find the largest div with text - main = soup.find('body') + main = soup.find("body") if not main: return "", "" # Get title title = "" - h1 = main.find('h1') + h1 = main.find("h1") if h1: title = h1.get_text(strip=True) else: - title_tag = soup.find('title') + title_tag = soup.find("title") if title_tag: - title = title_tag.get_text(strip=True).split('|')[0].strip() + title = title_tag.get_text(strip=True).split("|")[0].strip() # Remove navigation, footer, etc. - for tag in main.find_all(['nav', 'footer', 'header', 'script', 'style']): + for tag in main.find_all(["nav", "footer", "header", "script", "style"]): tag.decompose() # Convert to markdown - content = md(str(main), heading_style="ATX", code_language_callback=lambda el: "python" if "python" in str(el).lower() else "bash") + content = md( + str(main), + heading_style="ATX", + code_language_callback=lambda el: "python" if "python" in str(el).lower() else "bash", + ) # Clean up markdown - content = re.sub(r'\n{3,}', '\n\n', content) + content = re.sub(r"\n{3,}", "\n\n", content) content = content.strip() return title, content @@ -96,14 +103,14 @@ def extract_content(soup: BeautifulSoup) -> tuple[str, str]: def url_to_filename(url: str) -> str: """Convert URL to filename safely, preventing path traversal attacks.""" - path = urlparse(url).path.strip('/') + path = urlparse(url).path.strip("/") if not path: return "index.md" # Replace slashes with underscores - name = path.replace('/', '_') + name = path.replace("/", "_") # Remove any path traversal attempts and dangerous characters # Only allow alphanumeric, underscore, and hyphen (block backslashes too) - name = re.sub(r'[^a-zA-Z0-9_-]', '', name) + name = re.sub(r"[^a-zA-Z0-9_-]", "", name) if not name: return "index.md" return f"{name}.md" @@ -138,25 +145,25 @@ def scrape_all(): try: filepath = safe_join_path(DOCS_DIR, filename) - with open(filepath, 'w') as f: + with open(filepath, "w") as f: if title: f.write(f"# {title}\n\n") f.write(content) print(f" Saved: {filename}") - pages[url] = {'title': title, 'file': filename} + pages[url] = {"title": title, "file": filename} except (ValueError, requests.RequestException) as e: print(f" Skipping {url}: {e}") continue # Discover more links for link in extract_nav_links(soup): - if link not in visited and urljoin(BASE_URL, link) not in visited: - if link.startswith('/') and not link.startswith('//'): - to_visit.append(link) + full_url = urljoin(BASE_URL, link) + if link not in visited and full_url not in visited and link.startswith("/") and not link.startswith("//"): + to_visit.append(link) # Create index - with open(safe_join_path(DOCS_DIR, "README.md"), 'w') as f: + with open(safe_join_path(DOCS_DIR, "README.md"), "w") as f: f.write("# Privatemode Documentation\n\n") f.write("Scraped from https://docs.privatemode.ai\n\n") f.write("## Pages\n\n") diff --git a/tests/conftest.py b/tests/conftest.py index e424cab..5675e47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,22 +12,22 @@ import pytest # Add auth-proxy to path so imports work -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) # ── Environment setup (must happen before importing app modules) ── # Provide required env vars before any module-level code runs. # PBKDF2_SALT must be >= 16 bytes or admin.py raises ValueError at import time. -os.environ.setdefault('PBKDF2_SALT', 'test-salt-value-1234567890') -os.environ.setdefault('ADMIN_PASSWORD', 'test-admin-password') -os.environ.setdefault('PRIVATEMODE_API_KEY', 'pm_test_upstream_key') -os.environ.setdefault('API_KEYS_FILE', '') -os.environ.setdefault('SETTINGS_FILE', '') -os.environ.setdefault('USAGE_FILE', '') -os.environ.setdefault('UPSTREAM_URL', 'http://localhost:19999') -os.environ.setdefault('TRUST_PROXY', 'false') -os.environ.setdefault('FORCE_HTTPS', 'false') +os.environ.setdefault("PBKDF2_SALT", "test-salt-value-1234567890") +os.environ.setdefault("ADMIN_PASSWORD", "test-admin-password") +os.environ.setdefault("PRIVATEMODE_API_KEY", "pm_test_upstream_key") +os.environ.setdefault("API_KEYS_FILE", "") +os.environ.setdefault("SETTINGS_FILE", "") +os.environ.setdefault("USAGE_FILE", "") +os.environ.setdefault("UPSTREAM_URL", "http://localhost:19999") +os.environ.setdefault("TRUST_PROXY", "false") +os.environ.setdefault("FORCE_HTTPS", "false") from tests.helpers import make_keys_file @@ -63,15 +63,15 @@ def app_env(keys_file, settings_file, usage_file): Returns a dict of the env vars set. """ env = { - 'API_KEYS_FILE': keys_file, - 'SETTINGS_FILE': settings_file, - 'USAGE_FILE': usage_file, - 'ADMIN_PASSWORD': 'test-admin-password', - 'PRIVATEMODE_API_KEY': 'pm_test_upstream_key', - 'UPSTREAM_URL': 'http://localhost:19999', - 'TRUST_PROXY': 'false', - 'FORCE_HTTPS': 'false', - 'PBKDF2_SALT': 'test-salt-value-1234567890', + "API_KEYS_FILE": keys_file, + "SETTINGS_FILE": settings_file, + "USAGE_FILE": usage_file, + "ADMIN_PASSWORD": "test-admin-password", + "PRIVATEMODE_API_KEY": "pm_test_upstream_key", + "UPSTREAM_URL": "http://localhost:19999", + "TRUST_PROXY": "false", + "FORCE_HTTPS": "false", + "PBKDF2_SALT": "test-salt-value-1234567890", } with patch.dict(os.environ, env): yield env diff --git a/tests/helpers.py b/tests/helpers.py index 62fae19..c995964 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,6 @@ import os import time - # ── Test API key constants ── TEST_API_KEY = "pm_test-key-for-unit-tests-12345" diff --git a/tests/test_admin.py b/tests/test_admin.py index 0582557..7dc5085 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -9,25 +9,25 @@ import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) from admin import ( - create_session, - validate_session, - delete_session, + _csrf_tokens, + _login_attempts, _sessions, check_login_rate_limit, - _login_attempts, + create_session, + delete_session, + format_timestamp, generate_csrf_token, - validate_csrf_token, - _csrf_tokens, + get_key_status, load_keys, - save_keys, load_settings, + save_keys, save_settings, - get_key_status, - format_timestamp, update_key_rate_limit, + validate_csrf_token, + validate_session, ) from tests.helpers import make_keys_file @@ -145,48 +145,48 @@ class TestKeysCRUD: def test_load_keys(self, tmp_path): keys_file = make_keys_file(tmp_path) - with patch('admin.KEYS_FILE', keys_file): + with patch("admin.KEYS_FILE", keys_file): data = load_keys() - assert 'keys' in data - assert len(data['keys']) == 4 + assert "keys" in data + assert len(data["keys"]) == 4 def test_load_missing_file(self, tmp_path): - with patch('admin.KEYS_FILE', os.path.join(str(tmp_path), 'nope.json')): + with patch("admin.KEYS_FILE", os.path.join(str(tmp_path), "nope.json")): data = load_keys() assert data == {"keys": []} def test_save_and_reload(self, tmp_path): - keys_file = os.path.join(str(tmp_path), 'keys.json') - with patch('admin.KEYS_FILE', keys_file): + keys_file = os.path.join(str(tmp_path), "keys.json") + with patch("admin.KEYS_FILE", keys_file): save_keys({"keys": [{"key_id": "x", "enabled": True}]}) data = load_keys() - assert len(data['keys']) == 1 - assert data['keys'][0]['key_id'] == 'x' + assert len(data["keys"]) == 1 + assert data["keys"][0]["key_id"] == "x" def test_update_key_rate_limit(self, tmp_path): keys_file = make_keys_file(tmp_path) - with patch('admin.KEYS_FILE', keys_file): + with patch("admin.KEYS_FILE", keys_file): result = update_key_rate_limit("test_key_1", 50) assert result is True # Verify it was saved data = load_keys() - key = next(k for k in data['keys'] if k['key_id'] == 'test_key_1') - assert key['rate_limit'] == 50 + key = next(k for k in data["keys"] if k["key_id"] == "test_key_1") + assert key["rate_limit"] == 50 def test_update_key_rate_limit_clear(self, tmp_path): keys_file = make_keys_file(tmp_path) - with patch('admin.KEYS_FILE', keys_file): + with patch("admin.KEYS_FILE", keys_file): # key_2 has rate_limit=5 update_key_rate_limit("test_key_2", None) data = load_keys() - key = next(k for k in data['keys'] if k['key_id'] == 'test_key_2') - assert 'rate_limit' not in key + key = next(k for k in data["keys"] if k["key_id"] == "test_key_2") + assert "rate_limit" not in key def test_update_nonexistent_key(self, tmp_path): keys_file = make_keys_file(tmp_path) - with patch('admin.KEYS_FILE', keys_file): + with patch("admin.KEYS_FILE", keys_file): result = update_key_rate_limit("nonexistent", 10) assert result is False @@ -195,39 +195,39 @@ class TestSettings: """Test settings load/save.""" def test_load_defaults(self, tmp_path): - with patch('admin.SETTINGS_FILE', os.path.join(str(tmp_path), 'nope.json')): + with patch("admin.SETTINGS_FILE", os.path.join(str(tmp_path), "nope.json")): settings = load_settings() - assert 'rate_limit_requests' in settings - assert 'ip_rate_limit_requests' in settings + assert "rate_limit_requests" in settings + assert "ip_rate_limit_requests" in settings def test_save_and_load(self, tmp_path): - settings_file = os.path.join(str(tmp_path), 'settings.json') - with patch('admin.SETTINGS_FILE', settings_file): - save_settings({'rate_limit_requests': 200, 'custom': 'value'}) + settings_file = os.path.join(str(tmp_path), "settings.json") + with patch("admin.SETTINGS_FILE", settings_file): + save_settings({"rate_limit_requests": 200, "custom": "value"}) settings = load_settings() - assert settings['rate_limit_requests'] == 200 - assert settings['custom'] == 'value' + assert settings["rate_limit_requests"] == 200 + assert settings["custom"] == "value" # Defaults should still be present - assert 'ip_rate_limit_requests' in settings + assert "ip_rate_limit_requests" in settings class TestHelpers: """Test admin helper functions.""" def test_get_key_status_active(self): - key = {'enabled': True} + key = {"enabled": True} status, css = get_key_status(key) assert status == "Active" assert css == "status-active" def test_get_key_status_revoked(self): - key = {'enabled': False} + key = {"enabled": False} status, css = get_key_status(key) assert status == "Revoked" assert css == "status-revoked" def test_get_key_status_expired(self): - key = {'enabled': True, 'expires_at': time.time() - 3600} + key = {"enabled": True, "expires_at": time.time() - 3600} status, css = get_key_status(key) assert status == "Expired" assert css == "status-expired" diff --git a/tests/test_auth.py b/tests/test_auth.py index b848621..141e1bf 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -9,10 +9,13 @@ import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) from tests.helpers import ( - TEST_API_KEY, TEST_API_KEY_2, EXPIRED_API_KEY, DISABLED_API_KEY, + DISABLED_API_KEY, + EXPIRED_API_KEY, + TEST_API_KEY, + TEST_API_KEY_2, make_keys_file, ) @@ -21,54 +24,58 @@ class TestExtractApiKey: """Test API key extraction from request headers.""" def test_bearer_token(self): - from server import extract_api_key from unittest.mock import MagicMock + from server import extract_api_key + request = MagicMock() - request.headers = {'Authorization': f'Bearer {TEST_API_KEY}'} + request.headers = {"Authorization": f"Bearer {TEST_API_KEY}"} assert extract_api_key(request) == TEST_API_KEY def test_x_api_key_header(self): - from server import extract_api_key from unittest.mock import MagicMock + from server import extract_api_key + request = MagicMock() request.headers = MagicMock() - request.headers.get = MagicMock(side_effect=lambda key, default='': ( - '' if key == 'Authorization' else - TEST_API_KEY if key == 'X-API-Key' else default - )) + request.headers.get = MagicMock( + side_effect=lambda key, default="": ( + "" if key == "Authorization" else TEST_API_KEY if key == "X-API-Key" else default + ) + ) assert extract_api_key(request) == TEST_API_KEY def test_missing_key(self): - from server import extract_api_key from unittest.mock import MagicMock + from server import extract_api_key + request = MagicMock() request.headers = MagicMock() - request.headers.get = MagicMock(side_effect=lambda key, default='': '') + request.headers.get = MagicMock(side_effect=lambda key, default="": "") assert extract_api_key(request) is None def test_bearer_prefix_only(self): - from server import extract_api_key from unittest.mock import MagicMock + from server import extract_api_key + request = MagicMock() request.headers = MagicMock() - request.headers.get = lambda key, default='': 'Bearer ' if key == 'Authorization' else None + request.headers.get = lambda key, default="": "Bearer " if key == "Authorization" else None # "Bearer " with empty key returns empty string result = extract_api_key(request) - assert result == '' + assert result == "" def test_non_bearer_auth_header(self): - from server import extract_api_key from unittest.mock import MagicMock + from server import extract_api_key + request = MagicMock() request.headers = MagicMock() - request.headers.get = lambda key, default='': ( - 'Basic dXNlcjpwYXNz' if key == 'Authorization' else None - ) + request.headers.get = lambda key, default="": "Basic dXNlcjpwYXNz" if key == "Authorization" else None # Basic auth should not be extracted as API key assert extract_api_key(request) is None @@ -103,7 +110,7 @@ def test_validate_expired_key(self, tmp_path): keys_file = make_keys_file(tmp_path) km = KeyManager(keys_file) - valid, key_obj = km.validate_key(EXPIRED_API_KEY) + valid, _key_obj = km.validate_key(EXPIRED_API_KEY) assert valid is False def test_validate_disabled_key(self, tmp_path): @@ -112,7 +119,7 @@ def test_validate_disabled_key(self, tmp_path): keys_file = make_keys_file(tmp_path) km = KeyManager(keys_file) - valid, key_obj = km.validate_key(DISABLED_API_KEY) + valid, _key_obj = km.validate_key(DISABLED_API_KEY) assert valid is False def test_key_with_rate_limit(self, tmp_path): @@ -133,9 +140,9 @@ def test_get_key_info(self, tmp_path): info = km.get_key_info(TEST_API_KEY) assert info is not None - assert info['key_id'] == "test_key_1" - assert info['description'] == "Test key 1" - assert 'key_hash' not in info # Shouldn't leak the hash + assert info["key_id"] == "test_key_1" + assert info["description"] == "Test key 1" + assert "key_hash" not in info # Shouldn't leak the hash def test_get_key_info_invalid(self, tmp_path): from key_manager import KeyManager @@ -148,6 +155,7 @@ def test_get_key_info_invalid(self, tmp_path): def test_hot_reload(self, tmp_path): import json + from key_manager import KeyManager keys_file = make_keys_file(tmp_path) @@ -229,10 +237,10 @@ def test_create_key_entry_hash_only(self): from key_manager import create_key_entry entry = create_key_entry("pm_test", description="Test", store_hash_only=True) - assert 'key_hash' in entry - assert 'key' not in entry - assert entry['description'] == "Test" - assert entry['enabled'] is True + assert "key_hash" in entry + assert "key" not in entry + assert entry["description"] == "Test" + assert entry["enabled"] is True def test_create_key_entry_with_expiry(self): from key_manager import create_key_entry @@ -240,5 +248,5 @@ def test_create_key_entry_with_expiry(self): before = time.time() entry = create_key_entry("pm_test", expires_in_days=30) expected_expiry = before + (30 * 86400) - assert entry['expires_at'] >= expected_expiry - 1 - assert entry['expires_at'] <= expected_expiry + 2 + assert entry["expires_at"] >= expected_expiry - 1 + assert entry["expires_at"] <= expected_expiry + 2 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index ea647c2..1ecad65 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -11,7 +11,7 @@ import pytest from aiohttp import web -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) from tests.helpers import TEST_API_KEY, make_keys_file @@ -24,24 +24,27 @@ def proxy_app(tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") env = { - 'API_KEYS_FILE': keys_file, - 'SETTINGS_FILE': settings_file, - 'USAGE_FILE': usage_file, - 'ADMIN_PASSWORD': 'test-admin-password', - 'PRIVATEMODE_API_KEY': 'pm_test_upstream_key', - 'UPSTREAM_URL': 'http://localhost:19999', - 'TRUST_PROXY': 'false', - 'FORCE_HTTPS': 'false', - 'PBKDF2_SALT': 'test-salt-value-1234567890', + "API_KEYS_FILE": keys_file, + "SETTINGS_FILE": settings_file, + "USAGE_FILE": usage_file, + "ADMIN_PASSWORD": "test-admin-password", + "PRIVATEMODE_API_KEY": "pm_test_upstream_key", + "UPSTREAM_URL": "http://localhost:19999", + "TRUST_PROXY": "false", + "FORCE_HTTPS": "false", + "PBKDF2_SALT": "test-salt-value-1234567890", } with patch.dict(os.environ, env): # Re-import to pick up new config values import importlib + import config + importlib.reload(config) # Need to reload server module too since it imports config at module level import server + importlib.reload(server) application = server.create_app() @@ -59,15 +62,15 @@ class TestHealthEndpoint: @pytest.mark.asyncio async def test_health_returns_ok(self, client): - resp = await client.get('/health') + resp = await client.get("/health") assert resp.status == 200 data = await resp.json() - assert data['status'] == 'healthy' + assert data["status"] == "healthy" @pytest.mark.asyncio async def test_health_no_auth_required(self, client): # No auth headers at all - resp = await client.get('/health') + resp = await client.get("/health") assert resp.status == 200 @@ -76,30 +79,31 @@ class TestAuthMiddleware: @pytest.mark.asyncio async def test_missing_api_key(self, client): - resp = await client.post('/v1/chat/completions', - json={"model": "test", "messages": []}) + resp = await client.post("/v1/chat/completions", json={"model": "test", "messages": []}) assert resp.status == 401 data = await resp.json() - assert 'Missing API key' in data['error'] + assert "Missing API key" in data["error"] @pytest.mark.asyncio async def test_invalid_api_key(self, client): - resp = await client.post('/v1/chat/completions', - headers={'Authorization': 'Bearer pm_invalid_key'}, - json={"model": "test", "messages": []}) + resp = await client.post( + "/v1/chat/completions", + headers={"Authorization": "Bearer pm_invalid_key"}, + json={"model": "test", "messages": []}, + ) assert resp.status == 401 data = await resp.json() - assert 'Invalid' in data['error'] + assert "Invalid" in data["error"] @pytest.mark.asyncio async def test_valid_api_key_bearer(self, client): # This will try to proxy to upstream (which doesn't exist in test) # but should get past auth - with patch('server.proxy_request', new_callable=AsyncMock) as mock_proxy: + with patch("server.proxy_request", new_callable=AsyncMock) as mock_proxy: mock_proxy.return_value = web.json_response({"ok": True}) resp = await client.post( - '/v1/chat/completions', - headers={'Authorization': f'Bearer {TEST_API_KEY}'}, + "/v1/chat/completions", + headers={"Authorization": f"Bearer {TEST_API_KEY}"}, json={"model": "test", "messages": []}, ) # Should get past auth (200 from mocked proxy or 502 if proxy fails) @@ -107,11 +111,11 @@ async def test_valid_api_key_bearer(self, client): @pytest.mark.asyncio async def test_valid_api_key_x_header(self, client): - with patch('server.proxy_request', new_callable=AsyncMock) as mock_proxy: + with patch("server.proxy_request", new_callable=AsyncMock) as mock_proxy: mock_proxy.return_value = web.json_response({"ok": True}) resp = await client.post( - '/v1/chat/completions', - headers={'X-API-Key': TEST_API_KEY}, + "/v1/chat/completions", + headers={"X-API-Key": TEST_API_KEY}, json={"model": "test", "messages": []}, ) assert resp.status in (200, 502) @@ -122,21 +126,19 @@ class TestKeyInfoEndpoint: @pytest.mark.asyncio async def test_key_info_valid(self, client): - resp = await client.get('/auth/key-info', - headers={'Authorization': f'Bearer {TEST_API_KEY}'}) + resp = await client.get("/auth/key-info", headers={"Authorization": f"Bearer {TEST_API_KEY}"}) assert resp.status == 200 data = await resp.json() - assert data['key_id'] == 'test_key_1' + assert data["key_id"] == "test_key_1" @pytest.mark.asyncio async def test_key_info_invalid(self, client): - resp = await client.get('/auth/key-info', - headers={'Authorization': 'Bearer pm_bad'}) + resp = await client.get("/auth/key-info", headers={"Authorization": "Bearer pm_bad"}) assert resp.status == 401 @pytest.mark.asyncio async def test_key_info_no_auth(self, client): - resp = await client.get('/auth/key-info') + resp = await client.get("/auth/key-info") assert resp.status == 401 @@ -145,19 +147,19 @@ class TestSecurityHeaders: @pytest.mark.asyncio async def test_health_has_security_headers(self, client): - resp = await client.get('/health') - assert resp.headers.get('X-Content-Type-Options') == 'nosniff' - assert resp.headers.get('X-Frame-Options') == 'DENY' - assert resp.headers.get('X-XSS-Protection') == '1; mode=block' - assert 'Referrer-Policy' in resp.headers - assert 'Permissions-Policy' in resp.headers + resp = await client.get("/health") + assert resp.headers.get("X-Content-Type-Options") == "nosniff" + assert resp.headers.get("X-Frame-Options") == "DENY" + assert resp.headers.get("X-XSS-Protection") == "1; mode=block" + assert "Referrer-Policy" in resp.headers + assert "Permissions-Policy" in resp.headers @pytest.mark.asyncio async def test_admin_has_csp(self, client): # Admin login page should have CSP - resp = await client.get('/admin/login') + resp = await client.get("/admin/login") assert resp.status == 200 - csp = resp.headers.get('Content-Security-Policy', '') + csp = resp.headers.get("Content-Security-Policy", "") assert "frame-ancestors 'none'" in csp @@ -166,51 +168,61 @@ class TestAdminEndpoints: @pytest.mark.asyncio async def test_admin_redirects_to_login(self, client): - resp = await client.get('/admin', allow_redirects=False) + resp = await client.get("/admin", allow_redirects=False) assert resp.status == 302 - assert '/admin/login' in resp.headers['Location'] + assert "/admin/login" in resp.headers["Location"] @pytest.mark.asyncio async def test_login_page_loads(self, client): - resp = await client.get('/admin/login') + resp = await client.get("/admin/login") assert resp.status == 200 text = await resp.text() - assert 'password' in text.lower() - assert 'csrf_token' in text + assert "password" in text.lower() + assert "csrf_token" in text @pytest.mark.asyncio async def test_login_wrong_password(self, client): # First get a CSRF token - resp = await client.get('/admin/login') + resp = await client.get("/admin/login") text = await resp.text() # Extract CSRF token from form import re + match = re.search(r'name="csrf_token" value="([^"]+)"', text) assert match, "CSRF token not found in login form" csrf_token = match.group(1) - resp = await client.post('/admin/login', data={ - 'password': 'wrong', - 'csrf_token': csrf_token, - }, allow_redirects=False) + resp = await client.post( + "/admin/login", + data={ + "password": "wrong", + "csrf_token": csrf_token, + }, + allow_redirects=False, + ) assert resp.status == 302 - assert 'error' in resp.headers.get('Location', '') + assert "error" in resp.headers.get("Location", "") @pytest.mark.asyncio async def test_login_success(self, client): # Get CSRF token - resp = await client.get('/admin/login') + resp = await client.get("/admin/login") text = await resp.text() import re + match = re.search(r'name="csrf_token" value="([^"]+)"', text) csrf_token = match.group(1) - resp = await client.post('/admin/login', data={ - 'password': 'test-admin-password', - 'csrf_token': csrf_token, - }, allow_redirects=False) + resp = await client.post( + "/admin/login", + data={ + "password": "test-admin-password", + "csrf_token": csrf_token, + }, + allow_redirects=False, + ) assert resp.status == 302 - assert resp.headers.get('Location') == '/admin' + assert resp.headers.get("Location") == "/admin" # Should have session cookie - cookies = resp.cookies - assert 'admin_session' in {c.key for c in resp.cookies.values()} or resp.headers.get('Set-Cookie', '') + + assert "admin_session" in {c.key for c in resp.cookies.values()} or resp.headers.get("Set-Cookie", "") diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 0bfeb32..ccf85b6 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -6,8 +6,7 @@ import os import sys - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) from server import ( detect_endpoint_type, @@ -68,41 +67,37 @@ def test_chat_completion_response(self): "prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, - } + }, } body = json.dumps(response).encode() usage = extract_usage_from_response(body, "chat") - assert usage['model'] == "gpt-oss-120b" - assert usage['prompt_tokens'] == 10 - assert usage['completion_tokens'] == 20 - assert usage['total_tokens'] == 30 + assert usage["model"] == "gpt-oss-120b" + assert usage["prompt_tokens"] == 10 + assert usage["completion_tokens"] == 20 + assert usage["total_tokens"] == 30 def test_embedding_response(self): - response = { - "model": "qwen3-embedding-4b", - "data": [{"embedding": [0.1, 0.2]}], - "usage": {"total_tokens": 50} - } + response = {"model": "qwen3-embedding-4b", "data": [{"embedding": [0.1, 0.2]}], "usage": {"total_tokens": 50}} body = json.dumps(response).encode() usage = extract_usage_from_response(body, "embeddings") - assert usage['model'] == "qwen3-embedding-4b" - assert usage['total_tokens'] == 50 + assert usage["model"] == "qwen3-embedding-4b" + assert usage["total_tokens"] == 50 def test_no_usage_field(self): response = {"model": "gpt-oss-120b", "choices": []} body = json.dumps(response).encode() usage = extract_usage_from_response(body, "chat") - assert usage['prompt_tokens'] == 0 - assert usage['total_tokens'] == 0 + assert usage["prompt_tokens"] == 0 + assert usage["total_tokens"] == 0 def test_invalid_response_body(self): usage = extract_usage_from_response(b"not json", "chat") - assert usage['model'] == 'unknown' - assert usage['total_tokens'] == 0 + assert usage["model"] == "unknown" + assert usage["total_tokens"] == 0 def test_empty_response(self): usage = extract_usage_from_response(b"", "chat") - assert usage['model'] == 'unknown' + assert usage["model"] == "unknown" diff --git a/tests/test_rate_limiting.py b/tests/test_rate_limiting.py index 98320a9..e4e7950 100644 --- a/tests/test_rate_limiting.py +++ b/tests/test_rate_limiting.py @@ -9,15 +9,15 @@ import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) from server import ( check_global_rate_limit, - check_per_key_rate_limit, check_ip_rate_limit, + check_per_key_rate_limit, global_rate_limit_store, - rate_limit_store, ip_rate_limit_store, + rate_limit_store, ) @@ -37,12 +37,15 @@ class TestGlobalRateLimit: """Test global rate limiting (shared across all keys).""" def test_allows_within_limit(self): - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 10, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 1000, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 10, + "rate_limit_window": 60, + "ip_rate_limit_requests": 1000, + "ip_rate_limit_window": 60, + }, + ): allowed, remaining, limit, window = check_global_rate_limit() assert allowed is True assert remaining == 9 # limit(10) - count(0) - 1 = 9, then append @@ -50,41 +53,48 @@ def test_allows_within_limit(self): assert window == 60 def test_blocks_at_limit(self): - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 3, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 1000, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 3, + "rate_limit_window": 60, + "ip_rate_limit_requests": 1000, + "ip_rate_limit_window": 60, + }, + ): # Use up the limit check_global_rate_limit() # 1 check_global_rate_limit() # 2 check_global_rate_limit() # 3 # Should be blocked now - allowed, remaining, limit, window = check_global_rate_limit() + allowed, remaining, _limit, _window = check_global_rate_limit() assert allowed is False assert remaining == 0 def test_cleans_old_entries(self): # The global store is reassigned in the function, so we need to test # by calling through the function itself - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 3, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 1000, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 3, + "rate_limit_window": 60, + "ip_rate_limit_requests": 1000, + "ip_rate_limit_window": 60, + }, + ): # Use up 2 of 3 check_global_rate_limit() check_global_rate_limit() # Manually age the entries by modifying the store import server + server.global_rate_limit_store[:] = [time.time() - 120, time.time() - 120] # Should be allowed again since old entries get cleaned - allowed, remaining, limit, window = check_global_rate_limit() + allowed, _remaining, _limit, _window = check_global_rate_limit() assert allowed is True @@ -93,44 +103,53 @@ class TestPerKeyRateLimit: def test_no_limit_set(self): """Keys without a limit should always be allowed.""" - allowed, remaining, limit = check_per_key_rate_limit("key1", None) + allowed, remaining, _limit = check_per_key_rate_limit("key1", None) assert allowed is True assert remaining == -1 # Indicates no limit def test_allows_within_limit(self): - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 100, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 1000, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 100, + "rate_limit_window": 60, + "ip_rate_limit_requests": 1000, + "ip_rate_limit_window": 60, + }, + ): allowed, remaining, limit = check_per_key_rate_limit("key1", 5) assert allowed is True assert remaining == 4 # limit(5) - count(0) - 1 = 4 assert limit == 5 def test_blocks_at_limit(self): - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 100, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 1000, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 100, + "rate_limit_window": 60, + "ip_rate_limit_requests": 1000, + "ip_rate_limit_window": 60, + }, + ): for _ in range(5): check_per_key_rate_limit("key1", 5) - allowed, remaining, limit = check_per_key_rate_limit("key1", 5) + allowed, remaining, _limit = check_per_key_rate_limit("key1", 5) assert allowed is False assert remaining == 0 def test_separate_key_stores(self): """Rate limits should be independent per key.""" - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 100, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 1000, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 100, + "rate_limit_window": 60, + "ip_rate_limit_requests": 1000, + "ip_rate_limit_window": 60, + }, + ): # Exhaust key1 for _ in range(3): check_per_key_rate_limit("key1", 3) @@ -146,38 +165,47 @@ class TestIPRateLimit: """Test IP-based rate limiting.""" def test_allows_within_limit(self): - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 100, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 10, - 'ip_rate_limit_window': 60, - }): - allowed, remaining, limit, window = check_ip_rate_limit("192.168.1.1") + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 100, + "rate_limit_window": 60, + "ip_rate_limit_requests": 10, + "ip_rate_limit_window": 60, + }, + ): + allowed, _remaining, limit, _window = check_ip_rate_limit("192.168.1.1") assert allowed is True assert limit == 10 def test_blocks_at_limit(self): - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 100, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 3, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 100, + "rate_limit_window": 60, + "ip_rate_limit_requests": 3, + "ip_rate_limit_window": 60, + }, + ): for _ in range(3): check_ip_rate_limit("192.168.1.1") - allowed, remaining, limit, window = check_ip_rate_limit("192.168.1.1") + allowed, remaining, _limit, _window = check_ip_rate_limit("192.168.1.1") assert allowed is False assert remaining == 0 def test_separate_ip_stores(self): """Different IPs should have separate rate limit buckets.""" - with patch('server.get_rate_limit_settings', return_value={ - 'rate_limit_requests': 100, - 'rate_limit_window': 60, - 'ip_rate_limit_requests': 2, - 'ip_rate_limit_window': 60, - }): + with patch( + "server.get_rate_limit_settings", + return_value={ + "rate_limit_requests": 100, + "rate_limit_window": 60, + "ip_rate_limit_requests": 2, + "ip_rate_limit_window": 60, + }, + ): # Exhaust IP 1 check_ip_rate_limit("10.0.0.1") check_ip_rate_limit("10.0.0.1") diff --git a/tests/test_usage_tracker.py b/tests/test_usage_tracker.py index eba79ff..8246a28 100644 --- a/tests/test_usage_tracker.py +++ b/tests/test_usage_tracker.py @@ -8,7 +8,7 @@ import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) from usage_tracker import UsageTracker, get_time_range @@ -65,8 +65,8 @@ def test_record_basic_usage(self, tmp_path): ) summary = tracker.get_usage_summary() - assert summary['total_tokens'] == 30 - assert summary['requests'] == 1 + assert summary["total_tokens"] == 30 + assert summary["requests"] == 1 def test_auto_calculates_total_tokens(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") @@ -82,38 +82,42 @@ def test_auto_calculates_total_tokens(self, tmp_path): ) summary = tracker.get_usage_summary() - assert summary['total_tokens'] == 30 + assert summary["total_tokens"] == 30 def test_persistence(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") tracker1 = UsageTracker(usage_file) # Record enough to trigger auto-save (every 10 records) - for i in range(10): + for _i in range(10): tracker1.record_usage( - key_id="key1", model="gpt-oss-120b", endpoint="chat", + key_id="key1", + model="gpt-oss-120b", + endpoint="chat", total_tokens=100, ) # Create a new tracker from the same file tracker2 = UsageTracker(usage_file) summary = tracker2.get_usage_summary() - assert summary['requests'] == 10 - assert summary['total_tokens'] == 1000 + assert summary["requests"] == 10 + assert summary["total_tokens"] == 1000 def test_flush(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") tracker = UsageTracker(usage_file) tracker.record_usage( - key_id="key1", model="gpt-oss-120b", endpoint="chat", + key_id="key1", + model="gpt-oss-120b", + endpoint="chat", total_tokens=100, ) tracker.flush() # Verify file exists and has data tracker2 = UsageTracker(usage_file) - assert tracker2.get_usage_summary()['requests'] == 1 + assert tracker2.get_usage_summary()["requests"] == 1 class TestUsageSummary: @@ -127,8 +131,8 @@ def test_filter_by_key(self, tmp_path): tracker.record_usage(key_id="key2", model="gpt-oss-120b", endpoint="chat", total_tokens=200) summary = tracker.get_usage_summary(key_id="key1") - assert summary['total_tokens'] == 100 - assert summary['requests'] == 1 + assert summary["total_tokens"] == 100 + assert summary["requests"] == 1 def test_filter_by_time(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") @@ -140,7 +144,7 @@ def test_filter_by_time(self, tmp_path): # Filter to future time (should exclude current records) future = time.time() + 3600 summary = tracker.get_usage_summary(start_time=future) - assert summary['requests'] == 0 + assert summary["requests"] == 0 def test_by_model_breakdown(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") @@ -150,10 +154,10 @@ def test_by_model_breakdown(self, tmp_path): tracker.record_usage(key_id="key1", model="gemma-3-27b", endpoint="chat", total_tokens=200) summary = tracker.get_usage_summary() - assert "gpt-oss-120b" in summary['by_model'] - assert "gemma-3-27b" in summary['by_model'] - assert summary['by_model']['gpt-oss-120b']['tokens'] == 100 - assert summary['by_model']['gemma-3-27b']['tokens'] == 200 + assert "gpt-oss-120b" in summary["by_model"] + assert "gemma-3-27b" in summary["by_model"] + assert summary["by_model"]["gpt-oss-120b"]["tokens"] == 100 + assert summary["by_model"]["gemma-3-27b"]["tokens"] == 200 def test_by_endpoint_breakdown(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") @@ -163,8 +167,8 @@ def test_by_endpoint_breakdown(self, tmp_path): tracker.record_usage(key_id="key1", model="qwen3-embedding-4b", endpoint="embeddings", total_tokens=50) summary = tracker.get_usage_summary() - assert summary['by_endpoint']['chat']['requests'] == 1 - assert summary['by_endpoint']['embeddings']['requests'] == 1 + assert summary["by_endpoint"]["chat"]["requests"] == 1 + assert summary["by_endpoint"]["embeddings"]["requests"] == 1 def test_usage_by_key(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") @@ -175,9 +179,9 @@ def test_usage_by_key(self, tmp_path): tracker.record_usage(key_id="key2", model="gpt-oss-120b", endpoint="chat", total_tokens=50) by_key = tracker.get_usage_by_key() - assert by_key['key1']['tokens'] == 300 - assert by_key['key1']['requests'] == 2 - assert by_key['key2']['tokens'] == 50 + assert by_key["key1"]["tokens"] == 300 + assert by_key["key1"]["requests"] == 2 + assert by_key["key2"]["tokens"] == 50 def test_daily_breakdown(self, tmp_path): usage_file = os.path.join(str(tmp_path), "usage.json") @@ -187,7 +191,7 @@ def test_daily_breakdown(self, tmp_path): daily = tracker.get_daily_breakdown(days=1) assert len(daily) >= 1 - assert daily[0]['tokens'] == 100 + assert daily[0]["tokens"] == 100 class TestTimeRanges: diff --git a/tests/test_utils.py b/tests/test_utils.py index 928d5f9..5b76321 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,8 +6,7 @@ import sys from unittest.mock import MagicMock, patch - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy")) from utils import get_client_ip @@ -19,60 +18,52 @@ def test_direct_connection(self): request = MagicMock() request.remote = "192.168.1.1" request.headers = MagicMock() - request.headers.get = lambda key, default='': '' + request.headers.get = lambda key, default="": "" - with patch('utils.TRUST_PROXY', False): + with patch("utils.TRUST_PROXY", False): assert get_client_ip(request) == "192.168.1.1" def test_ignores_forwarded_when_untrusted(self): request = MagicMock() request.remote = "10.0.0.1" request.headers = MagicMock() - request.headers.get = lambda key, default='': ( - "1.2.3.4, 5.6.7.8" if key == 'X-Forwarded-For' else '' - ) + request.headers.get = lambda key, default="": "1.2.3.4, 5.6.7.8" if key == "X-Forwarded-For" else "" - with patch('utils.TRUST_PROXY', False): + with patch("utils.TRUST_PROXY", False): assert get_client_ip(request) == "10.0.0.1" def test_uses_forwarded_when_trusted(self): request = MagicMock() request.remote = "10.0.0.1" request.headers = MagicMock() - request.headers.get = lambda key, default='': ( - "1.2.3.4, 5.6.7.8" if key == 'X-Forwarded-For' else '' - ) + request.headers.get = lambda key, default="": "1.2.3.4, 5.6.7.8" if key == "X-Forwarded-For" else "" - with patch('utils.TRUST_PROXY', True): + with patch("utils.TRUST_PROXY", True): assert get_client_ip(request) == "1.2.3.4" def test_single_forwarded_ip(self): request = MagicMock() request.remote = "10.0.0.1" request.headers = MagicMock() - request.headers.get = lambda key, default='': ( - "203.0.113.50" if key == 'X-Forwarded-For' else '' - ) + request.headers.get = lambda key, default="": "203.0.113.50" if key == "X-Forwarded-For" else "" - with patch('utils.TRUST_PROXY', True): + with patch("utils.TRUST_PROXY", True): assert get_client_ip(request) == "203.0.113.50" def test_no_remote_returns_unknown(self): request = MagicMock() request.remote = None request.headers = MagicMock() - request.headers.get = lambda key, default='': '' + request.headers.get = lambda key, default="": "" - with patch('utils.TRUST_PROXY', False): + with patch("utils.TRUST_PROXY", False): assert get_client_ip(request) == "unknown" def test_forwarded_with_spaces(self): request = MagicMock() request.remote = "10.0.0.1" request.headers = MagicMock() - request.headers.get = lambda key, default='': ( - " 1.2.3.4 , 5.6.7.8 " if key == 'X-Forwarded-For' else '' - ) + request.headers.get = lambda key, default="": " 1.2.3.4 , 5.6.7.8 " if key == "X-Forwarded-For" else "" - with patch('utils.TRUST_PROXY', True): + with patch("utils.TRUST_PROXY", True): assert get_client_ip(request) == "1.2.3.4"