Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import json
import os
import re
import time
import traceback
from types import SimpleNamespace
from typing import Any, Dict
Expand All @@ -52,6 +53,10 @@

_SENTRY_INITIALIZED = False
_SENTRY_DSN: str = ""
_AUTH_RATE_LIMIT_STATE: Dict[str, Dict[str, int]] = {}
_AUTH_RATE_LIMIT_WINDOW_SECONDS = 60
_AUTH_RATE_LIMIT_MAX_ATTEMPTS = 5
_AUTH_RATE_LIMIT_MAX_KEYS = 10000


def init_sentry(env):
Expand Down Expand Up @@ -418,6 +423,74 @@ def _unauthorized_basic(realm: str = "Alpha One Labs Admin"):
)


def _too_many_requests(retry_after: int):
headers = {"Content-Type": "application/json", **_CORS, "Retry-After": str(retry_after)}
return Response(json.dumps({"error": "Too many requests"}), status=429, headers=headers)


def _auth_rate_limit_env_value(env, name: str, default: int) -> int:
try:
env_dict = getattr(env, "__dict__", None)
if isinstance(env_dict, dict) and name in env_dict:
return int(env_dict[name])
except Exception:
pass

try:
return int(getattr(env, name))
except Exception:
return default


def _check_auth_rate_limit(req, env, route: str):
window_seconds = max(1, _auth_rate_limit_env_value(
env, "AUTH_RATE_LIMIT_WINDOW_SECONDS", _AUTH_RATE_LIMIT_WINDOW_SECONDS
))
max_attempts = max(1, _auth_rate_limit_env_value(
env, "AUTH_RATE_LIMIT_MAX_ATTEMPTS", _AUTH_RATE_LIMIT_MAX_ATTEMPTS
))
max_keys = max(1, _auth_rate_limit_env_value(
env, "AUTH_RATE_LIMIT_MAX_KEYS", _AUTH_RATE_LIMIT_MAX_KEYS
))

# Only CF-Connecting-IP is trusted in Cloudflare Workers.
client_ip = (req.headers.get("CF-Connecting-IP") or "").strip()
if not client_ip:
print(json.dumps({"level": "warn", "where": "auth_rate_limit", "error": "missing_cf_connecting_ip"}))
return _too_many_requests(1)

Comment thread
Flashl3opard marked this conversation as resolved.
key = f"{route}:{client_ip}"
now = int(time.time())

# Keep in-memory state bounded by pruning expired entries and evicting oldest keys.
stale_before = now - window_seconds
for stale_key, stale_state in list(_AUTH_RATE_LIMIT_STATE.items()):
if int(stale_state.get("window_start", 0)) < stale_before:
_AUTH_RATE_LIMIT_STATE.pop(stale_key, None)

if len(_AUTH_RATE_LIMIT_STATE) > max_keys:
overflow = len(_AUTH_RATE_LIMIT_STATE) - max_keys
oldest_keys = sorted(
_AUTH_RATE_LIMIT_STATE.items(),
key=lambda item: int(item[1].get("window_start", 0)),
)[:overflow]
for oldest_key, _ in oldest_keys:
_AUTH_RATE_LIMIT_STATE.pop(oldest_key, None)

state = _AUTH_RATE_LIMIT_STATE.get(key)
if not state or now - int(state.get("window_start", 0)) >= window_seconds:
_AUTH_RATE_LIMIT_STATE[key] = {"window_start": now, "count": 1}
Comment thread
Flashl3opard marked this conversation as resolved.
return None

count = int(state.get("count", 0))
if count >= max_attempts:
retry_after = max(1, window_seconds - (now - int(state.get("window_start", now))))
return _too_many_requests(retry_after)

state["count"] = count + 1
return None


def _is_basic_auth_valid(req, env) -> bool:
username = (getattr(env, "ADMIN_BASIC_USER", "") or "").strip()
password = (getattr(env, "ADMIN_BASIC_PASS", "") or "").strip()
Expand Down Expand Up @@ -742,6 +815,10 @@ async def seed_db(env, enc_key: str):
# ---------------------------------------------------------------------------

async def api_register(req, env):
rate_limit_resp = _check_auth_rate_limit(req, env, "register")
if rate_limit_resp:
return rate_limit_resp

body, bad_resp = await parse_json_object(req)
if bad_resp:
return bad_resp
Expand Down Expand Up @@ -790,6 +867,10 @@ async def api_register(req, env):


async def api_login(req, env):
rate_limit_resp = _check_auth_rate_limit(req, env, "login")
if rate_limit_resp:
return rate_limit_resp

body, bad_resp = await parse_json_object(req)
if bad_resp:
return bad_resp
Expand Down
7 changes: 6 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def make_env(db=None, enc_key="test-encryption-key", jwt_secret="test-jwt-secret
env = MagicMock()
env.ENCRYPTION_KEY = enc_key
env.JWT_SECRET = jwt_secret
env.AUTH_RATE_LIMIT_WINDOW_SECONDS = 60
env.AUTH_RATE_LIMIT_MAX_ATTEMPTS = 5
env.ADMIN_BASIC_USER = admin_user
env.ADMIN_BASIC_PASS = admin_pass
env.ADMIN_URL = admin_url
Expand Down Expand Up @@ -157,7 +159,10 @@ def basic_auth_header(user: str, password: str) -> str:
# ---------------------------------------------------------------------------

def json_request(path: str, payload: dict, headers=None, method="POST") -> MockRequest:
h = {"Content-Type": "application/json"}
h = {
"Content-Type": "application/json",
"CF-Connecting-IP": "127.0.0.1",
}
if headers:
h.update(headers)
return MockRequest(method=method, url=f"http://localhost{path}", headers=h,
Expand Down
120 changes: 117 additions & 3 deletions tests/test_api_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import base64
import json
import pytest
from tests.helpers import load_worker, MockRequest, MockRow, MockDB, make_env, make_stmt, json_request

worker = load_worker()
Expand All @@ -12,6 +13,13 @@
JWT = "test-jwt-secret"


@pytest.fixture(autouse=True)
def clear_auth_rate_limit_state():
worker._AUTH_RATE_LIMIT_STATE.clear()
yield
worker._AUTH_RATE_LIMIT_STATE.clear()


def _parse(resp):
return json.loads(resp.body)

Expand All @@ -35,7 +43,13 @@ def _enc(val: str) -> str:

class TestApiRegister:
def _req(self, payload):
return json_request("/api/register", payload)
return json_request("/api/register", payload, headers={"CF-Connecting-IP": "127.0.0.1"})

def _rate_limited_env(self):
env = make_env(db=MockDB([make_stmt()]))
env.AUTH_RATE_LIMIT_WINDOW_SECONDS = 60
env.AUTH_RATE_LIMIT_MAX_ATTEMPTS = 2
return env

async def test_missing_username_returns_400(self):
env = make_env()
Expand Down Expand Up @@ -124,18 +138,49 @@ async def test_token_is_verifiable(self):

async def test_invalid_json_returns_400(self):
req = MockRequest(method="POST", url="http://localhost/api/register",
headers={"CF-Connecting-IP": "127.0.0.1"},
body="not-json")
r = await worker.api_register(req, make_env())
assert r.status == 400

async def test_register_is_rate_limited_per_ip(self):
env = self._rate_limited_env()
ip = "203.0.113.10"

first_req = self._req({"username": "alice1", "email": "alice1@example.com", "password": "password123"})
first_req.headers["CF-Connecting-IP"] = ip
second_req = self._req({"username": "alice2", "email": "alice2@example.com", "password": "password123"})
second_req.headers["CF-Connecting-IP"] = ip
third_req = self._req({"username": "alice3", "email": "alice3@example.com", "password": "password123"})
third_req.headers["CF-Connecting-IP"] = ip

first = await worker.api_register(first_req, env)
second = await worker.api_register(second_req, env)
third = await worker.api_register(third_req, env)

assert first.status == 200
assert second.status == 200
assert third.status == 429
Comment thread
Flashl3opard marked this conversation as resolved.
assert "Retry-After" in third.headers
retry_after = third.headers["Retry-After"]
assert retry_after.isdigit()
assert int(retry_after) > 0
assert _parse(third).get("error") == "Too many requests"


# ---------------------------------------------------------------------------
# api_login()
# ---------------------------------------------------------------------------

class TestApiLogin:
def _req(self, payload):
return json_request("/api/login", payload)
return json_request("/api/login", payload, headers={"CF-Connecting-IP": "127.0.0.1"})

def _rate_limited_env(self):
env = make_env(db=MockDB([make_stmt(first=None)]))
env.AUTH_RATE_LIMIT_WINDOW_SECONDS = 60
env.AUTH_RATE_LIMIT_MAX_ATTEMPTS = 2
return env

def _make_user_row(self, username="alice", password="password123", role="member", name="Alice"):
pw_hash = worker.hash_password(password, username)
Expand Down Expand Up @@ -192,6 +237,75 @@ async def test_login_token_is_verifiable(self):
assert payload is not None

async def test_invalid_json_returns_400(self):
req = MockRequest(method="POST", url="http://localhost/api/login", body="bad-json")
req = MockRequest(method="POST", url="http://localhost/api/login", headers={"CF-Connecting-IP": "127.0.0.1"}, body="bad-json")
r = await worker.api_login(req, make_env())
assert r.status == 400

async def test_login_is_rate_limited_per_ip(self):
row = self._make_user_row()
env = self._rate_limited_env()
env.DB = MockDB([
make_stmt(first=row),
make_stmt(first=row),
make_stmt(first=row),
])

req1 = self._req({"username": "alice", "password": "password123"})
req1.headers["CF-Connecting-IP"] = "203.0.113.10"
req2 = self._req({"username": "alice", "password": "password123"})
req2.headers["CF-Connecting-IP"] = "203.0.113.10"
req3 = self._req({"username": "alice", "password": "password123"})
req3.headers["CF-Connecting-IP"] = "203.0.113.10"

first = await worker.api_login(req1, env)
second = await worker.api_login(req2, env)
third = await worker.api_login(req3, env)

Comment thread
Flashl3opard marked this conversation as resolved.
assert first.status == 200
assert second.status == 200
assert third.status == 429
assert "Retry-After" in third.headers
retry_after = third.headers["Retry-After"]
assert retry_after.isdigit()
assert int(retry_after) > 0
assert _parse(third).get("error") == "Too many requests"

async def test_login_rate_limit_resets_after_window(self, monkeypatch):
row = self._make_user_row()
env = self._rate_limited_env()
env.DB = MockDB([
make_stmt(first=row),
make_stmt(first=row),
make_stmt(first=row),
make_stmt(first=row),
])

worker._AUTH_RATE_LIMIT_STATE.clear()
monkeypatch.setattr(worker.time, "time", lambda: 1000)

req1 = self._req({"username": "alice", "password": "password123"})
req1.headers["CF-Connecting-IP"] = "198.51.100.10"
req2 = self._req({"username": "alice", "password": "password123"})
req2.headers["CF-Connecting-IP"] = "198.51.100.10"

assert (await worker.api_login(req1, env)).status == 200
assert (await worker.api_login(req2, env)).status == 200

monkeypatch.setattr(worker.time, "time", lambda: 1000 + 61)
req3 = self._req({"username": "alice", "password": "password123"})
req3.headers["CF-Connecting-IP"] = "198.51.100.10"
assert (await worker.api_login(req3, env)).status == 200

req4 = self._req({"username": "alice", "password": "password123"})
req4.headers["CF-Connecting-IP"] = "198.51.100.10"
assert (await worker.api_login(req4, env)).status == 200

req5 = self._req({"username": "alice", "password": "password123"})
req5.headers["CF-Connecting-IP"] = "198.51.100.10"
limited = await worker.api_login(req5, env)
assert limited.status == 429
assert "Retry-After" in limited.headers
limited_retry_after = limited.headers["Retry-After"]
assert limited_retry_after.isdigit()
assert int(limited_retry_after) > 0
assert _parse(limited).get("error") == "Too many requests"
Loading