diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml
new file mode 100644
index 0000000..6eda8bd
--- /dev/null
+++ b/.github/workflows/test-lint.yml
@@ -0,0 +1,120 @@
+name: Test & Lint
+
+on:
+ push:
+ branches: ["main"]
+ pull_request:
+ branches: ["main"]
+ workflow_dispatch:
+
+permissions:
+ contents: read
+
+jobs:
+ test:
+ name: Python Tests
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v6
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.12'
+ cache: 'pip'
+ cache-dependency-path: |
+ auth-proxy/requirements.txt
+ requirements-test.txt
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r auth-proxy/requirements.txt
+ pip install -r requirements-test.txt
+
+ - name: Run tests
+ working-directory: .
+ run: |
+ python -m pytest tests/ -v --tb=short
+
+ lint:
+ name: Ruff Lint & Format
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v6
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.12'
+ cache: 'pip'
+
+ - name: Install ruff
+ run: pip install ruff
+
+ - name: Run ruff check
+ run: ruff check auth-proxy/ scripts/ tests/
+
+ - name: Run ruff format check
+ run: ruff format --check auth-proxy/ scripts/ tests/
+
+ typecheck:
+ name: Mypy Type Check
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v6
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.12'
+ cache: 'pip'
+ cache-dependency-path: |
+ auth-proxy/requirements.txt
+ requirements-test.txt
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r auth-proxy/requirements.txt
+ pip install -r requirements-test.txt
+ pip install mypy
+
+ - name: Run mypy
+ run: |
+ echo "Mypy is advisory — errors are reported but do not fail CI."
+ echo "Fix errors incrementally as you touch files."
+ mypy auth-proxy/ scripts/ --ignore-missing-imports --check-untyped-defs --no-strict-optional || true
+
+ docker:
+ name: Docker Build Check
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v6
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Build Docker image
+ uses: docker/build-push-action@v6
+ with:
+ context: .
+ push: false
+ load: true
+ tags: privatemode-proxy:test
+ cache-from: type=gha
+ cache-to: type=gha,mode=max
+
+ - name: Verify container starts
+ run: |
+ docker run --rm -d --name test-container \
+ -e API_KEYS_FILE=/app/secrets/keys.json \
+ -e UPSTREAM_URL=http://localhost:8081 \
+ -e PORT=8080 \
+ -e PRIVATEMODE_API_KEY=test-key \
+ -e FORCE_HTTPS=false \
+ privatemode-proxy:test
+ sleep 5
+ docker logs test-container
+ docker stop test-container
diff --git a/auth-proxy/admin.py b/auth-proxy/admin.py
index 46b4238..8501a44 100644
--- a/auth-proxy/admin.py
+++ b/auth-proxy/admin.py
@@ -11,24 +11,32 @@
Protected by admin password.
"""
-import os
-import json
-import time
+import base64
+import contextlib
import hashlib
+import json
+import os
import secrets
-import base64
+import time
from collections import defaultdict
-from html import escape
from datetime import datetime
+from html import escape
from pathlib import Path
+
from aiohttp import web
from cryptography.fernet import Fernet
-from usage_tracker import get_tracker, get_time_range
+
from config import (
- API_KEYS_FILE, SETTINGS_FILE, ADMIN_PASSWORD, PRIVATEMODE_API_KEY,
- DEFAULT_RATE_LIMIT_REQUESTS, DEFAULT_RATE_LIMIT_WINDOW,
- DEFAULT_IP_RATE_LIMIT_REQUESTS, DEFAULT_IP_RATE_LIMIT_WINDOW
+ ADMIN_PASSWORD,
+ API_KEYS_FILE,
+ DEFAULT_IP_RATE_LIMIT_REQUESTS,
+ DEFAULT_IP_RATE_LIMIT_WINDOW,
+ DEFAULT_RATE_LIMIT_REQUESTS,
+ DEFAULT_RATE_LIMIT_WINDOW,
+ PRIVATEMODE_API_KEY,
+ SETTINGS_FILE,
)
+from usage_tracker import get_time_range, get_tracker
from utils import get_client_ip
# Aliases for backwards compatibility
@@ -62,9 +70,7 @@ def validate_session(token: str, ip: str) -> bool:
del _sessions[token]
return False
# Validate IP matches the session's original IP
- if ip != session_ip:
- return False
- return True
+ return ip == session_ip
def delete_session(token: str) -> None:
@@ -83,10 +89,10 @@ def _cleanup_sessions() -> None:
def get_default_settings() -> dict:
"""Return default settings with rate limits from config."""
return {
- 'rate_limit_requests': DEFAULT_RATE_LIMIT_REQUESTS,
- 'rate_limit_window': DEFAULT_RATE_LIMIT_WINDOW,
- 'ip_rate_limit_requests': DEFAULT_IP_RATE_LIMIT_REQUESTS,
- 'ip_rate_limit_window': DEFAULT_IP_RATE_LIMIT_WINDOW,
+ "rate_limit_requests": DEFAULT_RATE_LIMIT_REQUESTS,
+ "rate_limit_window": DEFAULT_RATE_LIMIT_WINDOW,
+ "ip_rate_limit_requests": DEFAULT_IP_RATE_LIMIT_REQUESTS,
+ "ip_rate_limit_window": DEFAULT_IP_RATE_LIMIT_WINDOW,
}
@@ -100,14 +106,14 @@ def load_settings() -> dict:
saved = json.load(f)
# Merge saved settings with defaults (saved takes precedence)
return {**defaults, **saved}
- except (json.JSONDecodeError, IOError):
+ except (OSError, json.JSONDecodeError):
return defaults
def save_settings(data: dict) -> None:
"""Save settings to file."""
Path(SETTINGS_FILE).parent.mkdir(parents=True, exist_ok=True)
- with open(SETTINGS_FILE, 'w') as f:
+ with open(SETTINGS_FILE, "w") as f:
json.dump(data, f, indent=2)
@@ -115,11 +121,12 @@ def get_privatemode_key_status() -> tuple[str, str]:
"""Get status of Privatemode API key. Returns (status_text, css_class)."""
if PRIVATEMODE_API_KEY:
# Only show prefix to minimize exposure
- if PRIVATEMODE_API_KEY.startswith('pm_'):
+ if PRIVATEMODE_API_KEY.startswith("pm_"):
return "Configured (pm_****)", "status-active"
return "Configured (****)", "status-active"
return "Not configured", "status-revoked"
+
# Temporary storage for newly generated keys (in-memory, short-lived)
# Maps key_id -> (encrypted_key, timestamp)
_pending_keys: dict[str, tuple[str, float]] = {}
@@ -130,7 +137,7 @@ def get_privatemode_key_status() -> tuple[str, str]:
CSRF_TTL = 3600 # 1 hour
-PBKDF2_SALT = os.environ.get('PBKDF2_SALT', '').encode()
+PBKDF2_SALT = os.environ.get("PBKDF2_SALT", "").encode()
if len(PBKDF2_SALT) < 16:
raise ValueError("PBKDF2_SALT environment variable must be at least 16 bytes")
@@ -151,11 +158,11 @@ def _get_fernet_key() -> bytes:
return _FERNET_KEY
# Use PBKDF2 with 600,000 iterations (OWASP recommended for HMAC-SHA256)
key_material = hashlib.pbkdf2_hmac(
- 'sha256',
+ "sha256",
ADMIN_PASSWORD.encode(),
PBKDF2_SALT,
iterations=600_000,
- dklen=32 # Fernet requires 32 bytes
+ dklen=32, # Fernet requires 32 bytes
)
_FERNET_KEY = base64.urlsafe_b64encode(key_material)
return _FERNET_KEY
@@ -234,7 +241,7 @@ def check_admin_auth(request: web.Request) -> bool:
if not ADMIN_PASSWORD:
return False
- token = request.cookies.get('admin_session')
+ token = request.cookies.get("admin_session")
if not token:
return False
@@ -253,19 +260,19 @@ def load_keys() -> dict:
def save_keys(data: dict) -> None:
"""Save keys to file."""
Path(KEYS_FILE).parent.mkdir(parents=True, exist_ok=True)
- with open(KEYS_FILE, 'w') as f:
+ with open(KEYS_FILE, "w") as f:
json.dump(data, f, indent=2)
def update_key_rate_limit(key_id: str, rate_limit: int | None) -> bool:
"""Update the rate limit for a specific key. Returns True if successful."""
keys_data = load_keys()
- for key in keys_data['keys']:
- if key['key_id'] == key_id:
+ for key in keys_data["keys"]:
+ if key["key_id"] == key_id:
if rate_limit is None:
- key.pop('rate_limit', None)
+ key.pop("rate_limit", None)
else:
- key['rate_limit'] = rate_limit
+ key["rate_limit"] = rate_limit
save_keys(keys_data)
return True
return False
@@ -280,9 +287,9 @@ def format_timestamp(ts: float | None) -> str:
def get_key_status(key: dict) -> tuple[str, str]:
"""Get status and CSS class for a key."""
- if not key.get('enabled', True):
+ if not key.get("enabled", True):
return "Revoked", "status-revoked"
- expires_at = key.get('expires_at')
+ expires_at = key.get("expires_at")
if expires_at and time.time() > expires_at:
return "Expired", "status-expired"
return "Active", "status-active"
@@ -634,23 +641,18 @@ def get_key_status(key: dict) -> tuple[str, str]:
async def admin_login_page(request: web.Request) -> web.Response:
"""Show login page."""
if not ADMIN_PASSWORD:
- return web.Response(
- text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.",
- status=403
- )
+ return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403)
if check_admin_auth(request):
- raise web.HTTPFound('/admin')
+ raise web.HTTPFound("/admin")
- error = request.query.get('error', '')
+ error = request.query.get("error", "")
# Escape error message to prevent XSS attacks
- error_html = f'
{escape(error)}
' if error else ''
+ error_html = f'{escape(error)}
' if error else ""
csrf_token = generate_csrf_token()
- html = HTML_TEMPLATE.format(
- content=LOGIN_CONTENT.format(error=error_html, csrf_token=csrf_token)
- )
- return web.Response(text=html, content_type='text/html')
+ html = HTML_TEMPLATE.format(content=LOGIN_CONTENT.format(error=error_html, csrf_token=csrf_token))
+ return web.Response(text=html, content_type="text/html")
async def admin_login_post(request: web.Request) -> web.Response:
@@ -662,14 +664,11 @@ async def admin_login_post(request: web.Request) -> web.Response:
# Check rate limit BEFORE password validation
if not check_login_rate_limit(client_ip):
- return web.Response(
- text="Too many login attempts. Please try again later.",
- status=429
- )
+ return web.Response(text="Too many login attempts. Please try again later.", status=429)
data = await request.post()
- csrf_token = data.get('csrf_token', '')
- password = data.get('password', '')
+ csrf_token = data.get("csrf_token", "")
+ password = data.get("password", "")
# Validate CSRF token
if not validate_csrf_token(csrf_token):
@@ -678,56 +677,53 @@ async def admin_login_post(request: web.Request) -> web.Response:
if secrets.compare_digest(password, ADMIN_PASSWORD):
# Create a new random session token
token = create_session(client_ip)
- response = web.HTTPFound('/admin')
+ response = web.HTTPFound("/admin")
# Detect if running behind HTTPS (Fly.io, nginx, etc.)
- is_https = request.headers.get('X-Forwarded-Proto', '').lower() == 'https'
+ is_https = request.headers.get("X-Forwarded-Proto", "").lower() == "https"
response.set_cookie(
- 'admin_session',
+ "admin_session",
token,
max_age=86400, # 24 hours
httponly=True,
- samesite='Strict',
- secure=is_https # Only send cookie over HTTPS in production
+ samesite="Strict",
+ secure=is_https, # Only send cookie over HTTPS in production
)
return response
# Record failed login attempt
check_login_rate_limit(client_ip, record_attempt=True)
- raise web.HTTPFound('/admin/login?error=Invalid password')
+ raise web.HTTPFound("/admin/login?error=Invalid password")
async def admin_logout(request: web.Request) -> web.Response:
"""Handle logout."""
# Delete the session from store
- token = request.cookies.get('admin_session')
+ token = request.cookies.get("admin_session")
if token:
delete_session(token)
- response = web.HTTPFound('/admin/login')
- response.del_cookie('admin_session')
+ response = web.HTTPFound("/admin/login")
+ response.del_cookie("admin_session")
return response
async def admin_dashboard(request: web.Request) -> web.Response:
"""Show admin dashboard."""
if not ADMIN_PASSWORD:
- return web.Response(
- text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.",
- status=403
- )
+ return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403)
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
# Check for newly generated key to display (one-time retrieval)
- new_key = ''
- show_new_key = ''
- show_key_id = request.query.get('show_key', '')
+ new_key = ""
+ show_new_key = ""
+ show_key_id = request.query.get("show_key", "")
if show_key_id and show_key_id in _pending_keys:
encrypted_key, timestamp = _pending_keys.pop(show_key_id) # Remove after retrieval
if time.time() - timestamp < PENDING_KEY_TTL:
new_key = _decrypt_key_for_display(encrypted_key)
- show_new_key = 'show'
+ show_new_key = "show"
# Load keys
data = load_keys()
@@ -736,30 +732,30 @@ async def admin_dashboard(request: web.Request) -> web.Response:
csrf_token = generate_csrf_token()
# Build keys table
- if not data['keys']:
+ if not data["keys"]:
keys_table = 'No API keys configured. Generate one above.
'
else:
rows = []
- for key in data['keys']:
+ for key in data["keys"]:
status, status_class = get_key_status(key)
- if key.get('enabled', True):
+ if key.get("enabled", True):
actions = f'''
-
'''
else:
actions = f'''
-
'''
actions += f'''
-
'''
- rows.append(KEY_ROW.format(
- key_id=escape(key['key_id']),
- description=escape(key.get('description', '-')),
- status=status,
- status_class=status_class,
- rate_limit_display=rate_limit_display,
- created=format_timestamp(key.get('created_at')),
- expires=format_timestamp(key.get('expires_at')),
- actions=actions
- ))
-
- keys_table = KEYS_TABLE.format(rows=''.join(rows))
+ rows.append(
+ KEY_ROW.format(
+ key_id=escape(key["key_id"]),
+ description=escape(key.get("description", "-")),
+ status=status,
+ status_class=status_class,
+ rate_limit_display=rate_limit_display,
+ created=format_timestamp(key.get("created_at")),
+ expires=format_timestamp(key.get("expires_at")),
+ actions=actions,
+ )
+ )
+
+ keys_table = KEYS_TABLE.format(rows="".join(rows))
# Determine base URL for usage examples
- scheme = request.headers.get('X-Forwarded-Proto', request.scheme)
- host = request.headers.get('X-Forwarded-Host', request.host)
+ scheme = request.headers.get("X-Forwarded-Proto", request.scheme)
+ host = request.headers.get("X-Forwarded-Host", request.host)
base_url = escape(f"{scheme}://{host}")
content = DASHBOARD_CONTENT.format(
- keys_table=keys_table,
- new_key=new_key,
- show_new_key=show_new_key,
- base_url=base_url,
- csrf_token=csrf_token
+ keys_table=keys_table, new_key=new_key, show_new_key=show_new_key, base_url=base_url, csrf_token=csrf_token
)
html = HTML_TEMPLATE.format(content=content)
- return web.Response(text=html, content_type='text/html')
+ return web.Response(text=html, content_type="text/html")
async def admin_generate_key(request: web.Request) -> web.Response:
"""Generate a new API key."""
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
data = await request.post()
- csrf_token = data.get('csrf_token', '')
+ csrf_token = data.get("csrf_token", "")
# Validate CSRF token
if not validate_csrf_token(csrf_token):
return web.Response(text="Invalid or expired CSRF token", status=403)
- description = data.get('description', '')
- expires_days = data.get('expires_days', '')
- rate_limit = data.get('rate_limit', '')
+ description = data.get("description", "")
+ expires_days = data.get("expires_days", "")
+ rate_limit = data.get("rate_limit", "")
# Generate key
new_key = f"pm_{secrets.token_urlsafe(32)}"
@@ -838,28 +832,24 @@ async def admin_generate_key(request: web.Request) -> web.Response:
key_id = f"key_{secrets.token_hex(4)}"
entry = {
- 'key_id': key_id,
- 'key_hash': key_hash,
- 'created_at': time.time(),
- 'description': description,
- 'enabled': True
+ "key_id": key_id,
+ "key_hash": key_hash,
+ "created_at": time.time(),
+ "description": description,
+ "enabled": True,
}
if expires_days:
- try:
- entry['expires_at'] = time.time() + (int(expires_days) * 86400)
- except ValueError:
- pass
+ with contextlib.suppress(ValueError):
+ entry["expires_at"] = time.time() + (int(expires_days) * 86400)
if rate_limit:
- try:
- entry['rate_limit'] = int(rate_limit)
- except ValueError:
- pass
+ with contextlib.suppress(ValueError):
+ entry["rate_limit"] = int(rate_limit)
# Save
keys_data = load_keys()
- keys_data['keys'].append(entry)
+ keys_data["keys"].append(entry)
save_keys(keys_data)
# Store encrypted key temporarily for one-time display
@@ -867,100 +857,100 @@ async def admin_generate_key(request: web.Request) -> web.Response:
_pending_keys[key_id] = (_encrypt_key_for_display(new_key), time.time())
# Redirect with just the key_id (not the actual key)
- raise web.HTTPFound(f'/admin?show_key={key_id}')
+ raise web.HTTPFound(f"/admin?show_key={key_id}")
async def admin_revoke_key(request: web.Request) -> web.Response:
"""Revoke an API key."""
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
data = await request.post()
- csrf_token = data.get('csrf_token', '')
+ csrf_token = data.get("csrf_token", "")
# Validate CSRF token
if not validate_csrf_token(csrf_token):
return web.Response(text="Invalid or expired CSRF token", status=403)
- key_id = request.match_info['key_id']
+ key_id = request.match_info["key_id"]
keys_data = load_keys()
- for key in keys_data['keys']:
- if key['key_id'] == key_id:
- key['enabled'] = False
- key['revoked_at'] = time.time()
+ for key in keys_data["keys"]:
+ if key["key_id"] == key_id:
+ key["enabled"] = False
+ key["revoked_at"] = time.time()
break
save_keys(keys_data)
- raise web.HTTPFound('/admin')
+ raise web.HTTPFound("/admin")
async def admin_enable_key(request: web.Request) -> web.Response:
"""Re-enable an API key."""
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
data = await request.post()
- csrf_token = data.get('csrf_token', '')
+ csrf_token = data.get("csrf_token", "")
# Validate CSRF token
if not validate_csrf_token(csrf_token):
return web.Response(text="Invalid or expired CSRF token", status=403)
- key_id = request.match_info['key_id']
+ key_id = request.match_info["key_id"]
keys_data = load_keys()
- for key in keys_data['keys']:
- if key['key_id'] == key_id:
- key['enabled'] = True
- if 'revoked_at' in key:
- del key['revoked_at']
+ for key in keys_data["keys"]:
+ if key["key_id"] == key_id:
+ key["enabled"] = True
+ if "revoked_at" in key:
+ del key["revoked_at"]
break
save_keys(keys_data)
- raise web.HTTPFound('/admin')
+ raise web.HTTPFound("/admin")
async def admin_delete_key(request: web.Request) -> web.Response:
"""Delete an API key permanently."""
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
data = await request.post()
- csrf_token = data.get('csrf_token', '')
+ csrf_token = data.get("csrf_token", "")
# Validate CSRF token
if not validate_csrf_token(csrf_token):
return web.Response(text="Invalid or expired CSRF token", status=403)
- key_id = request.match_info['key_id']
+ key_id = request.match_info["key_id"]
keys_data = load_keys()
- keys_data['keys'] = [k for k in keys_data['keys'] if k['key_id'] != key_id]
+ keys_data["keys"] = [k for k in keys_data["keys"] if k["key_id"] != key_id]
save_keys(keys_data)
- raise web.HTTPFound('/admin')
+ raise web.HTTPFound("/admin")
async def admin_update_key_rate_limit(request: web.Request) -> web.Response:
"""Update rate limit for a specific API key."""
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
data = await request.post()
- csrf_token = data.get('csrf_token', '')
+ csrf_token = data.get("csrf_token", "")
# Validate CSRF token
if not validate_csrf_token(csrf_token):
return web.Response(text="Invalid or expired CSRF token", status=403)
- key_id = request.match_info['key_id']
+ key_id = request.match_info["key_id"]
# Check if clearing the rate limit
- if data.get('clear'):
+ if data.get("clear"):
update_key_rate_limit(key_id, None)
else:
- rate_limit_str = data.get('rate_limit', '')
+ rate_limit_str = data.get("rate_limit", "")
if rate_limit_str:
try:
rate_limit = int(rate_limit_str)
@@ -969,16 +959,16 @@ async def admin_update_key_rate_limit(request: web.Request) -> web.Response:
except ValueError:
pass
- raise web.HTTPFound('/admin')
+ raise web.HTTPFound("/admin")
async def admin_save_rate_limits(request: web.Request) -> web.Response:
"""Save global rate limit settings."""
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
data = await request.post()
- csrf_token = data.get('csrf_token', '')
+ csrf_token = data.get("csrf_token", "")
# Validate CSRF token
if not validate_csrf_token(csrf_token):
@@ -989,20 +979,20 @@ async def admin_save_rate_limits(request: web.Request) -> web.Response:
# Update rate limit settings
try:
- if data.get('rate_limit_requests'):
- settings['rate_limit_requests'] = int(data['rate_limit_requests'])
- if data.get('rate_limit_window'):
- settings['rate_limit_window'] = int(data['rate_limit_window'])
- if data.get('ip_rate_limit_requests'):
- settings['ip_rate_limit_requests'] = int(data['ip_rate_limit_requests'])
- if data.get('ip_rate_limit_window'):
- settings['ip_rate_limit_window'] = int(data['ip_rate_limit_window'])
+ if data.get("rate_limit_requests"):
+ settings["rate_limit_requests"] = int(data["rate_limit_requests"])
+ if data.get("rate_limit_window"):
+ settings["rate_limit_window"] = int(data["rate_limit_window"])
+ if data.get("ip_rate_limit_requests"):
+ settings["ip_rate_limit_requests"] = int(data["ip_rate_limit_requests"])
+ if data.get("ip_rate_limit_window"):
+ settings["ip_rate_limit_window"] = int(data["ip_rate_limit_window"])
except ValueError:
pass
save_settings(settings)
- raise web.HTTPFound('/admin/settings?success=rate_limits')
+ raise web.HTTPFound("/admin/settings?success=rate_limits")
USAGE_CONTENT = """
@@ -1186,18 +1176,15 @@ async def admin_save_rate_limits(request: web.Request) -> web.Response:
async def admin_settings(request: web.Request) -> web.Response:
"""Show settings page."""
if not ADMIN_PASSWORD:
- return web.Response(
- text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.",
- status=403
- )
+ return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403)
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
# Check for success message
- success = request.query.get('success', '')
- success_message = ''
- if success == 'rate_limits':
+ success = request.query.get("success", "")
+ success_message = ""
+ if success == "rate_limits":
success_message = 'Rate limit settings saved successfully.
'
# Get Privatemode key status
@@ -1206,10 +1193,10 @@ async def admin_settings(request: web.Request) -> web.Response:
# Build message and dot color based on key status
if PRIVATEMODE_API_KEY:
pm_key_message = 'E2E encryption active
'
- pm_dot_color = '#22c55e'
+ pm_dot_color = "#22c55e"
else:
pm_key_message = 'Not configured - set PRIVATEMODE_API_KEY
'
- pm_dot_color = '#ef4444'
+ pm_dot_color = "#ef4444"
# Load current rate limit settings
settings = load_settings()
@@ -1222,48 +1209,45 @@ async def admin_settings(request: web.Request) -> web.Response:
pm_dot_color=pm_dot_color,
csrf_token=csrf_token,
success_message=success_message,
- rate_limit_requests=settings.get('rate_limit_requests', 100),
- rate_limit_window=settings.get('rate_limit_window', 60),
- ip_rate_limit_requests=settings.get('ip_rate_limit_requests', 1000),
- ip_rate_limit_window=settings.get('ip_rate_limit_window', 60),
+ rate_limit_requests=settings.get("rate_limit_requests", 100),
+ rate_limit_window=settings.get("rate_limit_window", 60),
+ ip_rate_limit_requests=settings.get("ip_rate_limit_requests", 1000),
+ ip_rate_limit_window=settings.get("ip_rate_limit_window", 60),
)
html = HTML_TEMPLATE.format(content=content)
- return web.Response(text=html, content_type='text/html')
+ return web.Response(text=html, content_type="text/html")
async def admin_usage(request: web.Request) -> web.Response:
"""Show usage dashboard."""
if not ADMIN_PASSWORD:
- return web.Response(
- text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.",
- status=403
- )
+ return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403)
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
# Get time period from query params
- period = request.query.get('period', 'month')
+ period = request.query.get("period", "month")
start_time, end_time = get_time_range(period)
# Period labels and active states
period_labels = {
- 'today': 'Today',
- 'week': 'Last 7 Days',
- 'month': 'Last 30 Days',
- 'year': 'Last Year',
- 'all': 'All Time'
+ "today": "Today",
+ "week": "Last 7 Days",
+ "month": "Last 30 Days",
+ "year": "Last Year",
+ "all": "All Time",
}
- period_label = period_labels.get(period, 'Last 30 Days')
+ period_label = period_labels.get(period, "Last 30 Days")
# Active button states
active_states = {
- 'active_today': 'btn-primary' if period == 'today' else '',
- 'active_week': 'btn-primary' if period == 'week' else '',
- 'active_month': 'btn-primary' if period == 'month' else '',
- 'active_year': 'btn-primary' if period == 'year' else '',
- 'active_all': 'btn-primary' if period == 'all' else ''
+ "active_today": "btn-primary" if period == "today" else "",
+ "active_week": "btn-primary" if period == "week" else "",
+ "active_month": "btn-primary" if period == "month" else "",
+ "active_year": "btn-primary" if period == "year" else "",
+ "active_all": "btn-primary" if period == "all" else "",
}
tracker = get_tracker()
@@ -1276,7 +1260,7 @@ async def admin_usage(request: web.Request) -> web.Response:
# Load keys to get descriptions
keys_data = load_keys()
- key_descriptions = {k['key_id']: k.get('description', '-') for k in keys_data.get('keys', [])}
+ key_descriptions = {k["key_id"]: k.get("description", "-") for k in keys_data.get("keys", [])}
# Build usage by key table
if not usage_by_key:
@@ -1284,44 +1268,45 @@ async def admin_usage(request: web.Request) -> web.Response:
else:
rows = []
# Sort by cost descending
- sorted_keys = sorted(usage_by_key.items(), key=lambda x: x[1]['cost_eur'], reverse=True)
+ sorted_keys = sorted(usage_by_key.items(), key=lambda x: x[1]["cost_eur"], reverse=True)
for key_id, data in sorted_keys:
- rows.append(USAGE_BY_KEY_ROW.format(
- description=escape(key_descriptions.get(key_id, 'Unknown Key')),
- tokens=data['tokens'],
- requests=data['requests'],
- cost=data['cost_eur']
- ))
- usage_by_key_table = USAGE_BY_KEY_TABLE.format(rows=''.join(rows))
+ rows.append(
+ USAGE_BY_KEY_ROW.format(
+ description=escape(key_descriptions.get(key_id, "Unknown Key")),
+ tokens=data["tokens"],
+ requests=data["requests"],
+ cost=data["cost_eur"],
+ )
+ )
+ usage_by_key_table = USAGE_BY_KEY_TABLE.format(rows="".join(rows))
# Build usage by model table
- if not summary['by_model']:
+ if not summary["by_model"]:
usage_by_model_table = 'No usage data for this period.
'
else:
rows = []
# Sort by cost descending
- sorted_models = sorted(summary['by_model'].items(), key=lambda x: x[1]['cost'], reverse=True)
+ sorted_models = sorted(summary["by_model"].items(), key=lambda x: x[1]["cost"], reverse=True)
for model, data in sorted_models:
- rows.append(USAGE_BY_MODEL_ROW.format(
- model=escape(model),
- tokens=data['tokens'],
- requests=data['requests'],
- cost=data['cost']
- ))
- usage_by_model_table = USAGE_BY_MODEL_TABLE.format(rows=''.join(rows))
+ rows.append(
+ USAGE_BY_MODEL_ROW.format(
+ model=escape(model), tokens=data["tokens"], requests=data["requests"], cost=data["cost"]
+ )
+ )
+ usage_by_model_table = USAGE_BY_MODEL_TABLE.format(rows="".join(rows))
content = USAGE_CONTENT.format(
period_label=period_label,
- total_cost=summary['total_cost_eur'],
- total_tokens=summary['total_tokens'],
- total_requests=summary['requests'],
+ total_cost=summary["total_cost_eur"],
+ total_tokens=summary["total_tokens"],
+ total_requests=summary["requests"],
usage_by_key_table=usage_by_key_table,
usage_by_model_table=usage_by_model_table,
- **active_states
+ **active_states,
)
html = HTML_TEMPLATE.format(content=content)
- return web.Response(text=html, content_type='text/html')
+ return web.Response(text=html, content_type="text/html")
ABOUT_CONTENT = """
@@ -1502,28 +1487,25 @@ async def admin_usage(request: web.Request) -> web.Response:
async def admin_about(request: web.Request) -> web.Response:
"""Show about page."""
if not ADMIN_PASSWORD:
- return web.Response(
- text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.",
- status=403
- )
+ return web.Response(text="Admin UI disabled. Set ADMIN_PASSWORD environment variable to enable.", status=403)
if not check_admin_auth(request):
- raise web.HTTPFound('/admin/login')
+ raise web.HTTPFound("/admin/login")
html = HTML_TEMPLATE.format(content=ABOUT_CONTENT)
- return web.Response(text=html, content_type='text/html')
+ return web.Response(text=html, content_type="text/html")
async def admin_static(request: web.Request) -> web.Response:
"""Serve static files (logo, etc.)."""
- filename = request.match_info.get('filename', '')
+ filename = request.match_info.get("filename", "")
# Security: allowlist maps filenames to (relative path, content type).
# The user-controlled value is only used as a dict key — the filesystem
# path is constructed entirely from hardcoded values, breaking taint flow.
- static_dir = Path(__file__).parent / 'static'
+ static_dir = Path(__file__).parent / "static"
allowed_files = {
- 'logo.png': (static_dir / 'logo.png', 'image/png'),
+ "logo.png": (static_dir / "logo.png", "image/png"),
}
entry = allowed_files.get(filename)
@@ -1534,23 +1516,23 @@ async def admin_static(request: web.Request) -> web.Response:
if not file_path.exists():
raise web.HTTPNotFound()
- with open(file_path, 'rb') as f:
+ with open(file_path, "rb") as f:
return web.Response(body=f.read(), content_type=content_type)
def setup_admin_routes(app: web.Application) -> None:
"""Add admin routes to the app."""
- app.router.add_get('/admin', admin_dashboard)
- app.router.add_get('/admin/settings', admin_settings)
- app.router.add_post('/admin/settings/rate-limits', admin_save_rate_limits)
- app.router.add_get('/admin/usage', admin_usage)
- app.router.add_get('/admin/about', admin_about)
- app.router.add_get('/admin/static/{filename}', admin_static)
- app.router.add_get('/admin/login', admin_login_page)
- app.router.add_post('/admin/login', admin_login_post)
- app.router.add_get('/admin/logout', admin_logout)
- app.router.add_post('/admin/keys/generate', admin_generate_key)
- app.router.add_post('/admin/keys/{key_id}/revoke', admin_revoke_key)
- app.router.add_post('/admin/keys/{key_id}/enable', admin_enable_key)
- app.router.add_post('/admin/keys/{key_id}/delete', admin_delete_key)
- app.router.add_post('/admin/keys/{key_id}/rate-limit', admin_update_key_rate_limit)
+ app.router.add_get("/admin", admin_dashboard)
+ app.router.add_get("/admin/settings", admin_settings)
+ app.router.add_post("/admin/settings/rate-limits", admin_save_rate_limits)
+ app.router.add_get("/admin/usage", admin_usage)
+ app.router.add_get("/admin/about", admin_about)
+ app.router.add_get("/admin/static/{filename}", admin_static)
+ app.router.add_get("/admin/login", admin_login_page)
+ app.router.add_post("/admin/login", admin_login_post)
+ app.router.add_get("/admin/logout", admin_logout)
+ app.router.add_post("/admin/keys/generate", admin_generate_key)
+ app.router.add_post("/admin/keys/{key_id}/revoke", admin_revoke_key)
+ app.router.add_post("/admin/keys/{key_id}/enable", admin_enable_key)
+ app.router.add_post("/admin/keys/{key_id}/delete", admin_delete_key)
+ app.router.add_post("/admin/keys/{key_id}/rate-limit", admin_update_key_rate_limit)
diff --git a/auth-proxy/config.py b/auth-proxy/config.py
index 3605a06..9472914 100644
--- a/auth-proxy/config.py
+++ b/auth-proxy/config.py
@@ -8,38 +8,38 @@
import os
# File paths
-API_KEYS_FILE = os.environ.get('API_KEYS_FILE', '/app/secrets/api_keys.json')
-SETTINGS_FILE = os.environ.get('SETTINGS_FILE', '/app/secrets/settings.json')
-USAGE_FILE = os.environ.get('USAGE_FILE', '/app/data/usage.json')
+API_KEYS_FILE = os.environ.get("API_KEYS_FILE", "/app/secrets/api_keys.json")
+SETTINGS_FILE = os.environ.get("SETTINGS_FILE", "/app/secrets/settings.json")
+USAGE_FILE = os.environ.get("USAGE_FILE", "/app/data/usage.json")
# Server configuration
# Default assumes both services run in same container (supervisord setup)
# Override with UPSTREAM_URL env var for different deployments
-UPSTREAM_URL = os.environ.get('UPSTREAM_URL', 'http://localhost:8081')
-PORT = int(os.environ.get('PORT', '8080'))
-PRIVATEMODE_API_KEY = os.environ.get('PRIVATEMODE_API_KEY', '')
-ADMIN_PASSWORD = os.environ.get('ADMIN_PASSWORD', '')
+UPSTREAM_URL = os.environ.get("UPSTREAM_URL", "http://localhost:8081")
+PORT = int(os.environ.get("PORT", "8080"))
+PRIVATEMODE_API_KEY = os.environ.get("PRIVATEMODE_API_KEY", "")
+ADMIN_PASSWORD = os.environ.get("ADMIN_PASSWORD", "")
# Trust proxy headers (X-Forwarded-For, X-Forwarded-Proto, etc.)
# Set to 'true' when running behind a trusted reverse proxy (nginx, Fly.io, etc.)
# When false, X-Forwarded-For headers are ignored to prevent IP spoofing
-TRUST_PROXY = os.environ.get('TRUST_PROXY', 'false').lower() in ('true', '1', 'yes')
+TRUST_PROXY = os.environ.get("TRUST_PROXY", "false").lower() in ("true", "1", "yes")
# Default rate limits
-DEFAULT_RATE_LIMIT_REQUESTS = int(os.environ.get('RATE_LIMIT_REQUESTS', '100'))
-DEFAULT_RATE_LIMIT_WINDOW = int(os.environ.get('RATE_LIMIT_WINDOW', '60'))
-DEFAULT_IP_RATE_LIMIT_REQUESTS = int(os.environ.get('IP_RATE_LIMIT_REQUESTS', '1000'))
-DEFAULT_IP_RATE_LIMIT_WINDOW = int(os.environ.get('IP_RATE_LIMIT_WINDOW', '60'))
+DEFAULT_RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "100"))
+DEFAULT_RATE_LIMIT_WINDOW = int(os.environ.get("RATE_LIMIT_WINDOW", "60"))
+DEFAULT_IP_RATE_LIMIT_REQUESTS = int(os.environ.get("IP_RATE_LIMIT_REQUESTS", "1000"))
+DEFAULT_IP_RATE_LIMIT_WINDOW = int(os.environ.get("IP_RATE_LIMIT_WINDOW", "60"))
# TLS configuration
# Set TLS_CERT_FILE and TLS_KEY_FILE to enable HTTPS
# If not set, server runs in HTTP mode
-TLS_CERT_FILE = os.environ.get('TLS_CERT_FILE', '')
-TLS_KEY_FILE = os.environ.get('TLS_KEY_FILE', '')
+TLS_CERT_FILE = os.environ.get("TLS_CERT_FILE", "")
+TLS_KEY_FILE = os.environ.get("TLS_KEY_FILE", "")
TLS_ENABLED = bool(TLS_CERT_FILE and TLS_KEY_FILE)
# Force HTTPS - reject non-HTTPS requests (default: true when TLS is enabled)
# When behind a trusted proxy, checks X-Forwarded-Proto header
# Set to 'false' to allow HTTP (not recommended for production)
-_force_https_default = 'true' if TLS_ENABLED else 'false'
-FORCE_HTTPS = os.environ.get('FORCE_HTTPS', _force_https_default).lower() in ('true', '1', 'yes')
+_force_https_default = "true" if TLS_ENABLED else "false"
+FORCE_HTTPS = os.environ.get("FORCE_HTTPS", _force_https_default).lower() in ("true", "1", "yes")
diff --git a/auth-proxy/key_manager.py b/auth-proxy/key_manager.py
index 69f893a..d5a1e99 100644
--- a/auth-proxy/key_manager.py
+++ b/auth-proxy/key_manager.py
@@ -9,24 +9,24 @@
- Description/owner info
"""
-import json
import hashlib
+import json
+import os
import secrets
-import time
import threading
-import os
+import time
from dataclasses import dataclass
-from typing import Optional
@dataclass
class APIKey:
"""Represents an API key with metadata."""
+
key_id: str
key_hash: str # We store hash, not plaintext
created_at: float
- expires_at: Optional[float] = None
- rate_limit: Optional[int] = None
+ expires_at: float | None = None
+ rate_limit: int | None = None
description: str = ""
enabled: bool = True
@@ -34,9 +34,7 @@ def is_valid(self) -> bool:
"""Check if key is valid (enabled and not expired)."""
if not self.enabled:
return False
- if self.expires_at and time.time() > self.expires_at:
- return False
- return True
+ return not (self.expires_at and time.time() > self.expires_at)
class KeyManager:
@@ -56,12 +54,12 @@ def _hash_key(self, key: str) -> str:
def _load_keys_from_env(self) -> None:
"""Load keys from environment variables (for cloud deployment)."""
# API_KEYS env var: comma-separated list of keys
- env_keys = os.environ.get('API_KEYS', '')
+ env_keys = os.environ.get("API_KEYS", "")
if not env_keys:
return
with self._lock:
- for i, key in enumerate(env_keys.split(',')):
+ for i, key in enumerate(env_keys.split(",")):
key = key.strip()
if not key:
continue
@@ -71,7 +69,7 @@ def _load_keys_from_env(self) -> None:
key_hash=key_hash,
created_at=time.time(),
description="From API_KEYS environment variable",
- enabled=True
+ enabled=True,
)
self.keys[key_hash] = api_key
@@ -94,28 +92,28 @@ def _load_keys(self) -> None:
if mtime <= self._last_modified:
return
- with open(self.keys_file, 'r') as f:
+ with open(self.keys_file) as f:
data = json.load(f)
with self._lock:
self.keys.clear()
- for key_data in data.get('keys', []):
+ for key_data in data.get("keys", []):
# Support both hashed and plaintext keys in config
- if 'key_hash' in key_data:
- key_hash = key_data['key_hash']
- elif 'key' in key_data:
- key_hash = self._hash_key(key_data['key'])
+ if "key_hash" in key_data:
+ key_hash = key_data["key_hash"]
+ elif "key" in key_data:
+ key_hash = self._hash_key(key_data["key"])
else:
continue
api_key = APIKey(
- key_id=key_data.get('key_id', key_hash[:8]),
+ key_id=key_data.get("key_id", key_hash[:8]),
key_hash=key_hash,
- created_at=key_data.get('created_at', time.time()),
- expires_at=key_data.get('expires_at'),
- rate_limit=key_data.get('rate_limit'),
- description=key_data.get('description', ''),
- enabled=key_data.get('enabled', True)
+ created_at=key_data.get("created_at", time.time()),
+ expires_at=key_data.get("expires_at"),
+ rate_limit=key_data.get("rate_limit"),
+ description=key_data.get("description", ""),
+ enabled=key_data.get("enabled", True),
)
self.keys[key_hash] = api_key
@@ -135,7 +133,7 @@ def reload_if_changed(self) -> None:
except Exception as e:
print(f"Error checking keys file: {e}")
- def validate_key(self, key: str) -> tuple[bool, Optional[APIKey]]:
+ def validate_key(self, key: str) -> tuple[bool, APIKey | None]:
"""
Validate an API key.
Returns (is_valid, api_key_obj or None).
@@ -149,17 +147,17 @@ def validate_key(self, key: str) -> tuple[bool, Optional[APIKey]]:
return True, api_key
return False, None
- def get_key_info(self, key: str) -> Optional[dict]:
+ def get_key_info(self, key: str) -> dict | None:
"""Get non-sensitive info about a key."""
valid, api_key = self.validate_key(key)
if not valid or not api_key:
return None
return {
- 'key_id': api_key.key_id,
- 'created_at': api_key.created_at,
- 'expires_at': api_key.expires_at,
- 'rate_limit': api_key.rate_limit,
- 'description': api_key.description
+ "key_id": api_key.key_id,
+ "created_at": api_key.created_at,
+ "expires_at": api_key.expires_at,
+ "rate_limit": api_key.rate_limit,
+ "description": api_key.description,
}
@@ -172,27 +170,27 @@ def generate_api_key(prefix: str = "pm") -> str:
def create_key_entry(
key: str,
description: str = "",
- expires_in_days: Optional[int] = None,
- rate_limit: Optional[int] = None,
- store_hash_only: bool = True
+ expires_in_days: int | None = None,
+ rate_limit: int | None = None,
+ store_hash_only: bool = True,
) -> dict:
"""Create a key entry for the keys file."""
entry = {
- 'key_id': f"key_{secrets.token_hex(4)}",
- 'created_at': time.time(),
- 'description': description,
- 'enabled': True
+ "key_id": f"key_{secrets.token_hex(4)}",
+ "created_at": time.time(),
+ "description": description,
+ "enabled": True,
}
if store_hash_only:
- entry['key_hash'] = hashlib.sha256(key.encode()).hexdigest()
+ entry["key_hash"] = hashlib.sha256(key.encode()).hexdigest()
else:
- entry['key'] = key
+ entry["key"] = key
if expires_in_days:
- entry['expires_at'] = time.time() + (expires_in_days * 86400)
+ entry["expires_at"] = time.time() + (expires_in_days * 86400)
if rate_limit:
- entry['rate_limit'] = rate_limit
+ entry["rate_limit"] = rate_limit
return entry
diff --git a/auth-proxy/server.py b/auth-proxy/server.py
index 0d93318..53a9ef9 100644
--- a/auth-proxy/server.py
+++ b/auth-proxy/server.py
@@ -9,21 +9,31 @@
- Key rotation support via hot-reload
"""
+import json
import ssl
import time
-import json
-import asyncio
from collections import defaultdict
-from aiohttp import web, ClientSession, ClientTimeout
-from key_manager import KeyManager
-from admin import setup_admin_routes, load_settings
-from usage_tracker import get_tracker
+
+from aiohttp import ClientSession, ClientTimeout, web
+
+from admin import load_settings, setup_admin_routes
from config import (
- API_KEYS_FILE, UPSTREAM_URL, PORT, PRIVATEMODE_API_KEY,
- DEFAULT_RATE_LIMIT_REQUESTS, DEFAULT_RATE_LIMIT_WINDOW,
- DEFAULT_IP_RATE_LIMIT_REQUESTS, DEFAULT_IP_RATE_LIMIT_WINDOW,
- TLS_ENABLED, TLS_CERT_FILE, TLS_KEY_FILE, FORCE_HTTPS, TRUST_PROXY
+ API_KEYS_FILE,
+ DEFAULT_IP_RATE_LIMIT_REQUESTS,
+ DEFAULT_IP_RATE_LIMIT_WINDOW,
+ DEFAULT_RATE_LIMIT_REQUESTS,
+ DEFAULT_RATE_LIMIT_WINDOW,
+ FORCE_HTTPS,
+ PORT,
+ PRIVATEMODE_API_KEY,
+ TLS_CERT_FILE,
+ TLS_ENABLED,
+ TLS_KEY_FILE,
+ TRUST_PROXY,
+ UPSTREAM_URL,
)
+from key_manager import KeyManager
+from usage_tracker import get_tracker
from utils import get_client_ip
@@ -31,12 +41,13 @@ def get_rate_limit_settings() -> dict:
"""Get current rate limit settings from settings file or defaults."""
settings = load_settings()
return {
- 'rate_limit_requests': settings.get('rate_limit_requests', DEFAULT_RATE_LIMIT_REQUESTS),
- 'rate_limit_window': settings.get('rate_limit_window', DEFAULT_RATE_LIMIT_WINDOW),
- 'ip_rate_limit_requests': settings.get('ip_rate_limit_requests', DEFAULT_IP_RATE_LIMIT_REQUESTS),
- 'ip_rate_limit_window': settings.get('ip_rate_limit_window', DEFAULT_IP_RATE_LIMIT_WINDOW),
+ "rate_limit_requests": settings.get("rate_limit_requests", DEFAULT_RATE_LIMIT_REQUESTS),
+ "rate_limit_window": settings.get("rate_limit_window", DEFAULT_RATE_LIMIT_WINDOW),
+ "ip_rate_limit_requests": settings.get("ip_rate_limit_requests", DEFAULT_IP_RATE_LIMIT_REQUESTS),
+ "ip_rate_limit_window": settings.get("ip_rate_limit_window", DEFAULT_IP_RATE_LIMIT_WINDOW),
}
+
# Rate limiting storage: key_id -> list of request timestamps
rate_limit_store: dict[str, list[float]] = defaultdict(list)
@@ -50,12 +61,12 @@ def get_rate_limit_settings() -> dict:
def extract_api_key(request: web.Request) -> str | None:
"""Extract API key from request headers."""
# Check Authorization header (Bearer token)
- auth_header = request.headers.get('Authorization', '')
- if auth_header.startswith('Bearer '):
+ auth_header = request.headers.get("Authorization", "")
+ if auth_header.startswith("Bearer "):
return auth_header[7:]
# Check X-API-Key header
- api_key = request.headers.get('X-API-Key')
+ api_key = request.headers.get("X-API-Key")
if api_key:
return api_key
@@ -69,8 +80,8 @@ def check_global_rate_limit() -> tuple[bool, int, int, int]:
"""
global global_rate_limit_store
settings = get_rate_limit_settings()
- limit = settings['rate_limit_requests']
- window = settings['rate_limit_window']
+ limit = settings["rate_limit_requests"]
+ window = settings["rate_limit_window"]
now = time.time()
window_start = now - window
@@ -99,15 +110,13 @@ def check_per_key_rate_limit(key_id: str, limit: int | None) -> tuple[bool, int,
return True, -1, 0
settings = get_rate_limit_settings()
- window = settings['rate_limit_window']
+ window = settings["rate_limit_window"]
now = time.time()
window_start = now - window
# Clean old entries
- rate_limit_store[key_id] = [
- ts for ts in rate_limit_store[key_id] if ts > window_start
- ]
+ rate_limit_store[key_id] = [ts for ts in rate_limit_store[key_id] if ts > window_start]
current_count = len(rate_limit_store[key_id])
remaining = max(0, limit - current_count)
@@ -122,8 +131,8 @@ def check_per_key_rate_limit(key_id: str, limit: int | None) -> tuple[bool, int,
def check_ip_rate_limit(ip: str) -> tuple[bool, int, int, int]:
"""Check if IP is within global rate limit. Returns (allowed, remaining, limit, window)."""
settings = get_rate_limit_settings()
- ip_limit = settings['ip_rate_limit_requests']
- ip_window = settings['ip_rate_limit_window']
+ ip_limit = settings["ip_rate_limit_requests"]
+ ip_window = settings["ip_rate_limit_window"]
now = time.time()
window_start = now - ip_window
@@ -143,24 +152,24 @@ def check_ip_rate_limit(ip: str) -> tuple[bool, int, int, int]:
def detect_endpoint_type(path: str) -> str:
"""Detect the API endpoint type from path."""
- if '/chat/completions' in path:
- return 'chat'
- elif '/embeddings' in path:
- return 'embeddings'
- elif '/audio/transcriptions' in path:
- return 'transcriptions'
- elif '/completions' in path:
- return 'completions'
- return 'other'
+ if "/chat/completions" in path:
+ return "chat"
+ elif "/embeddings" in path:
+ return "embeddings"
+ elif "/audio/transcriptions" in path:
+ return "transcriptions"
+ elif "/completions" in path:
+ return "completions"
+ return "other"
def extract_model_from_request(body: bytes) -> str:
"""Extract model name from request body."""
try:
data = json.loads(body)
- return data.get('model', 'unknown')
+ return data.get("model", "unknown")
except (json.JSONDecodeError, UnicodeDecodeError):
- return 'unknown'
+ return "unknown"
def extract_usage_from_response(response_body: bytes, endpoint: str) -> dict:
@@ -168,31 +177,26 @@ def extract_usage_from_response(response_body: bytes, endpoint: str) -> dict:
Extract token usage from response body.
Only extracts usage metadata, never the actual content.
"""
- usage = {
- 'prompt_tokens': 0,
- 'completion_tokens': 0,
- 'total_tokens': 0,
- 'model': 'unknown'
- }
+ usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "model": "unknown"}
try:
data = json.loads(response_body)
# Get model from response
- usage['model'] = data.get('model', 'unknown')
+ usage["model"] = data.get("model", "unknown")
# Chat completions and completions have usage object
- if 'usage' in data:
- usage_data = data['usage']
- usage['prompt_tokens'] = usage_data.get('prompt_tokens', 0)
- usage['completion_tokens'] = usage_data.get('completion_tokens', 0)
- usage['total_tokens'] = usage_data.get('total_tokens', 0)
+ if "usage" in data:
+ usage_data = data["usage"]
+ usage["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
+ usage["completion_tokens"] = usage_data.get("completion_tokens", 0)
+ usage["total_tokens"] = usage_data.get("total_tokens", 0)
# Embeddings also have usage
- elif endpoint == 'embeddings' and 'usage' in data:
- usage_data = data['usage']
- usage['total_tokens'] = usage_data.get('total_tokens', 0)
- usage['prompt_tokens'] = usage_data.get('prompt_tokens', usage['total_tokens'])
+ elif endpoint == "embeddings" and "usage" in data:
+ usage_data = data["usage"]
+ usage["total_tokens"] = usage_data.get("total_tokens", 0)
+ usage["prompt_tokens"] = usage_data.get("prompt_tokens", usage["total_tokens"])
except (json.JSONDecodeError, UnicodeDecodeError, KeyError):
pass
@@ -210,40 +214,42 @@ async def proxy_request(request: web.Request, session: ClientSession) -> web.Res
# Forward headers (excluding hop-by-hop headers)
headers = {}
- hop_by_hop = {'connection', 'keep-alive', 'proxy-authenticate',
- 'proxy-authorization', 'te', 'trailers', 'transfer-encoding',
- 'upgrade', 'host'}
+ hop_by_hop = {
+ "connection",
+ "keep-alive",
+ "proxy-authenticate",
+ "proxy-authorization",
+ "te",
+ "trailers",
+ "transfer-encoding",
+ "upgrade",
+ "host",
+ }
for key, value in request.headers.items():
- if key.lower() not in hop_by_hop:
- # Don't forward our auth headers to upstream
- if key.lower() not in ('authorization', 'x-api-key'):
- headers[key] = value
+ if key.lower() not in hop_by_hop and key.lower() not in ("authorization", "x-api-key"):
+ headers[key] = value
# Add Privatemode API key for upstream authentication
if PRIVATEMODE_API_KEY:
- headers['Authorization'] = f'Bearer {PRIVATEMODE_API_KEY}'
+ headers["Authorization"] = f"Bearer {PRIVATEMODE_API_KEY}"
# Read request body
body = await request.read()
# Detect endpoint type and model for usage tracking
endpoint = detect_endpoint_type(path)
- request_model = extract_model_from_request(body) if body else 'unknown'
+ request_model = extract_model_from_request(body) if body else "unknown"
# Track audio file size for transcription requests
audio_bytes = 0
- if endpoint == 'transcriptions':
+ if endpoint == "transcriptions":
audio_bytes = len(body)
try:
# Forward request to upstream
async with session.request(
- method=request.method,
- url=upstream_url,
- headers=headers,
- data=body,
- allow_redirects=False
+ method=request.method, url=upstream_url, headers=headers, data=body, allow_redirects=False
) as upstream_response:
# Read response
response_body = await upstream_response.read()
@@ -255,38 +261,28 @@ async def proxy_request(request: web.Request, session: ClientSession) -> web.Res
response_headers[key] = value
# Track usage if request was successful and we have a key_id
- if upstream_response.status == 200 and 'key_id' in request:
+ if upstream_response.status == 200 and "key_id" in request:
usage = extract_usage_from_response(response_body, endpoint)
- model = usage['model'] if usage['model'] != 'unknown' else request_model
+ model = usage["model"] if usage["model"] != "unknown" else request_model
tracker = get_tracker()
tracker.record_usage(
- key_id=request['key_id'],
+ key_id=request["key_id"],
model=model,
endpoint=endpoint,
- prompt_tokens=usage['prompt_tokens'],
- completion_tokens=usage['completion_tokens'],
- total_tokens=usage['total_tokens'],
- audio_bytes=audio_bytes
+ prompt_tokens=usage["prompt_tokens"],
+ completion_tokens=usage["completion_tokens"],
+ total_tokens=usage["total_tokens"],
+ audio_bytes=audio_bytes,
)
- return web.Response(
- status=upstream_response.status,
- headers=response_headers,
- body=response_body
- )
+ return web.Response(status=upstream_response.status, headers=response_headers, body=response_body)
- except asyncio.TimeoutError:
- return web.json_response(
- {'error': 'Upstream timeout'},
- status=504
- )
+ except TimeoutError:
+ return web.json_response({"error": "Upstream timeout"}, status=504)
except Exception as e:
print(f"Proxy error: {e}")
- return web.json_response(
- {'error': 'Proxy error'},
- status=502
- )
+ return web.json_response({"error": "Proxy error"}, status=502)
@web.middleware
@@ -300,18 +296,15 @@ async def https_enforcement_middleware(request: web.Request, handler):
# If behind a trusted proxy, also check X-Forwarded-Proto
if not is_https and TRUST_PROXY:
- forwarded_proto = request.headers.get('X-Forwarded-Proto', '')
- is_https = forwarded_proto.lower() == 'https'
+ forwarded_proto = request.headers.get("X-Forwarded-Proto", "")
+ is_https = forwarded_proto.lower() == "https"
# Allow health checks over HTTP for internal load balancer probes
- if request.path == '/health':
+ if request.path == "/health":
return await handler(request)
if not is_https:
- return web.json_response(
- {'error': 'HTTPS required. Please use a secure connection.'},
- status=403
- )
+ return web.json_response({"error": "HTTPS required. Please use a secure connection."}, status=403)
return await handler(request)
@@ -322,8 +315,8 @@ async def security_headers_middleware(request: web.Request, handler):
response = await handler(request)
# Content Security Policy - restrictive for admin UI
- if request.path.startswith('/admin'):
- response.headers['Content-Security-Policy'] = (
+ if request.path.startswith("/admin"):
+ response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
@@ -333,27 +326,23 @@ async def security_headers_middleware(request: web.Request, handler):
)
# Prevent MIME type sniffing
- response.headers['X-Content-Type-Options'] = 'nosniff'
+ response.headers["X-Content-Type-Options"] = "nosniff"
# Prevent clickjacking
- response.headers['X-Frame-Options'] = 'DENY'
+ response.headers["X-Frame-Options"] = "DENY"
# XSS protection (legacy, but still useful for older browsers)
- response.headers['X-XSS-Protection'] = '1; mode=block'
+ response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer policy
- response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
+ response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions policy (restrict sensitive APIs)
- response.headers['Permissions-Policy'] = (
- 'geolocation=(), microphone=(), camera=()'
- )
+ response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
# HSTS - enforce HTTPS for 1 year when TLS is enabled
if TLS_ENABLED or FORCE_HTTPS:
- response.headers['Strict-Transport-Security'] = (
- 'max-age=31536000; includeSubDomains'
- )
+ response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
@@ -364,49 +353,45 @@ def create_auth_middleware(key_manager: KeyManager):
@web.middleware
async def auth_middleware(request: web.Request, handler):
# Skip auth for health check and admin routes
- if request.path == '/health' or request.path.startswith('/admin'):
+ if request.path == "/health" or request.path.startswith("/admin"):
return await handler(request)
# Check global IP rate limit first
client_ip = get_client_ip(request)
- ip_allowed, ip_remaining, ip_limit, ip_window = check_ip_rate_limit(client_ip)
+ ip_allowed, _ip_remaining, ip_limit, ip_window = check_ip_rate_limit(client_ip)
if not ip_allowed:
return web.json_response(
- {'error': 'Global rate limit exceeded'},
+ {"error": "Global rate limit exceeded"},
status=429,
headers={
- 'X-RateLimit-Limit': str(ip_limit),
- 'X-RateLimit-Remaining': '0',
- 'X-RateLimit-Reset': str(int(time.time()) + ip_window)
- }
+ "X-RateLimit-Limit": str(ip_limit),
+ "X-RateLimit-Remaining": "0",
+ "X-RateLimit-Reset": str(int(time.time()) + ip_window),
+ },
)
# Extract and validate API key
api_key = extract_api_key(request)
if not api_key:
return web.json_response(
- {'error': 'Missing API key. Use Authorization: Bearer or X-API-Key header'},
- status=401
+ {"error": "Missing API key. Use Authorization: Bearer or X-API-Key header"}, status=401
)
valid, key_obj = key_manager.validate_key(api_key)
if not valid:
- return web.json_response(
- {'error': 'Invalid or expired API key'},
- status=401
- )
+ return web.json_response({"error": "Invalid or expired API key"}, status=401)
# Check global rate limit first (shared across ALL keys)
global_allowed, global_remaining, global_limit, global_window = check_global_rate_limit()
if not global_allowed:
return web.json_response(
- {'error': 'Global rate limit exceeded'},
+ {"error": "Global rate limit exceeded"},
status=429,
headers={
- 'X-RateLimit-Limit': str(global_limit),
- 'X-RateLimit-Remaining': '0',
- 'X-RateLimit-Reset': str(int(time.time()) + global_window)
- }
+ "X-RateLimit-Limit": str(global_limit),
+ "X-RateLimit-Remaining": "0",
+ "X-RateLimit-Reset": str(int(time.time()) + global_window),
+ },
)
# Check per-key rate limit (only if key has a specific limit set)
@@ -415,24 +400,24 @@ async def auth_middleware(request: web.Request, handler):
if not key_allowed:
return web.json_response(
- {'error': 'Per-key rate limit exceeded'},
+ {"error": "Per-key rate limit exceeded"},
status=429,
headers={
- 'X-RateLimit-Limit': str(key_limit),
- 'X-RateLimit-Remaining': '0',
- 'X-RateLimit-Reset': str(int(time.time()) + global_window)
- }
+ "X-RateLimit-Limit": str(key_limit),
+ "X-RateLimit-Remaining": "0",
+ "X-RateLimit-Reset": str(int(time.time()) + global_window),
+ },
)
# Add rate limit headers to request for handler
# Show per-key limit if set, otherwise show global
if per_key_limit:
- request['rate_limit_remaining'] = key_remaining
- request['rate_limit_limit'] = key_limit
+ request["rate_limit_remaining"] = key_remaining
+ request["rate_limit_limit"] = key_limit
else:
- request['rate_limit_remaining'] = global_remaining
- request['rate_limit_limit'] = global_limit
- request['key_id'] = key_obj.key_id
+ request["rate_limit_remaining"] = global_remaining
+ request["rate_limit_limit"] = global_limit
+ request["key_id"] = key_obj.key_id
return await handler(request)
@@ -441,34 +426,34 @@ async def auth_middleware(request: web.Request, handler):
async def health_handler(request: web.Request) -> web.Response:
"""Health check endpoint."""
- return web.json_response({'status': 'healthy'})
+ return web.json_response({"status": "healthy"})
async def key_info_handler(request: web.Request) -> web.Response:
"""Return info about the current API key (non-sensitive)."""
- key_manager = request.app['key_manager']
+ key_manager = request.app["key_manager"]
api_key = extract_api_key(request)
if not api_key:
- return web.json_response({'error': 'No API key provided'}, status=401)
+ return web.json_response({"error": "No API key provided"}, status=401)
info = key_manager.get_key_info(api_key)
if not info:
- return web.json_response({'error': 'Invalid API key'}, status=401)
+ return web.json_response({"error": "Invalid API key"}, status=401)
return web.json_response(info)
async def catch_all_handler(request: web.Request) -> web.Response:
"""Proxy all other requests to upstream."""
- session = request.app['client_session']
+ session = request.app["client_session"]
response = await proxy_request(request, session)
# Add rate limit headers
- if 'rate_limit_remaining' in request:
- response.headers['X-RateLimit-Remaining'] = str(request['rate_limit_remaining'])
- response.headers['X-RateLimit-Limit'] = str(request.get('rate_limit_limit', 100))
+ if "rate_limit_remaining" in request:
+ response.headers["X-RateLimit-Remaining"] = str(request["rate_limit_remaining"])
+ response.headers["X-RateLimit-Limit"] = str(request.get("rate_limit_limit", 100))
return response
@@ -476,13 +461,13 @@ async def catch_all_handler(request: web.Request) -> web.Response:
async def on_startup(app: web.Application):
"""Initialize client session on startup."""
timeout = ClientTimeout(total=300) # 5 minute timeout for LLM requests
- app['client_session'] = ClientSession(timeout=timeout)
+ app["client_session"] = ClientSession(timeout=timeout)
print(f"Auth proxy started, forwarding to {UPSTREAM_URL}")
async def on_cleanup(app: web.Application):
"""Cleanup client session and flush usage data."""
- await app['client_session'].close()
+ await app["client_session"].close()
# Flush usage data to disk
get_tracker().flush()
@@ -492,23 +477,19 @@ def create_app() -> web.Application:
key_manager = KeyManager(API_KEYS_FILE)
# Middleware order: HTTPS enforcement -> security headers -> auth
- middlewares = [
- https_enforcement_middleware,
- security_headers_middleware,
- create_auth_middleware(key_manager)
- ]
+ middlewares = [https_enforcement_middleware, security_headers_middleware, create_auth_middleware(key_manager)]
app = web.Application(middlewares=middlewares)
- app['key_manager'] = key_manager
+ app["key_manager"] = key_manager
# Routes
- app.router.add_get('/health', health_handler)
- app.router.add_get('/auth/key-info', key_info_handler)
+ app.router.add_get("/health", health_handler)
+ app.router.add_get("/auth/key-info", key_info_handler)
# Admin UI routes
setup_admin_routes(app)
# Catch-all for proxying (must be last)
- app.router.add_route('*', '/{path_info:.*}', catch_all_handler)
+ app.router.add_route("*", "/{path_info:.*}", catch_all_handler)
app.on_startup.append(on_startup)
app.on_cleanup.append(on_cleanup)
@@ -528,7 +509,7 @@ def create_ssl_context():
return ssl_context
-if __name__ == '__main__':
+if __name__ == "__main__":
app = create_app()
ssl_context = create_ssl_context()
@@ -537,4 +518,4 @@ def create_ssl_context():
else:
print(f"Starting HTTP server on port {PORT} (TLS not configured)")
- web.run_app(app, host='0.0.0.0', port=PORT, ssl_context=ssl_context)
+ web.run_app(app, host="0.0.0.0", port=PORT, ssl_context=ssl_context)
diff --git a/auth-proxy/usage_tracker.py b/auth-proxy/usage_tracker.py
index 7aaa9c8..14595ba 100644
--- a/auth-proxy/usage_tracker.py
+++ b/auth-proxy/usage_tracker.py
@@ -5,39 +5,39 @@
Calculates costs based on Privatemode pricing.
"""
-import os
import json
+import os
+import threading
import time
+from collections import defaultdict
+from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from pathlib import Path
-from dataclasses import dataclass, asdict
-from typing import Optional
-from collections import defaultdict
-import threading
from config import USAGE_FILE
# Privatemode pricing (EUR per unit)
PRICING = {
# Chat models: EUR per 1M tokens
- 'gpt-oss-120b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000},
- 'llama-3.3-70b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000},
- 'gemma-3-27b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000},
- 'qwen3-coder-30b-a3b': {'type': 'token', 'rate': 5.0, 'per': 1_000_000},
+ "gpt-oss-120b": {"type": "token", "rate": 5.0, "per": 1_000_000},
+ "llama-3.3-70b": {"type": "token", "rate": 5.0, "per": 1_000_000},
+ "gemma-3-27b": {"type": "token", "rate": 5.0, "per": 1_000_000},
+ "qwen3-coder-30b-a3b": {"type": "token", "rate": 5.0, "per": 1_000_000},
# Embedding models: EUR per 1M tokens
- 'multilingual-e5': {'type': 'token', 'rate': 0.13, 'per': 1_000_000},
- 'qwen3-embedding-4b': {'type': 'token', 'rate': 0.13, 'per': 1_000_000},
+ "multilingual-e5": {"type": "token", "rate": 0.13, "per": 1_000_000},
+ "qwen3-embedding-4b": {"type": "token", "rate": 0.13, "per": 1_000_000},
# Speech-to-text: EUR per MB
- 'whisper-large-v3': {'type': 'audio', 'rate': 0.096, 'per': 1}, # per MB
+ "whisper-large-v3": {"type": "audio", "rate": 0.096, "per": 1}, # per MB
}
# Default pricing for unknown models
-DEFAULT_PRICING = {'type': 'token', 'rate': 5.0, 'per': 1_000_000}
+DEFAULT_PRICING = {"type": "token", "rate": 5.0, "per": 1_000_000}
@dataclass
class UsageRecord:
"""Single usage record."""
+
timestamp: float
key_id: str
model: str
@@ -52,7 +52,7 @@ class UsageRecord:
class UsageTracker:
"""Thread-safe usage tracker with file persistence."""
- def __init__(self, usage_file: str = None):
+ def __init__(self, usage_file: str | None = None):
self.usage_file = usage_file or USAGE_FILE
self._lock = threading.Lock()
self._records: list[dict] = []
@@ -62,32 +62,32 @@ def _load(self):
"""Load usage data from file."""
try:
if os.path.exists(self.usage_file):
- with open(self.usage_file, 'r') as f:
+ with open(self.usage_file) as f:
data = json.load(f)
- self._records = data.get('records', [])
- except (json.JSONDecodeError, IOError):
+ self._records = data.get("records", [])
+ except (OSError, json.JSONDecodeError):
self._records = []
def _save(self):
"""Save usage data to file."""
try:
Path(self.usage_file).parent.mkdir(parents=True, exist_ok=True)
- with open(self.usage_file, 'w') as f:
- json.dump({'records': self._records}, f)
- except IOError as e:
+ with open(self.usage_file, "w") as f:
+ json.dump({"records": self._records}, f)
+ except OSError as e:
print(f"Failed to save usage data: {e}")
def calculate_cost(self, model: str, tokens: int = 0, audio_bytes: int = 0) -> float:
"""Calculate cost in EUR based on model and usage."""
pricing = PRICING.get(model, DEFAULT_PRICING)
- if pricing['type'] == 'audio':
+ if pricing["type"] == "audio":
# Audio: cost per MB
mb = audio_bytes / (1024 * 1024)
- return mb * pricing['rate']
+ return mb * pricing["rate"]
else:
# Tokens: cost per million tokens
- return (tokens / pricing['per']) * pricing['rate']
+ return (tokens / pricing["per"]) * pricing["rate"]
def record_usage(
self,
@@ -97,7 +97,7 @@ def record_usage(
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
- audio_bytes: int = 0
+ audio_bytes: int = 0,
):
"""Record a usage event."""
if total_tokens == 0 and prompt_tokens + completion_tokens > 0:
@@ -114,7 +114,7 @@ def record_usage(
completion_tokens=completion_tokens,
total_tokens=total_tokens,
audio_bytes=audio_bytes,
- cost_eur=cost
+ cost_eur=cost,
)
with self._lock:
@@ -129,10 +129,7 @@ def flush(self):
self._save()
def get_usage_summary(
- self,
- key_id: Optional[str] = None,
- start_time: Optional[float] = None,
- end_time: Optional[float] = None
+ self, key_id: str | None = None, start_time: float | None = None, end_time: float | None = None
) -> dict:
"""
Get usage summary for a key or all keys within a time range.
@@ -152,51 +149,47 @@ def get_usage_summary(
# Filter records
filtered = []
for r in records:
- if key_id and r['key_id'] != key_id:
+ if key_id and r["key_id"] != key_id:
continue
- if start_time and r['timestamp'] < start_time:
+ if start_time and r["timestamp"] < start_time:
continue
- if end_time and r['timestamp'] > end_time:
+ if end_time and r["timestamp"] > end_time:
continue
filtered.append(r)
# Aggregate
summary = {
- 'total_tokens': 0,
- 'total_audio_bytes': 0,
- 'total_cost_eur': 0.0,
- 'by_model': defaultdict(lambda: {'tokens': 0, 'audio_bytes': 0, 'cost': 0.0, 'requests': 0}),
- 'by_endpoint': defaultdict(lambda: {'tokens': 0, 'requests': 0, 'cost': 0.0}),
- 'requests': len(filtered)
+ "total_tokens": 0,
+ "total_audio_bytes": 0,
+ "total_cost_eur": 0.0,
+ "by_model": defaultdict(lambda: {"tokens": 0, "audio_bytes": 0, "cost": 0.0, "requests": 0}),
+ "by_endpoint": defaultdict(lambda: {"tokens": 0, "requests": 0, "cost": 0.0}),
+ "requests": len(filtered),
}
for r in filtered:
- summary['total_tokens'] += r['total_tokens']
- summary['total_audio_bytes'] += r['audio_bytes']
- summary['total_cost_eur'] += r['cost_eur']
+ summary["total_tokens"] += r["total_tokens"]
+ summary["total_audio_bytes"] += r["audio_bytes"]
+ summary["total_cost_eur"] += r["cost_eur"]
- model = r['model']
- summary['by_model'][model]['tokens'] += r['total_tokens']
- summary['by_model'][model]['audio_bytes'] += r['audio_bytes']
- summary['by_model'][model]['cost'] += r['cost_eur']
- summary['by_model'][model]['requests'] += 1
+ model = r["model"]
+ summary["by_model"][model]["tokens"] += r["total_tokens"]
+ summary["by_model"][model]["audio_bytes"] += r["audio_bytes"]
+ summary["by_model"][model]["cost"] += r["cost_eur"]
+ summary["by_model"][model]["requests"] += 1
- endpoint = r['endpoint']
- summary['by_endpoint'][endpoint]['tokens'] += r['total_tokens']
- summary['by_endpoint'][endpoint]['requests'] += 1
- summary['by_endpoint'][endpoint]['cost'] += r['cost_eur']
+ endpoint = r["endpoint"]
+ summary["by_endpoint"][endpoint]["tokens"] += r["total_tokens"]
+ summary["by_endpoint"][endpoint]["requests"] += 1
+ summary["by_endpoint"][endpoint]["cost"] += r["cost_eur"]
# Convert defaultdicts to regular dicts
- summary['by_model'] = dict(summary['by_model'])
- summary['by_endpoint'] = dict(summary['by_endpoint'])
+ summary["by_model"] = dict(summary["by_model"])
+ summary["by_endpoint"] = dict(summary["by_endpoint"])
return summary
- def get_usage_by_key(
- self,
- start_time: Optional[float] = None,
- end_time: Optional[float] = None
- ) -> dict[str, dict]:
+ def get_usage_by_key(self, start_time: float | None = None, end_time: float | None = None) -> dict[str, dict]:
"""Get usage breakdown by key."""
with self._lock:
records = self._records.copy()
@@ -204,34 +197,25 @@ def get_usage_by_key(
# Filter by time
filtered = []
for r in records:
- if start_time and r['timestamp'] < start_time:
+ if start_time and r["timestamp"] < start_time:
continue
- if end_time and r['timestamp'] > end_time:
+ if end_time and r["timestamp"] > end_time:
continue
filtered.append(r)
# Group by key
- by_key = defaultdict(lambda: {
- 'tokens': 0,
- 'audio_bytes': 0,
- 'cost_eur': 0.0,
- 'requests': 0
- })
+ by_key = defaultdict(lambda: {"tokens": 0, "audio_bytes": 0, "cost_eur": 0.0, "requests": 0})
for r in filtered:
- key_id = r['key_id']
- by_key[key_id]['tokens'] += r['total_tokens']
- by_key[key_id]['audio_bytes'] += r['audio_bytes']
- by_key[key_id]['cost_eur'] += r['cost_eur']
- by_key[key_id]['requests'] += 1
+ key_id = r["key_id"]
+ by_key[key_id]["tokens"] += r["total_tokens"]
+ by_key[key_id]["audio_bytes"] += r["audio_bytes"]
+ by_key[key_id]["cost_eur"] += r["cost_eur"]
+ by_key[key_id]["requests"] += 1
return dict(by_key)
- def get_daily_breakdown(
- self,
- key_id: Optional[str] = None,
- days: int = 30
- ) -> list[dict]:
+ def get_daily_breakdown(self, key_id: str | None = None, days: int = 30) -> list[dict]:
"""Get daily usage breakdown for the last N days."""
with self._lock:
records = self._records.copy()
@@ -244,28 +228,25 @@ def get_daily_breakdown(
# Filter records
filtered = []
for r in records:
- if r['timestamp'] < start_time:
+ if r["timestamp"] < start_time:
continue
- if key_id and r['key_id'] != key_id:
+ if key_id and r["key_id"] != key_id:
continue
filtered.append(r)
# Group by day
- daily = defaultdict(lambda: {'tokens': 0, 'cost_eur': 0.0, 'requests': 0})
+ daily = defaultdict(lambda: {"tokens": 0, "cost_eur": 0.0, "requests": 0})
for r in filtered:
- day = datetime.fromtimestamp(r['timestamp']).strftime('%Y-%m-%d')
- daily[day]['tokens'] += r['total_tokens']
- daily[day]['cost_eur'] += r['cost_eur']
- daily[day]['requests'] += 1
+ day = datetime.fromtimestamp(r["timestamp"]).strftime("%Y-%m-%d")
+ daily[day]["tokens"] += r["total_tokens"]
+ daily[day]["cost_eur"] += r["cost_eur"]
+ daily[day]["requests"] += 1
# Convert to sorted list
result = []
for date_str in sorted(daily.keys()):
- result.append({
- 'date': date_str,
- **daily[date_str]
- })
+ result.append({"date": date_str, **daily[date_str]})
return result
@@ -284,29 +265,29 @@ def get_time_range(period: str) -> tuple[float, float]:
now = datetime.now()
end = now.timestamp()
- if period == 'today':
+ if period == "today":
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
- elif period == 'yesterday':
+ elif period == "yesterday":
yesterday = now - timedelta(days=1)
start = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
end = now.replace(hour=0, minute=0, second=0, microsecond=0).timestamp()
- elif period == 'week':
+ elif period == "week":
start = now - timedelta(days=7)
- elif period == 'month':
+ elif period == "month":
start = now - timedelta(days=30)
- elif period == 'year':
+ elif period == "year":
start = now - timedelta(days=365)
- elif period == 'all':
+ elif period == "all":
return None, None
else:
# Default to all time
return None, None
- return start.timestamp() if hasattr(start, 'timestamp') else start, end
+ return start.timestamp() if hasattr(start, "timestamp") else start, end
# Global instance
-_tracker: Optional[UsageTracker] = None
+_tracker: UsageTracker | None = None
def get_tracker() -> UsageTracker:
diff --git a/auth-proxy/utils.py b/auth-proxy/utils.py
index 0d83f7f..8713dc9 100644
--- a/auth-proxy/utils.py
+++ b/auth-proxy/utils.py
@@ -3,6 +3,7 @@
"""
from aiohttp import web
+
from config import TRUST_PROXY
@@ -14,7 +15,7 @@ def get_client_ip(request: web.Request) -> str:
This prevents IP spoofing when not behind a trusted reverse proxy.
"""
if TRUST_PROXY:
- forwarded = request.headers.get('X-Forwarded-For', '')
+ forwarded = request.headers.get("X-Forwarded-For", "")
if forwarded:
- return forwarded.split(',')[0].strip()
- return request.remote or 'unknown'
+ return forwarded.split(",")[0].strip()
+ return request.remote or "unknown"
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..efe6fe2
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,41 @@
+[tool.ruff]
+target-version = "py312"
+line-length = 120
+
+[tool.ruff.lint]
+select = [
+ "E", # pycodestyle errors
+ "F", # pyflakes
+ "W", # pycodestyle warnings
+ "I", # isort
+ "UP", # pyupgrade
+ "B", # flake8-bugbear
+ "SIM", # flake8-simplify
+ "RUF", # ruff-specific rules
+]
+ignore = [
+ "B904", # allow raise without from in some cases
+ "SIM108", # allow if-else blocks for readability
+ "RUF001", # allow unicode characters (e.g. × symbol in admin UI)
+]
+
+[tool.ruff.lint.per-file-ignores]
+"auth-proxy/admin.py" = [
+ "E501", # HTML template strings are inherently long
+ "SIM103", # inline conditions would reduce readability in HTML templates
+ "SIM105", # contextlib.suppress less readable in HTML-heavy file
+]
+
+[tool.ruff.lint.isort]
+known-first-party = ["key_manager", "admin", "usage_tracker", "config", "utils"]
+
+[tool.mypy]
+python_version = "3.12"
+warn_return_any = true
+warn_unused_configs = true
+check_untyped_defs = true
+no_strict_optional = true
+warn_redundant_casts = true
+warn_unused_ignores = true
+warn_no_return = true
+strict_equality = true
diff --git a/scripts/manage_keys.py b/scripts/manage_keys.py
index f5e7d40..9e51d09 100755
--- a/scripts/manage_keys.py
+++ b/scripts/manage_keys.py
@@ -45,7 +45,7 @@ def load_keys() -> dict:
def save_keys(data: dict) -> None:
"""Save keys to file."""
KEYS_FILE.parent.mkdir(parents=True, exist_ok=True)
- with open(KEYS_FILE, 'w') as f:
+ with open(KEYS_FILE, "w") as f:
json.dump(data, f, indent=2)
print(f"Keys saved to {KEYS_FILE}")
@@ -82,7 +82,7 @@ def cmd_generate(args):
"key_hash": key_hash,
"created_at": time.time(),
"description": args.description or "",
- "enabled": True
+ "enabled": True,
}
if args.expires_days:
@@ -193,7 +193,7 @@ def cmd_rotate(args):
"created_at": time.time(),
"description": old_key.get("description", "") + " (rotated)",
"enabled": True,
- "rotated_from": args.key_id
+ "rotated_from": args.key_id,
}
if args.expires_days:
@@ -242,7 +242,7 @@ def main():
parser = argparse.ArgumentParser(
description="Manage API keys for Privatemode proxy",
formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog=__doc__
+ epilog=__doc__,
)
subparsers = parser.add_subparsers(dest="command", required=True)
diff --git a/scripts/scrape_docs.py b/scripts/scrape_docs.py
index 8fb2d44..188cf9b 100644
--- a/scripts/scrape_docs.py
+++ b/scripts/scrape_docs.py
@@ -6,11 +6,11 @@
import os
import re
from pathlib import Path
+from urllib.parse import urljoin, urlparse
import requests
from bs4 import BeautifulSoup
from markdownify import markdownify as md
-from urllib.parse import urljoin, urlparse
BASE_URL = "https://docs.privatemode.ai"
DOCS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docs")
@@ -40,7 +40,7 @@ def get_page(url: str) -> BeautifulSoup | None:
print(f"Fetching: {full_url}")
response = requests.get(full_url, timeout=10)
response.raise_for_status()
- return BeautifulSoup(response.text, 'html.parser')
+ return BeautifulSoup(response.text, "html.parser")
except Exception as e:
print(f"Error fetching {url}: {e}")
return None
@@ -50,45 +50,52 @@ def extract_nav_links(soup: BeautifulSoup) -> list[str]:
"""Extract navigation links from the page."""
links = []
# Look for sidebar navigation
- for a in soup.find_all('a', href=True):
- href = a['href']
- if href.startswith('/') and not href.startswith('//'):
- if not any(x in href for x in ['#', 'mailto:', 'javascript:']):
- links.append(href)
+ for a in soup.find_all("a", href=True):
+ href = a["href"]
+ if (
+ href.startswith("/")
+ and not href.startswith("//")
+ and not any(x in href for x in ["#", "mailto:", "javascript:"])
+ ):
+ links.append(href)
return list(set(links))
def extract_content(soup: BeautifulSoup) -> tuple[str, str]:
"""Extract main content from the page."""
# Try to find main content area
- main = soup.find('main') or soup.find('article') or soup.find(class_=re.compile(r'content|docs|markdown'))
+ main = soup.find("main") or soup.find("article") or soup.find(class_=re.compile(r"content|docs|markdown"))
if not main:
# Fallback: try to find the largest div with text
- main = soup.find('body')
+ main = soup.find("body")
if not main:
return "", ""
# Get title
title = ""
- h1 = main.find('h1')
+ h1 = main.find("h1")
if h1:
title = h1.get_text(strip=True)
else:
- title_tag = soup.find('title')
+ title_tag = soup.find("title")
if title_tag:
- title = title_tag.get_text(strip=True).split('|')[0].strip()
+ title = title_tag.get_text(strip=True).split("|")[0].strip()
# Remove navigation, footer, etc.
- for tag in main.find_all(['nav', 'footer', 'header', 'script', 'style']):
+ for tag in main.find_all(["nav", "footer", "header", "script", "style"]):
tag.decompose()
# Convert to markdown
- content = md(str(main), heading_style="ATX", code_language_callback=lambda el: "python" if "python" in str(el).lower() else "bash")
+ content = md(
+ str(main),
+ heading_style="ATX",
+ code_language_callback=lambda el: "python" if "python" in str(el).lower() else "bash",
+ )
# Clean up markdown
- content = re.sub(r'\n{3,}', '\n\n', content)
+ content = re.sub(r"\n{3,}", "\n\n", content)
content = content.strip()
return title, content
@@ -96,14 +103,14 @@ def extract_content(soup: BeautifulSoup) -> tuple[str, str]:
def url_to_filename(url: str) -> str:
"""Convert URL to filename safely, preventing path traversal attacks."""
- path = urlparse(url).path.strip('/')
+ path = urlparse(url).path.strip("/")
if not path:
return "index.md"
# Replace slashes with underscores
- name = path.replace('/', '_')
+ name = path.replace("/", "_")
# Remove any path traversal attempts and dangerous characters
# Only allow alphanumeric, underscore, and hyphen (block backslashes too)
- name = re.sub(r'[^a-zA-Z0-9_-]', '', name)
+ name = re.sub(r"[^a-zA-Z0-9_-]", "", name)
if not name:
return "index.md"
return f"{name}.md"
@@ -138,25 +145,25 @@ def scrape_all():
try:
filepath = safe_join_path(DOCS_DIR, filename)
- with open(filepath, 'w') as f:
+ with open(filepath, "w") as f:
if title:
f.write(f"# {title}\n\n")
f.write(content)
print(f" Saved: {filename}")
- pages[url] = {'title': title, 'file': filename}
+ pages[url] = {"title": title, "file": filename}
except (ValueError, requests.RequestException) as e:
print(f" Skipping {url}: {e}")
continue
# Discover more links
for link in extract_nav_links(soup):
- if link not in visited and urljoin(BASE_URL, link) not in visited:
- if link.startswith('/') and not link.startswith('//'):
- to_visit.append(link)
+ full_url = urljoin(BASE_URL, link)
+ if link not in visited and full_url not in visited and link.startswith("/") and not link.startswith("//"):
+ to_visit.append(link)
# Create index
- with open(safe_join_path(DOCS_DIR, "README.md"), 'w') as f:
+ with open(safe_join_path(DOCS_DIR, "README.md"), "w") as f:
f.write("# Privatemode Documentation\n\n")
f.write("Scraped from https://docs.privatemode.ai\n\n")
f.write("## Pages\n\n")
diff --git a/tests/conftest.py b/tests/conftest.py
index e424cab..5675e47 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -12,22 +12,22 @@
import pytest
# Add auth-proxy to path so imports work
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
# ── Environment setup (must happen before importing app modules) ──
# Provide required env vars before any module-level code runs.
# PBKDF2_SALT must be >= 16 bytes or admin.py raises ValueError at import time.
-os.environ.setdefault('PBKDF2_SALT', 'test-salt-value-1234567890')
-os.environ.setdefault('ADMIN_PASSWORD', 'test-admin-password')
-os.environ.setdefault('PRIVATEMODE_API_KEY', 'pm_test_upstream_key')
-os.environ.setdefault('API_KEYS_FILE', '')
-os.environ.setdefault('SETTINGS_FILE', '')
-os.environ.setdefault('USAGE_FILE', '')
-os.environ.setdefault('UPSTREAM_URL', 'http://localhost:19999')
-os.environ.setdefault('TRUST_PROXY', 'false')
-os.environ.setdefault('FORCE_HTTPS', 'false')
+os.environ.setdefault("PBKDF2_SALT", "test-salt-value-1234567890")
+os.environ.setdefault("ADMIN_PASSWORD", "test-admin-password")
+os.environ.setdefault("PRIVATEMODE_API_KEY", "pm_test_upstream_key")
+os.environ.setdefault("API_KEYS_FILE", "")
+os.environ.setdefault("SETTINGS_FILE", "")
+os.environ.setdefault("USAGE_FILE", "")
+os.environ.setdefault("UPSTREAM_URL", "http://localhost:19999")
+os.environ.setdefault("TRUST_PROXY", "false")
+os.environ.setdefault("FORCE_HTTPS", "false")
from tests.helpers import make_keys_file
@@ -63,15 +63,15 @@ def app_env(keys_file, settings_file, usage_file):
Returns a dict of the env vars set.
"""
env = {
- 'API_KEYS_FILE': keys_file,
- 'SETTINGS_FILE': settings_file,
- 'USAGE_FILE': usage_file,
- 'ADMIN_PASSWORD': 'test-admin-password',
- 'PRIVATEMODE_API_KEY': 'pm_test_upstream_key',
- 'UPSTREAM_URL': 'http://localhost:19999',
- 'TRUST_PROXY': 'false',
- 'FORCE_HTTPS': 'false',
- 'PBKDF2_SALT': 'test-salt-value-1234567890',
+ "API_KEYS_FILE": keys_file,
+ "SETTINGS_FILE": settings_file,
+ "USAGE_FILE": usage_file,
+ "ADMIN_PASSWORD": "test-admin-password",
+ "PRIVATEMODE_API_KEY": "pm_test_upstream_key",
+ "UPSTREAM_URL": "http://localhost:19999",
+ "TRUST_PROXY": "false",
+ "FORCE_HTTPS": "false",
+ "PBKDF2_SALT": "test-salt-value-1234567890",
}
with patch.dict(os.environ, env):
yield env
diff --git a/tests/helpers.py b/tests/helpers.py
index 62fae19..c995964 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -7,7 +7,6 @@
import os
import time
-
# ── Test API key constants ──
TEST_API_KEY = "pm_test-key-for-unit-tests-12345"
diff --git a/tests/test_admin.py b/tests/test_admin.py
index 0582557..7dc5085 100644
--- a/tests/test_admin.py
+++ b/tests/test_admin.py
@@ -9,25 +9,25 @@
import pytest
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
from admin import (
- create_session,
- validate_session,
- delete_session,
+ _csrf_tokens,
+ _login_attempts,
_sessions,
check_login_rate_limit,
- _login_attempts,
+ create_session,
+ delete_session,
+ format_timestamp,
generate_csrf_token,
- validate_csrf_token,
- _csrf_tokens,
+ get_key_status,
load_keys,
- save_keys,
load_settings,
+ save_keys,
save_settings,
- get_key_status,
- format_timestamp,
update_key_rate_limit,
+ validate_csrf_token,
+ validate_session,
)
from tests.helpers import make_keys_file
@@ -145,48 +145,48 @@ class TestKeysCRUD:
def test_load_keys(self, tmp_path):
keys_file = make_keys_file(tmp_path)
- with patch('admin.KEYS_FILE', keys_file):
+ with patch("admin.KEYS_FILE", keys_file):
data = load_keys()
- assert 'keys' in data
- assert len(data['keys']) == 4
+ assert "keys" in data
+ assert len(data["keys"]) == 4
def test_load_missing_file(self, tmp_path):
- with patch('admin.KEYS_FILE', os.path.join(str(tmp_path), 'nope.json')):
+ with patch("admin.KEYS_FILE", os.path.join(str(tmp_path), "nope.json")):
data = load_keys()
assert data == {"keys": []}
def test_save_and_reload(self, tmp_path):
- keys_file = os.path.join(str(tmp_path), 'keys.json')
- with patch('admin.KEYS_FILE', keys_file):
+ keys_file = os.path.join(str(tmp_path), "keys.json")
+ with patch("admin.KEYS_FILE", keys_file):
save_keys({"keys": [{"key_id": "x", "enabled": True}]})
data = load_keys()
- assert len(data['keys']) == 1
- assert data['keys'][0]['key_id'] == 'x'
+ assert len(data["keys"]) == 1
+ assert data["keys"][0]["key_id"] == "x"
def test_update_key_rate_limit(self, tmp_path):
keys_file = make_keys_file(tmp_path)
- with patch('admin.KEYS_FILE', keys_file):
+ with patch("admin.KEYS_FILE", keys_file):
result = update_key_rate_limit("test_key_1", 50)
assert result is True
# Verify it was saved
data = load_keys()
- key = next(k for k in data['keys'] if k['key_id'] == 'test_key_1')
- assert key['rate_limit'] == 50
+ key = next(k for k in data["keys"] if k["key_id"] == "test_key_1")
+ assert key["rate_limit"] == 50
def test_update_key_rate_limit_clear(self, tmp_path):
keys_file = make_keys_file(tmp_path)
- with patch('admin.KEYS_FILE', keys_file):
+ with patch("admin.KEYS_FILE", keys_file):
# key_2 has rate_limit=5
update_key_rate_limit("test_key_2", None)
data = load_keys()
- key = next(k for k in data['keys'] if k['key_id'] == 'test_key_2')
- assert 'rate_limit' not in key
+ key = next(k for k in data["keys"] if k["key_id"] == "test_key_2")
+ assert "rate_limit" not in key
def test_update_nonexistent_key(self, tmp_path):
keys_file = make_keys_file(tmp_path)
- with patch('admin.KEYS_FILE', keys_file):
+ with patch("admin.KEYS_FILE", keys_file):
result = update_key_rate_limit("nonexistent", 10)
assert result is False
@@ -195,39 +195,39 @@ class TestSettings:
"""Test settings load/save."""
def test_load_defaults(self, tmp_path):
- with patch('admin.SETTINGS_FILE', os.path.join(str(tmp_path), 'nope.json')):
+ with patch("admin.SETTINGS_FILE", os.path.join(str(tmp_path), "nope.json")):
settings = load_settings()
- assert 'rate_limit_requests' in settings
- assert 'ip_rate_limit_requests' in settings
+ assert "rate_limit_requests" in settings
+ assert "ip_rate_limit_requests" in settings
def test_save_and_load(self, tmp_path):
- settings_file = os.path.join(str(tmp_path), 'settings.json')
- with patch('admin.SETTINGS_FILE', settings_file):
- save_settings({'rate_limit_requests': 200, 'custom': 'value'})
+ settings_file = os.path.join(str(tmp_path), "settings.json")
+ with patch("admin.SETTINGS_FILE", settings_file):
+ save_settings({"rate_limit_requests": 200, "custom": "value"})
settings = load_settings()
- assert settings['rate_limit_requests'] == 200
- assert settings['custom'] == 'value'
+ assert settings["rate_limit_requests"] == 200
+ assert settings["custom"] == "value"
# Defaults should still be present
- assert 'ip_rate_limit_requests' in settings
+ assert "ip_rate_limit_requests" in settings
class TestHelpers:
"""Test admin helper functions."""
def test_get_key_status_active(self):
- key = {'enabled': True}
+ key = {"enabled": True}
status, css = get_key_status(key)
assert status == "Active"
assert css == "status-active"
def test_get_key_status_revoked(self):
- key = {'enabled': False}
+ key = {"enabled": False}
status, css = get_key_status(key)
assert status == "Revoked"
assert css == "status-revoked"
def test_get_key_status_expired(self):
- key = {'enabled': True, 'expires_at': time.time() - 3600}
+ key = {"enabled": True, "expires_at": time.time() - 3600}
status, css = get_key_status(key)
assert status == "Expired"
assert css == "status-expired"
diff --git a/tests/test_auth.py b/tests/test_auth.py
index b848621..141e1bf 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -9,10 +9,13 @@
import pytest
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
from tests.helpers import (
- TEST_API_KEY, TEST_API_KEY_2, EXPIRED_API_KEY, DISABLED_API_KEY,
+ DISABLED_API_KEY,
+ EXPIRED_API_KEY,
+ TEST_API_KEY,
+ TEST_API_KEY_2,
make_keys_file,
)
@@ -21,54 +24,58 @@ class TestExtractApiKey:
"""Test API key extraction from request headers."""
def test_bearer_token(self):
- from server import extract_api_key
from unittest.mock import MagicMock
+ from server import extract_api_key
+
request = MagicMock()
- request.headers = {'Authorization': f'Bearer {TEST_API_KEY}'}
+ request.headers = {"Authorization": f"Bearer {TEST_API_KEY}"}
assert extract_api_key(request) == TEST_API_KEY
def test_x_api_key_header(self):
- from server import extract_api_key
from unittest.mock import MagicMock
+ from server import extract_api_key
+
request = MagicMock()
request.headers = MagicMock()
- request.headers.get = MagicMock(side_effect=lambda key, default='': (
- '' if key == 'Authorization' else
- TEST_API_KEY if key == 'X-API-Key' else default
- ))
+ request.headers.get = MagicMock(
+ side_effect=lambda key, default="": (
+ "" if key == "Authorization" else TEST_API_KEY if key == "X-API-Key" else default
+ )
+ )
assert extract_api_key(request) == TEST_API_KEY
def test_missing_key(self):
- from server import extract_api_key
from unittest.mock import MagicMock
+ from server import extract_api_key
+
request = MagicMock()
request.headers = MagicMock()
- request.headers.get = MagicMock(side_effect=lambda key, default='': '')
+ request.headers.get = MagicMock(side_effect=lambda key, default="": "")
assert extract_api_key(request) is None
def test_bearer_prefix_only(self):
- from server import extract_api_key
from unittest.mock import MagicMock
+ from server import extract_api_key
+
request = MagicMock()
request.headers = MagicMock()
- request.headers.get = lambda key, default='': 'Bearer ' if key == 'Authorization' else None
+ request.headers.get = lambda key, default="": "Bearer " if key == "Authorization" else None
# "Bearer " with empty key returns empty string
result = extract_api_key(request)
- assert result == ''
+ assert result == ""
def test_non_bearer_auth_header(self):
- from server import extract_api_key
from unittest.mock import MagicMock
+ from server import extract_api_key
+
request = MagicMock()
request.headers = MagicMock()
- request.headers.get = lambda key, default='': (
- 'Basic dXNlcjpwYXNz' if key == 'Authorization' else None
- )
+ request.headers.get = lambda key, default="": "Basic dXNlcjpwYXNz" if key == "Authorization" else None
# Basic auth should not be extracted as API key
assert extract_api_key(request) is None
@@ -103,7 +110,7 @@ def test_validate_expired_key(self, tmp_path):
keys_file = make_keys_file(tmp_path)
km = KeyManager(keys_file)
- valid, key_obj = km.validate_key(EXPIRED_API_KEY)
+ valid, _key_obj = km.validate_key(EXPIRED_API_KEY)
assert valid is False
def test_validate_disabled_key(self, tmp_path):
@@ -112,7 +119,7 @@ def test_validate_disabled_key(self, tmp_path):
keys_file = make_keys_file(tmp_path)
km = KeyManager(keys_file)
- valid, key_obj = km.validate_key(DISABLED_API_KEY)
+ valid, _key_obj = km.validate_key(DISABLED_API_KEY)
assert valid is False
def test_key_with_rate_limit(self, tmp_path):
@@ -133,9 +140,9 @@ def test_get_key_info(self, tmp_path):
info = km.get_key_info(TEST_API_KEY)
assert info is not None
- assert info['key_id'] == "test_key_1"
- assert info['description'] == "Test key 1"
- assert 'key_hash' not in info # Shouldn't leak the hash
+ assert info["key_id"] == "test_key_1"
+ assert info["description"] == "Test key 1"
+ assert "key_hash" not in info # Shouldn't leak the hash
def test_get_key_info_invalid(self, tmp_path):
from key_manager import KeyManager
@@ -148,6 +155,7 @@ def test_get_key_info_invalid(self, tmp_path):
def test_hot_reload(self, tmp_path):
import json
+
from key_manager import KeyManager
keys_file = make_keys_file(tmp_path)
@@ -229,10 +237,10 @@ def test_create_key_entry_hash_only(self):
from key_manager import create_key_entry
entry = create_key_entry("pm_test", description="Test", store_hash_only=True)
- assert 'key_hash' in entry
- assert 'key' not in entry
- assert entry['description'] == "Test"
- assert entry['enabled'] is True
+ assert "key_hash" in entry
+ assert "key" not in entry
+ assert entry["description"] == "Test"
+ assert entry["enabled"] is True
def test_create_key_entry_with_expiry(self):
from key_manager import create_key_entry
@@ -240,5 +248,5 @@ def test_create_key_entry_with_expiry(self):
before = time.time()
entry = create_key_entry("pm_test", expires_in_days=30)
expected_expiry = before + (30 * 86400)
- assert entry['expires_at'] >= expected_expiry - 1
- assert entry['expires_at'] <= expected_expiry + 2
+ assert entry["expires_at"] >= expected_expiry - 1
+ assert entry["expires_at"] <= expected_expiry + 2
diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py
index ea647c2..1ecad65 100644
--- a/tests/test_endpoints.py
+++ b/tests/test_endpoints.py
@@ -11,7 +11,7 @@
import pytest
from aiohttp import web
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
from tests.helpers import TEST_API_KEY, make_keys_file
@@ -24,24 +24,27 @@ def proxy_app(tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
env = {
- 'API_KEYS_FILE': keys_file,
- 'SETTINGS_FILE': settings_file,
- 'USAGE_FILE': usage_file,
- 'ADMIN_PASSWORD': 'test-admin-password',
- 'PRIVATEMODE_API_KEY': 'pm_test_upstream_key',
- 'UPSTREAM_URL': 'http://localhost:19999',
- 'TRUST_PROXY': 'false',
- 'FORCE_HTTPS': 'false',
- 'PBKDF2_SALT': 'test-salt-value-1234567890',
+ "API_KEYS_FILE": keys_file,
+ "SETTINGS_FILE": settings_file,
+ "USAGE_FILE": usage_file,
+ "ADMIN_PASSWORD": "test-admin-password",
+ "PRIVATEMODE_API_KEY": "pm_test_upstream_key",
+ "UPSTREAM_URL": "http://localhost:19999",
+ "TRUST_PROXY": "false",
+ "FORCE_HTTPS": "false",
+ "PBKDF2_SALT": "test-salt-value-1234567890",
}
with patch.dict(os.environ, env):
# Re-import to pick up new config values
import importlib
+
import config
+
importlib.reload(config)
# Need to reload server module too since it imports config at module level
import server
+
importlib.reload(server)
application = server.create_app()
@@ -59,15 +62,15 @@ class TestHealthEndpoint:
@pytest.mark.asyncio
async def test_health_returns_ok(self, client):
- resp = await client.get('/health')
+ resp = await client.get("/health")
assert resp.status == 200
data = await resp.json()
- assert data['status'] == 'healthy'
+ assert data["status"] == "healthy"
@pytest.mark.asyncio
async def test_health_no_auth_required(self, client):
# No auth headers at all
- resp = await client.get('/health')
+ resp = await client.get("/health")
assert resp.status == 200
@@ -76,30 +79,31 @@ class TestAuthMiddleware:
@pytest.mark.asyncio
async def test_missing_api_key(self, client):
- resp = await client.post('/v1/chat/completions',
- json={"model": "test", "messages": []})
+ resp = await client.post("/v1/chat/completions", json={"model": "test", "messages": []})
assert resp.status == 401
data = await resp.json()
- assert 'Missing API key' in data['error']
+ assert "Missing API key" in data["error"]
@pytest.mark.asyncio
async def test_invalid_api_key(self, client):
- resp = await client.post('/v1/chat/completions',
- headers={'Authorization': 'Bearer pm_invalid_key'},
- json={"model": "test", "messages": []})
+ resp = await client.post(
+ "/v1/chat/completions",
+ headers={"Authorization": "Bearer pm_invalid_key"},
+ json={"model": "test", "messages": []},
+ )
assert resp.status == 401
data = await resp.json()
- assert 'Invalid' in data['error']
+ assert "Invalid" in data["error"]
@pytest.mark.asyncio
async def test_valid_api_key_bearer(self, client):
# This will try to proxy to upstream (which doesn't exist in test)
# but should get past auth
- with patch('server.proxy_request', new_callable=AsyncMock) as mock_proxy:
+ with patch("server.proxy_request", new_callable=AsyncMock) as mock_proxy:
mock_proxy.return_value = web.json_response({"ok": True})
resp = await client.post(
- '/v1/chat/completions',
- headers={'Authorization': f'Bearer {TEST_API_KEY}'},
+ "/v1/chat/completions",
+ headers={"Authorization": f"Bearer {TEST_API_KEY}"},
json={"model": "test", "messages": []},
)
# Should get past auth (200 from mocked proxy or 502 if proxy fails)
@@ -107,11 +111,11 @@ async def test_valid_api_key_bearer(self, client):
@pytest.mark.asyncio
async def test_valid_api_key_x_header(self, client):
- with patch('server.proxy_request', new_callable=AsyncMock) as mock_proxy:
+ with patch("server.proxy_request", new_callable=AsyncMock) as mock_proxy:
mock_proxy.return_value = web.json_response({"ok": True})
resp = await client.post(
- '/v1/chat/completions',
- headers={'X-API-Key': TEST_API_KEY},
+ "/v1/chat/completions",
+ headers={"X-API-Key": TEST_API_KEY},
json={"model": "test", "messages": []},
)
assert resp.status in (200, 502)
@@ -122,21 +126,19 @@ class TestKeyInfoEndpoint:
@pytest.mark.asyncio
async def test_key_info_valid(self, client):
- resp = await client.get('/auth/key-info',
- headers={'Authorization': f'Bearer {TEST_API_KEY}'})
+ resp = await client.get("/auth/key-info", headers={"Authorization": f"Bearer {TEST_API_KEY}"})
assert resp.status == 200
data = await resp.json()
- assert data['key_id'] == 'test_key_1'
+ assert data["key_id"] == "test_key_1"
@pytest.mark.asyncio
async def test_key_info_invalid(self, client):
- resp = await client.get('/auth/key-info',
- headers={'Authorization': 'Bearer pm_bad'})
+ resp = await client.get("/auth/key-info", headers={"Authorization": "Bearer pm_bad"})
assert resp.status == 401
@pytest.mark.asyncio
async def test_key_info_no_auth(self, client):
- resp = await client.get('/auth/key-info')
+ resp = await client.get("/auth/key-info")
assert resp.status == 401
@@ -145,19 +147,19 @@ class TestSecurityHeaders:
@pytest.mark.asyncio
async def test_health_has_security_headers(self, client):
- resp = await client.get('/health')
- assert resp.headers.get('X-Content-Type-Options') == 'nosniff'
- assert resp.headers.get('X-Frame-Options') == 'DENY'
- assert resp.headers.get('X-XSS-Protection') == '1; mode=block'
- assert 'Referrer-Policy' in resp.headers
- assert 'Permissions-Policy' in resp.headers
+ resp = await client.get("/health")
+ assert resp.headers.get("X-Content-Type-Options") == "nosniff"
+ assert resp.headers.get("X-Frame-Options") == "DENY"
+ assert resp.headers.get("X-XSS-Protection") == "1; mode=block"
+ assert "Referrer-Policy" in resp.headers
+ assert "Permissions-Policy" in resp.headers
@pytest.mark.asyncio
async def test_admin_has_csp(self, client):
# Admin login page should have CSP
- resp = await client.get('/admin/login')
+ resp = await client.get("/admin/login")
assert resp.status == 200
- csp = resp.headers.get('Content-Security-Policy', '')
+ csp = resp.headers.get("Content-Security-Policy", "")
assert "frame-ancestors 'none'" in csp
@@ -166,51 +168,61 @@ class TestAdminEndpoints:
@pytest.mark.asyncio
async def test_admin_redirects_to_login(self, client):
- resp = await client.get('/admin', allow_redirects=False)
+ resp = await client.get("/admin", allow_redirects=False)
assert resp.status == 302
- assert '/admin/login' in resp.headers['Location']
+ assert "/admin/login" in resp.headers["Location"]
@pytest.mark.asyncio
async def test_login_page_loads(self, client):
- resp = await client.get('/admin/login')
+ resp = await client.get("/admin/login")
assert resp.status == 200
text = await resp.text()
- assert 'password' in text.lower()
- assert 'csrf_token' in text
+ assert "password" in text.lower()
+ assert "csrf_token" in text
@pytest.mark.asyncio
async def test_login_wrong_password(self, client):
# First get a CSRF token
- resp = await client.get('/admin/login')
+ resp = await client.get("/admin/login")
text = await resp.text()
# Extract CSRF token from form
import re
+
match = re.search(r'name="csrf_token" value="([^"]+)"', text)
assert match, "CSRF token not found in login form"
csrf_token = match.group(1)
- resp = await client.post('/admin/login', data={
- 'password': 'wrong',
- 'csrf_token': csrf_token,
- }, allow_redirects=False)
+ resp = await client.post(
+ "/admin/login",
+ data={
+ "password": "wrong",
+ "csrf_token": csrf_token,
+ },
+ allow_redirects=False,
+ )
assert resp.status == 302
- assert 'error' in resp.headers.get('Location', '')
+ assert "error" in resp.headers.get("Location", "")
@pytest.mark.asyncio
async def test_login_success(self, client):
# Get CSRF token
- resp = await client.get('/admin/login')
+ resp = await client.get("/admin/login")
text = await resp.text()
import re
+
match = re.search(r'name="csrf_token" value="([^"]+)"', text)
csrf_token = match.group(1)
- resp = await client.post('/admin/login', data={
- 'password': 'test-admin-password',
- 'csrf_token': csrf_token,
- }, allow_redirects=False)
+ resp = await client.post(
+ "/admin/login",
+ data={
+ "password": "test-admin-password",
+ "csrf_token": csrf_token,
+ },
+ allow_redirects=False,
+ )
assert resp.status == 302
- assert resp.headers.get('Location') == '/admin'
+ assert resp.headers.get("Location") == "/admin"
# Should have session cookie
- cookies = resp.cookies
- assert 'admin_session' in {c.key for c in resp.cookies.values()} or resp.headers.get('Set-Cookie', '')
+
+ assert "admin_session" in {c.key for c in resp.cookies.values()} or resp.headers.get("Set-Cookie", "")
diff --git a/tests/test_proxy.py b/tests/test_proxy.py
index 0bfeb32..ccf85b6 100644
--- a/tests/test_proxy.py
+++ b/tests/test_proxy.py
@@ -6,8 +6,7 @@
import os
import sys
-
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
from server import (
detect_endpoint_type,
@@ -68,41 +67,37 @@ def test_chat_completion_response(self):
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
- }
+ },
}
body = json.dumps(response).encode()
usage = extract_usage_from_response(body, "chat")
- assert usage['model'] == "gpt-oss-120b"
- assert usage['prompt_tokens'] == 10
- assert usage['completion_tokens'] == 20
- assert usage['total_tokens'] == 30
+ assert usage["model"] == "gpt-oss-120b"
+ assert usage["prompt_tokens"] == 10
+ assert usage["completion_tokens"] == 20
+ assert usage["total_tokens"] == 30
def test_embedding_response(self):
- response = {
- "model": "qwen3-embedding-4b",
- "data": [{"embedding": [0.1, 0.2]}],
- "usage": {"total_tokens": 50}
- }
+ response = {"model": "qwen3-embedding-4b", "data": [{"embedding": [0.1, 0.2]}], "usage": {"total_tokens": 50}}
body = json.dumps(response).encode()
usage = extract_usage_from_response(body, "embeddings")
- assert usage['model'] == "qwen3-embedding-4b"
- assert usage['total_tokens'] == 50
+ assert usage["model"] == "qwen3-embedding-4b"
+ assert usage["total_tokens"] == 50
def test_no_usage_field(self):
response = {"model": "gpt-oss-120b", "choices": []}
body = json.dumps(response).encode()
usage = extract_usage_from_response(body, "chat")
- assert usage['prompt_tokens'] == 0
- assert usage['total_tokens'] == 0
+ assert usage["prompt_tokens"] == 0
+ assert usage["total_tokens"] == 0
def test_invalid_response_body(self):
usage = extract_usage_from_response(b"not json", "chat")
- assert usage['model'] == 'unknown'
- assert usage['total_tokens'] == 0
+ assert usage["model"] == "unknown"
+ assert usage["total_tokens"] == 0
def test_empty_response(self):
usage = extract_usage_from_response(b"", "chat")
- assert usage['model'] == 'unknown'
+ assert usage["model"] == "unknown"
diff --git a/tests/test_rate_limiting.py b/tests/test_rate_limiting.py
index 98320a9..e4e7950 100644
--- a/tests/test_rate_limiting.py
+++ b/tests/test_rate_limiting.py
@@ -9,15 +9,15 @@
import pytest
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
from server import (
check_global_rate_limit,
- check_per_key_rate_limit,
check_ip_rate_limit,
+ check_per_key_rate_limit,
global_rate_limit_store,
- rate_limit_store,
ip_rate_limit_store,
+ rate_limit_store,
)
@@ -37,12 +37,15 @@ class TestGlobalRateLimit:
"""Test global rate limiting (shared across all keys)."""
def test_allows_within_limit(self):
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 10,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 1000,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 10,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 1000,
+ "ip_rate_limit_window": 60,
+ },
+ ):
allowed, remaining, limit, window = check_global_rate_limit()
assert allowed is True
assert remaining == 9 # limit(10) - count(0) - 1 = 9, then append
@@ -50,41 +53,48 @@ def test_allows_within_limit(self):
assert window == 60
def test_blocks_at_limit(self):
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 3,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 1000,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 3,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 1000,
+ "ip_rate_limit_window": 60,
+ },
+ ):
# Use up the limit
check_global_rate_limit() # 1
check_global_rate_limit() # 2
check_global_rate_limit() # 3
# Should be blocked now
- allowed, remaining, limit, window = check_global_rate_limit()
+ allowed, remaining, _limit, _window = check_global_rate_limit()
assert allowed is False
assert remaining == 0
def test_cleans_old_entries(self):
# The global store is reassigned in the function, so we need to test
# by calling through the function itself
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 3,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 1000,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 3,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 1000,
+ "ip_rate_limit_window": 60,
+ },
+ ):
# Use up 2 of 3
check_global_rate_limit()
check_global_rate_limit()
# Manually age the entries by modifying the store
import server
+
server.global_rate_limit_store[:] = [time.time() - 120, time.time() - 120]
# Should be allowed again since old entries get cleaned
- allowed, remaining, limit, window = check_global_rate_limit()
+ allowed, _remaining, _limit, _window = check_global_rate_limit()
assert allowed is True
@@ -93,44 +103,53 @@ class TestPerKeyRateLimit:
def test_no_limit_set(self):
"""Keys without a limit should always be allowed."""
- allowed, remaining, limit = check_per_key_rate_limit("key1", None)
+ allowed, remaining, _limit = check_per_key_rate_limit("key1", None)
assert allowed is True
assert remaining == -1 # Indicates no limit
def test_allows_within_limit(self):
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 100,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 1000,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 100,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 1000,
+ "ip_rate_limit_window": 60,
+ },
+ ):
allowed, remaining, limit = check_per_key_rate_limit("key1", 5)
assert allowed is True
assert remaining == 4 # limit(5) - count(0) - 1 = 4
assert limit == 5
def test_blocks_at_limit(self):
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 100,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 1000,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 100,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 1000,
+ "ip_rate_limit_window": 60,
+ },
+ ):
for _ in range(5):
check_per_key_rate_limit("key1", 5)
- allowed, remaining, limit = check_per_key_rate_limit("key1", 5)
+ allowed, remaining, _limit = check_per_key_rate_limit("key1", 5)
assert allowed is False
assert remaining == 0
def test_separate_key_stores(self):
"""Rate limits should be independent per key."""
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 100,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 1000,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 100,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 1000,
+ "ip_rate_limit_window": 60,
+ },
+ ):
# Exhaust key1
for _ in range(3):
check_per_key_rate_limit("key1", 3)
@@ -146,38 +165,47 @@ class TestIPRateLimit:
"""Test IP-based rate limiting."""
def test_allows_within_limit(self):
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 100,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 10,
- 'ip_rate_limit_window': 60,
- }):
- allowed, remaining, limit, window = check_ip_rate_limit("192.168.1.1")
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 100,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 10,
+ "ip_rate_limit_window": 60,
+ },
+ ):
+ allowed, _remaining, limit, _window = check_ip_rate_limit("192.168.1.1")
assert allowed is True
assert limit == 10
def test_blocks_at_limit(self):
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 100,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 3,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 100,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 3,
+ "ip_rate_limit_window": 60,
+ },
+ ):
for _ in range(3):
check_ip_rate_limit("192.168.1.1")
- allowed, remaining, limit, window = check_ip_rate_limit("192.168.1.1")
+ allowed, remaining, _limit, _window = check_ip_rate_limit("192.168.1.1")
assert allowed is False
assert remaining == 0
def test_separate_ip_stores(self):
"""Different IPs should have separate rate limit buckets."""
- with patch('server.get_rate_limit_settings', return_value={
- 'rate_limit_requests': 100,
- 'rate_limit_window': 60,
- 'ip_rate_limit_requests': 2,
- 'ip_rate_limit_window': 60,
- }):
+ with patch(
+ "server.get_rate_limit_settings",
+ return_value={
+ "rate_limit_requests": 100,
+ "rate_limit_window": 60,
+ "ip_rate_limit_requests": 2,
+ "ip_rate_limit_window": 60,
+ },
+ ):
# Exhaust IP 1
check_ip_rate_limit("10.0.0.1")
check_ip_rate_limit("10.0.0.1")
diff --git a/tests/test_usage_tracker.py b/tests/test_usage_tracker.py
index eba79ff..8246a28 100644
--- a/tests/test_usage_tracker.py
+++ b/tests/test_usage_tracker.py
@@ -8,7 +8,7 @@
import pytest
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
from usage_tracker import UsageTracker, get_time_range
@@ -65,8 +65,8 @@ def test_record_basic_usage(self, tmp_path):
)
summary = tracker.get_usage_summary()
- assert summary['total_tokens'] == 30
- assert summary['requests'] == 1
+ assert summary["total_tokens"] == 30
+ assert summary["requests"] == 1
def test_auto_calculates_total_tokens(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
@@ -82,38 +82,42 @@ def test_auto_calculates_total_tokens(self, tmp_path):
)
summary = tracker.get_usage_summary()
- assert summary['total_tokens'] == 30
+ assert summary["total_tokens"] == 30
def test_persistence(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
tracker1 = UsageTracker(usage_file)
# Record enough to trigger auto-save (every 10 records)
- for i in range(10):
+ for _i in range(10):
tracker1.record_usage(
- key_id="key1", model="gpt-oss-120b", endpoint="chat",
+ key_id="key1",
+ model="gpt-oss-120b",
+ endpoint="chat",
total_tokens=100,
)
# Create a new tracker from the same file
tracker2 = UsageTracker(usage_file)
summary = tracker2.get_usage_summary()
- assert summary['requests'] == 10
- assert summary['total_tokens'] == 1000
+ assert summary["requests"] == 10
+ assert summary["total_tokens"] == 1000
def test_flush(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
tracker = UsageTracker(usage_file)
tracker.record_usage(
- key_id="key1", model="gpt-oss-120b", endpoint="chat",
+ key_id="key1",
+ model="gpt-oss-120b",
+ endpoint="chat",
total_tokens=100,
)
tracker.flush()
# Verify file exists and has data
tracker2 = UsageTracker(usage_file)
- assert tracker2.get_usage_summary()['requests'] == 1
+ assert tracker2.get_usage_summary()["requests"] == 1
class TestUsageSummary:
@@ -127,8 +131,8 @@ def test_filter_by_key(self, tmp_path):
tracker.record_usage(key_id="key2", model="gpt-oss-120b", endpoint="chat", total_tokens=200)
summary = tracker.get_usage_summary(key_id="key1")
- assert summary['total_tokens'] == 100
- assert summary['requests'] == 1
+ assert summary["total_tokens"] == 100
+ assert summary["requests"] == 1
def test_filter_by_time(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
@@ -140,7 +144,7 @@ def test_filter_by_time(self, tmp_path):
# Filter to future time (should exclude current records)
future = time.time() + 3600
summary = tracker.get_usage_summary(start_time=future)
- assert summary['requests'] == 0
+ assert summary["requests"] == 0
def test_by_model_breakdown(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
@@ -150,10 +154,10 @@ def test_by_model_breakdown(self, tmp_path):
tracker.record_usage(key_id="key1", model="gemma-3-27b", endpoint="chat", total_tokens=200)
summary = tracker.get_usage_summary()
- assert "gpt-oss-120b" in summary['by_model']
- assert "gemma-3-27b" in summary['by_model']
- assert summary['by_model']['gpt-oss-120b']['tokens'] == 100
- assert summary['by_model']['gemma-3-27b']['tokens'] == 200
+ assert "gpt-oss-120b" in summary["by_model"]
+ assert "gemma-3-27b" in summary["by_model"]
+ assert summary["by_model"]["gpt-oss-120b"]["tokens"] == 100
+ assert summary["by_model"]["gemma-3-27b"]["tokens"] == 200
def test_by_endpoint_breakdown(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
@@ -163,8 +167,8 @@ def test_by_endpoint_breakdown(self, tmp_path):
tracker.record_usage(key_id="key1", model="qwen3-embedding-4b", endpoint="embeddings", total_tokens=50)
summary = tracker.get_usage_summary()
- assert summary['by_endpoint']['chat']['requests'] == 1
- assert summary['by_endpoint']['embeddings']['requests'] == 1
+ assert summary["by_endpoint"]["chat"]["requests"] == 1
+ assert summary["by_endpoint"]["embeddings"]["requests"] == 1
def test_usage_by_key(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
@@ -175,9 +179,9 @@ def test_usage_by_key(self, tmp_path):
tracker.record_usage(key_id="key2", model="gpt-oss-120b", endpoint="chat", total_tokens=50)
by_key = tracker.get_usage_by_key()
- assert by_key['key1']['tokens'] == 300
- assert by_key['key1']['requests'] == 2
- assert by_key['key2']['tokens'] == 50
+ assert by_key["key1"]["tokens"] == 300
+ assert by_key["key1"]["requests"] == 2
+ assert by_key["key2"]["tokens"] == 50
def test_daily_breakdown(self, tmp_path):
usage_file = os.path.join(str(tmp_path), "usage.json")
@@ -187,7 +191,7 @@ def test_daily_breakdown(self, tmp_path):
daily = tracker.get_daily_breakdown(days=1)
assert len(daily) >= 1
- assert daily[0]['tokens'] == 100
+ assert daily[0]["tokens"] == 100
class TestTimeRanges:
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 928d5f9..5b76321 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -6,8 +6,7 @@
import sys
from unittest.mock import MagicMock, patch
-
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'auth-proxy'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "auth-proxy"))
from utils import get_client_ip
@@ -19,60 +18,52 @@ def test_direct_connection(self):
request = MagicMock()
request.remote = "192.168.1.1"
request.headers = MagicMock()
- request.headers.get = lambda key, default='': ''
+ request.headers.get = lambda key, default="": ""
- with patch('utils.TRUST_PROXY', False):
+ with patch("utils.TRUST_PROXY", False):
assert get_client_ip(request) == "192.168.1.1"
def test_ignores_forwarded_when_untrusted(self):
request = MagicMock()
request.remote = "10.0.0.1"
request.headers = MagicMock()
- request.headers.get = lambda key, default='': (
- "1.2.3.4, 5.6.7.8" if key == 'X-Forwarded-For' else ''
- )
+ request.headers.get = lambda key, default="": "1.2.3.4, 5.6.7.8" if key == "X-Forwarded-For" else ""
- with patch('utils.TRUST_PROXY', False):
+ with patch("utils.TRUST_PROXY", False):
assert get_client_ip(request) == "10.0.0.1"
def test_uses_forwarded_when_trusted(self):
request = MagicMock()
request.remote = "10.0.0.1"
request.headers = MagicMock()
- request.headers.get = lambda key, default='': (
- "1.2.3.4, 5.6.7.8" if key == 'X-Forwarded-For' else ''
- )
+ request.headers.get = lambda key, default="": "1.2.3.4, 5.6.7.8" if key == "X-Forwarded-For" else ""
- with patch('utils.TRUST_PROXY', True):
+ with patch("utils.TRUST_PROXY", True):
assert get_client_ip(request) == "1.2.3.4"
def test_single_forwarded_ip(self):
request = MagicMock()
request.remote = "10.0.0.1"
request.headers = MagicMock()
- request.headers.get = lambda key, default='': (
- "203.0.113.50" if key == 'X-Forwarded-For' else ''
- )
+ request.headers.get = lambda key, default="": "203.0.113.50" if key == "X-Forwarded-For" else ""
- with patch('utils.TRUST_PROXY', True):
+ with patch("utils.TRUST_PROXY", True):
assert get_client_ip(request) == "203.0.113.50"
def test_no_remote_returns_unknown(self):
request = MagicMock()
request.remote = None
request.headers = MagicMock()
- request.headers.get = lambda key, default='': ''
+ request.headers.get = lambda key, default="": ""
- with patch('utils.TRUST_PROXY', False):
+ with patch("utils.TRUST_PROXY", False):
assert get_client_ip(request) == "unknown"
def test_forwarded_with_spaces(self):
request = MagicMock()
request.remote = "10.0.0.1"
request.headers = MagicMock()
- request.headers.get = lambda key, default='': (
- " 1.2.3.4 , 5.6.7.8 " if key == 'X-Forwarded-For' else ''
- )
+ request.headers.get = lambda key, default="": " 1.2.3.4 , 5.6.7.8 " if key == "X-Forwarded-For" else ""
- with patch('utils.TRUST_PROXY', True):
+ with patch("utils.TRUST_PROXY", True):
assert get_client_ip(request) == "1.2.3.4"