diff --git a/src/worker.py b/src/worker.py index fbeccc0..d690f99 100644 --- a/src/worker.py +++ b/src/worker.py @@ -38,6 +38,7 @@ import json import os import re +import time import traceback from types import SimpleNamespace from typing import Any, Dict @@ -52,6 +53,10 @@ _SENTRY_INITIALIZED = False _SENTRY_DSN: str = "" +_AUTH_RATE_LIMIT_STATE: Dict[str, Dict[str, int]] = {} +_AUTH_RATE_LIMIT_WINDOW_SECONDS = 60 +_AUTH_RATE_LIMIT_MAX_ATTEMPTS = 5 +_AUTH_RATE_LIMIT_MAX_KEYS = 10000 def init_sentry(env): @@ -418,6 +423,74 @@ def _unauthorized_basic(realm: str = "Alpha One Labs Admin"): ) +def _too_many_requests(retry_after: int): + headers = {"Content-Type": "application/json", **_CORS, "Retry-After": str(retry_after)} + return Response(json.dumps({"error": "Too many requests"}), status=429, headers=headers) + + +def _auth_rate_limit_env_value(env, name: str, default: int) -> int: + try: + env_dict = getattr(env, "__dict__", None) + if isinstance(env_dict, dict) and name in env_dict: + return int(env_dict[name]) + except Exception: + pass + + try: + return int(getattr(env, name)) + except Exception: + return default + + +def _check_auth_rate_limit(req, env, route: str): + window_seconds = max(1, _auth_rate_limit_env_value( + env, "AUTH_RATE_LIMIT_WINDOW_SECONDS", _AUTH_RATE_LIMIT_WINDOW_SECONDS + )) + max_attempts = max(1, _auth_rate_limit_env_value( + env, "AUTH_RATE_LIMIT_MAX_ATTEMPTS", _AUTH_RATE_LIMIT_MAX_ATTEMPTS + )) + max_keys = max(1, _auth_rate_limit_env_value( + env, "AUTH_RATE_LIMIT_MAX_KEYS", _AUTH_RATE_LIMIT_MAX_KEYS + )) + + # Only CF-Connecting-IP is trusted in Cloudflare Workers. + client_ip = (req.headers.get("CF-Connecting-IP") or "").strip() + if not client_ip: + print(json.dumps({"level": "warn", "where": "auth_rate_limit", "error": "missing_cf_connecting_ip"})) + return _too_many_requests(1) + + key = f"{route}:{client_ip}" + now = int(time.time()) + + # Keep in-memory state bounded by pruning expired entries and evicting oldest keys. + stale_before = now - window_seconds + for stale_key, stale_state in list(_AUTH_RATE_LIMIT_STATE.items()): + if int(stale_state.get("window_start", 0)) < stale_before: + _AUTH_RATE_LIMIT_STATE.pop(stale_key, None) + + if len(_AUTH_RATE_LIMIT_STATE) > max_keys: + overflow = len(_AUTH_RATE_LIMIT_STATE) - max_keys + oldest_keys = sorted( + _AUTH_RATE_LIMIT_STATE.items(), + key=lambda item: int(item[1].get("window_start", 0)), + )[:overflow] + for oldest_key, _ in oldest_keys: + _AUTH_RATE_LIMIT_STATE.pop(oldest_key, None) + + state = _AUTH_RATE_LIMIT_STATE.get(key) + if not state or now - int(state.get("window_start", 0)) >= window_seconds: + _AUTH_RATE_LIMIT_STATE[key] = {"window_start": now, "count": 1} + return None + + count = int(state.get("count", 0)) + if count >= max_attempts: + retry_after = max(1, window_seconds - (now - int(state.get("window_start", now)))) + return _too_many_requests(retry_after) + + state["count"] = count + 1 + return None + + def _is_basic_auth_valid(req, env) -> bool: username = (getattr(env, "ADMIN_BASIC_USER", "") or "").strip() password = (getattr(env, "ADMIN_BASIC_PASS", "") or "").strip() @@ -742,6 +815,10 @@ async def seed_db(env, enc_key: str): # --------------------------------------------------------------------------- async def api_register(req, env): + rate_limit_resp = _check_auth_rate_limit(req, env, "register") + if rate_limit_resp: + return rate_limit_resp + body, bad_resp = await parse_json_object(req) if bad_resp: return bad_resp @@ -790,6 +867,10 @@ async def api_register(req, env): async def api_login(req, env): + rate_limit_resp = _check_auth_rate_limit(req, env, "login") + if rate_limit_resp: + return rate_limit_resp + body, bad_resp = await parse_json_object(req) if bad_resp: return bad_resp diff --git a/tests/helpers.py b/tests/helpers.py index b53d167..a0fb5ad 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -118,6 +118,8 @@ def make_env(db=None, enc_key="test-encryption-key", jwt_secret="test-jwt-secret env = MagicMock() env.ENCRYPTION_KEY = enc_key env.JWT_SECRET = jwt_secret + env.AUTH_RATE_LIMIT_WINDOW_SECONDS = 60 + env.AUTH_RATE_LIMIT_MAX_ATTEMPTS = 5 env.ADMIN_BASIC_USER = admin_user env.ADMIN_BASIC_PASS = admin_pass env.ADMIN_URL = admin_url @@ -157,7 +159,10 @@ def basic_auth_header(user: str, password: str) -> str: # --------------------------------------------------------------------------- def json_request(path: str, payload: dict, headers=None, method="POST") -> MockRequest: - h = {"Content-Type": "application/json"} + h = { + "Content-Type": "application/json", + "CF-Connecting-IP": "127.0.0.1", + } if headers: h.update(headers) return MockRequest(method=method, url=f"http://localhost{path}", headers=h, diff --git a/tests/test_api_auth.py b/tests/test_api_auth.py index 2ffb1aa..c7c4fb1 100644 --- a/tests/test_api_auth.py +++ b/tests/test_api_auth.py @@ -4,6 +4,7 @@ import base64 import json +import pytest from tests.helpers import load_worker, MockRequest, MockRow, MockDB, make_env, make_stmt, json_request worker = load_worker() @@ -12,6 +13,13 @@ JWT = "test-jwt-secret" +@pytest.fixture(autouse=True) +def clear_auth_rate_limit_state(): + worker._AUTH_RATE_LIMIT_STATE.clear() + yield + worker._AUTH_RATE_LIMIT_STATE.clear() + + def _parse(resp): return json.loads(resp.body) @@ -35,7 +43,13 @@ def _enc(val: str) -> str: class TestApiRegister: def _req(self, payload): - return json_request("/api/register", payload) + return json_request("/api/register", payload, headers={"CF-Connecting-IP": "127.0.0.1"}) + + def _rate_limited_env(self): + env = make_env(db=MockDB([make_stmt()])) + env.AUTH_RATE_LIMIT_WINDOW_SECONDS = 60 + env.AUTH_RATE_LIMIT_MAX_ATTEMPTS = 2 + return env async def test_missing_username_returns_400(self): env = make_env() @@ -124,10 +138,35 @@ async def test_token_is_verifiable(self): async def test_invalid_json_returns_400(self): req = MockRequest(method="POST", url="http://localhost/api/register", + headers={"CF-Connecting-IP": "127.0.0.1"}, body="not-json") r = await worker.api_register(req, make_env()) assert r.status == 400 + async def test_register_is_rate_limited_per_ip(self): + env = self._rate_limited_env() + ip = "203.0.113.10" + + first_req = self._req({"username": "alice1", "email": "alice1@example.com", "password": "password123"}) + first_req.headers["CF-Connecting-IP"] = ip + second_req = self._req({"username": "alice2", "email": "alice2@example.com", "password": "password123"}) + second_req.headers["CF-Connecting-IP"] = ip + third_req = self._req({"username": "alice3", "email": "alice3@example.com", "password": "password123"}) + third_req.headers["CF-Connecting-IP"] = ip + + first = await worker.api_register(first_req, env) + second = await worker.api_register(second_req, env) + third = await worker.api_register(third_req, env) + + assert first.status == 200 + assert second.status == 200 + assert third.status == 429 + assert "Retry-After" in third.headers + retry_after = third.headers["Retry-After"] + assert retry_after.isdigit() + assert int(retry_after) > 0 + assert _parse(third).get("error") == "Too many requests" + # --------------------------------------------------------------------------- # api_login() @@ -135,7 +174,13 @@ async def test_invalid_json_returns_400(self): class TestApiLogin: def _req(self, payload): - return json_request("/api/login", payload) + return json_request("/api/login", payload, headers={"CF-Connecting-IP": "127.0.0.1"}) + + def _rate_limited_env(self): + env = make_env(db=MockDB([make_stmt(first=None)])) + env.AUTH_RATE_LIMIT_WINDOW_SECONDS = 60 + env.AUTH_RATE_LIMIT_MAX_ATTEMPTS = 2 + return env def _make_user_row(self, username="alice", password="password123", role="member", name="Alice"): pw_hash = worker.hash_password(password, username) @@ -192,6 +237,75 @@ async def test_login_token_is_verifiable(self): assert payload is not None async def test_invalid_json_returns_400(self): - req = MockRequest(method="POST", url="http://localhost/api/login", body="bad-json") + req = MockRequest(method="POST", url="http://localhost/api/login", headers={"CF-Connecting-IP": "127.0.0.1"}, body="bad-json") r = await worker.api_login(req, make_env()) assert r.status == 400 + + async def test_login_is_rate_limited_per_ip(self): + row = self._make_user_row() + env = self._rate_limited_env() + env.DB = MockDB([ + make_stmt(first=row), + make_stmt(first=row), + make_stmt(first=row), + ]) + + req1 = self._req({"username": "alice", "password": "password123"}) + req1.headers["CF-Connecting-IP"] = "203.0.113.10" + req2 = self._req({"username": "alice", "password": "password123"}) + req2.headers["CF-Connecting-IP"] = "203.0.113.10" + req3 = self._req({"username": "alice", "password": "password123"}) + req3.headers["CF-Connecting-IP"] = "203.0.113.10" + + first = await worker.api_login(req1, env) + second = await worker.api_login(req2, env) + third = await worker.api_login(req3, env) + + assert first.status == 200 + assert second.status == 200 + assert third.status == 429 + assert "Retry-After" in third.headers + retry_after = third.headers["Retry-After"] + assert retry_after.isdigit() + assert int(retry_after) > 0 + assert _parse(third).get("error") == "Too many requests" + + async def test_login_rate_limit_resets_after_window(self, monkeypatch): + row = self._make_user_row() + env = self._rate_limited_env() + env.DB = MockDB([ + make_stmt(first=row), + make_stmt(first=row), + make_stmt(first=row), + make_stmt(first=row), + ]) + + worker._AUTH_RATE_LIMIT_STATE.clear() + monkeypatch.setattr(worker.time, "time", lambda: 1000) + + req1 = self._req({"username": "alice", "password": "password123"}) + req1.headers["CF-Connecting-IP"] = "198.51.100.10" + req2 = self._req({"username": "alice", "password": "password123"}) + req2.headers["CF-Connecting-IP"] = "198.51.100.10" + + assert (await worker.api_login(req1, env)).status == 200 + assert (await worker.api_login(req2, env)).status == 200 + + monkeypatch.setattr(worker.time, "time", lambda: 1000 + 61) + req3 = self._req({"username": "alice", "password": "password123"}) + req3.headers["CF-Connecting-IP"] = "198.51.100.10" + assert (await worker.api_login(req3, env)).status == 200 + + req4 = self._req({"username": "alice", "password": "password123"}) + req4.headers["CF-Connecting-IP"] = "198.51.100.10" + assert (await worker.api_login(req4, env)).status == 200 + + req5 = self._req({"username": "alice", "password": "password123"}) + req5.headers["CF-Connecting-IP"] = "198.51.100.10" + limited = await worker.api_login(req5, env) + assert limited.status == 429 + assert "Retry-After" in limited.headers + limited_retry_after = limited.headers["Retry-After"] + assert limited_retry_after.isdigit() + assert int(limited_retry_after) > 0 + assert _parse(limited).get("error") == "Too many requests"