From 6e25be90b4172d81a2adb46bfa37c169f7bb4b7e Mon Sep 17 00:00:00 2001 From: Tu Pham Date: Tue, 26 May 2026 09:17:55 +0700 Subject: [PATCH 1/2] Extract security header middleware Refs #320 --- app/ledger/service.py | 37 ++++++------- app/main.py | 85 +++--------------------------- app/security_headers.py | 94 ++++++++++++++++++++++++++++++++++ tests/test_security_headers.py | 77 ++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 99 deletions(-) create mode 100644 app/security_headers.py create mode 100644 tests/test_security_headers.py diff --git a/app/ledger/service.py b/app/ledger/service.py index 3812285..295ae46 100644 --- a/app/ledger/service.py +++ b/app/ledger/service.py @@ -6,11 +6,10 @@ import re from datetime import UTC, datetime from decimal import Decimal, InvalidOperation -from typing import Any, cast +from typing import Any from urllib.parse import urlparse from sqlalchemy import case, func, select, update -from sqlalchemy.engine import CursorResult from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -657,19 +656,16 @@ def pay_bounty( reserve_account = reserve_account_for_bounty(bounty.id) if get_balance(session, reserve_account) < bounty.reward_microunits: raise LedgerError("bounty reserve balance too low") - claimed = cast( - CursorResult[Any], - session.execute( - update(Bounty) - .where(Bounty.id == bounty.id, Bounty.awards_paid < Bounty.max_awards) - .values( - awards_paid=Bounty.awards_paid + 1, - status=case( - (Bounty.awards_paid + 1 >= Bounty.max_awards, "paid"), - else_="open", - ), - ) - ), + claimed = session.execute( + update(Bounty) + .where(Bounty.id == bounty.id, Bounty.awards_paid < Bounty.max_awards) + .values( + awards_paid=Bounty.awards_paid + 1, + status=case( + (Bounty.awards_paid + 1 >= Bounty.max_awards, "paid"), + else_="open", + ), + ) ) if claimed.rowcount != 1: raise LedgerError("bounty already paid") @@ -737,13 +733,10 @@ def close_bounty( raise LedgerError("bounty is not open") _clean_required_text(closed_by, "closed_by", 80) clean_reference = validate_public_url(reference or bounty.issue_url) - claimed = cast( - CursorResult[Any], - session.execute( - update(Bounty) - .where(Bounty.id == bounty.id, Bounty.status == "open") - .values(status="closed") - ), + claimed = session.execute( + update(Bounty) + .where(Bounty.id == bounty.id, Bounty.status == "open") + .values(status="closed") ) if claimed.rowcount != 1: raise LedgerError("bounty is not open") diff --git a/app/main.py b/app/main.py index 7e7c4f9..6e07fe9 100644 --- a/app/main.py +++ b/app/main.py @@ -9,11 +9,11 @@ from datetime import UTC, datetime, timedelta from pathlib import Path from typing import Annotated, Any -from urllib.parse import urlencode, urlsplit, urlunsplit +from urllib.parse import unquote, urlencode import httpx from fastapi import Depends, FastAPI, Form, HTTPException, Query, Request -from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse, Response +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sqlalchemy import func, or_, select, update @@ -58,6 +58,7 @@ Submission, Wallet, ) +from app.security_headers import register_security_headers_middleware from app.serializers import ( accepted_work_for_account, account_accepted_summary, @@ -79,66 +80,14 @@ templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) templates.env.globals["safe_public_url"] = public_url_or_none -SECURITY_HEADERS = { - "Content-Security-Policy": ( - "default-src 'self'; " - "base-uri 'self'; " - "frame-ancestors 'none'; " - "form-action 'self'; " - "connect-src 'self'; " - "img-src 'self' data:; " - "object-src 'none'; " - "script-src 'self'; " - "style-src 'self'" - ), - "Referrer-Policy": "no-referrer", - "Strict-Transport-Security": "max-age=31536000; includeSubDomains", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", -} GITHUB_LOGIN_RE = re.compile(r"^[a-z0-9](?:[a-z0-9-]{0,37}[a-z0-9])?$") HEX_HASH_RE = re.compile(r"^[0-9a-f]{64}$") -API_DOCS_CSP = ( - "default-src 'self'; " - "base-uri 'self'; " - "frame-ancestors 'none'; " - "form-action 'self'; " - "connect-src 'self'; " - "font-src 'self' data: https://fonts.gstatic.com; " - "img-src 'self' data: https://fastapi.tiangolo.com https://cdn.redoc.ly; " - "object-src 'none'; " - "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " - "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://fonts.googleapis.com; " - "worker-src 'self' blob:" -) -API_DOCS_PATHS = {"/api/docs", "/api/redoc"} SQLITE_INTEGER_MAX = 2**63 - 1 DEFAULT_ATTEMPT_TTL_SECONDS = 24 * 60 * 60 MIN_ATTEMPT_TTL_SECONDS = 60 MAX_ATTEMPT_TTL_SECONDS = 7 * 24 * 60 * 60 -def _request_was_forwarded_https(request: Request) -> bool: - forwarded_proto = request.headers.get("x-forwarded-proto", "") - if forwarded_proto: - return forwarded_proto.split(",", 1)[0].strip().lower() == "https" - return request.url.scheme == "https" - - -def _preserve_forwarded_https_redirect(request: Request, response: Response) -> None: - if response.status_code not in {307, 308} or not _request_was_forwarded_https(request): - return - location = response.headers.get("location") - if not location: - return - parsed = urlsplit(location) - if parsed.scheme != "http" or parsed.netloc != request.url.netloc: - return - response.headers["location"] = urlunsplit( - ("https", parsed.netloc, parsed.path, parsed.query, parsed.fragment) - ) - - def _issue_number_search_value(query: str) -> int | None: if not query.isdigit(): return None @@ -278,13 +227,17 @@ def _oauth_configured(settings: Settings) -> bool: def _safe_next_path(next_path: str | None) -> str: + decoded_next_path = unquote(next_path) if next_path else "" if ( not next_path or not next_path.startswith("/") or next_path.startswith("//") or len(next_path) > 2048 or "\\" in next_path + or decoded_next_path.startswith("//") + or "\\" in decoded_next_path or any(ord(char) < 32 or 127 <= ord(char) < 160 for char in next_path) + or any(ord(char) < 32 or 127 <= ord(char) < 160 for char in decoded_next_path) ): return "/me" return next_path @@ -486,29 +439,7 @@ def post_only_route() -> None: headers={"Allow": "POST"}, ) - @app.middleware("http") - async def add_security_headers(request: Request, call_next: Any) -> Any: - original_method = request.scope["method"] - if original_method == "HEAD": - request.scope["method"] = "GET" - try: - response = await call_next(request) - finally: - request.scope["method"] = original_method - if original_method == "HEAD": - headers = dict(response.headers) - headers["content-length"] = "0" - response = Response( - status_code=response.status_code, - headers=headers, - media_type=response.media_type, - ) - if request.url.path in API_DOCS_PATHS: - response.headers["Content-Security-Policy"] = API_DOCS_CSP - _preserve_forwarded_https_redirect(request, response) - for name, value in SECURITY_HEADERS.items(): - response.headers.setdefault(name, value) - return response + register_security_headers_middleware(app) static_dir = BASE_DIR / "static" if static_dir.exists(): diff --git a/app/security_headers.py b/app/security_headers.py new file mode 100644 index 0000000..674a655 --- /dev/null +++ b/app/security_headers.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import Any +from urllib.parse import urlsplit, urlunsplit + +from fastapi import FastAPI, Request +from fastapi.responses import Response + +SECURITY_HEADERS = { + "Content-Security-Policy": ( + "default-src 'self'; " + "base-uri 'self'; " + "frame-ancestors 'none'; " + "form-action 'self'; " + "connect-src 'self'; " + "img-src 'self' data:; " + "object-src 'none'; " + "script-src 'self'; " + "style-src 'self'" + ), + "Referrer-Policy": "no-referrer", + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", +} +API_DOCS_CSP = ( + "default-src 'self'; " + "base-uri 'self'; " + "frame-ancestors 'none'; " + "form-action 'self'; " + "connect-src 'self'; " + "font-src 'self' data: https://fonts.gstatic.com; " + "img-src 'self' data: https://fastapi.tiangolo.com https://cdn.redoc.ly; " + "object-src 'none'; " + "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " + "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://fonts.googleapis.com; " + "worker-src 'self' blob:" +) +API_DOCS_PATHS = {"/api/docs", "/api/redoc"} + + +def request_was_forwarded_https(request: Request) -> bool: + forwarded_proto = request.headers.get("x-forwarded-proto", "") + if forwarded_proto: + return forwarded_proto.split(",", 1)[0].strip().lower() == "https" + return request.url.scheme == "https" + + +def preserve_forwarded_https_redirect(request: Request, response: Response) -> None: + if response.status_code not in {307, 308} or not request_was_forwarded_https(request): + return + location = response.headers.get("location") + if not location: + return + parsed = urlsplit(location) + if parsed.scheme != "http" or parsed.netloc != request.url.netloc: + return + response.headers["location"] = urlunsplit( + ("https", parsed.netloc, parsed.path, parsed.query, parsed.fragment) + ) + + +def apply_security_headers(request: Request, response: Response) -> Response: + if request.url.path in API_DOCS_PATHS: + response.headers["Content-Security-Policy"] = API_DOCS_CSP + preserve_forwarded_https_redirect(request, response) + for name, value in SECURITY_HEADERS.items(): + response.headers.setdefault(name, value) + return response + + +async def security_headers_middleware(request: Request, call_next: Any) -> Response: + original_method = request.scope["method"] + if original_method == "HEAD": + request.scope["method"] = "GET" + try: + response = await call_next(request) + finally: + request.scope["method"] = original_method + if original_method == "HEAD": + headers = dict(response.headers) + headers["content-length"] = "0" + response = Response( + status_code=response.status_code, + headers=headers, + media_type=response.media_type, + ) + return apply_security_headers(request, response) + + +def register_security_headers_middleware(app: FastAPI) -> None: + @app.middleware("http") + async def add_security_headers(request: Request, call_next: Any) -> Response: + return await security_headers_middleware(request, call_next) diff --git a/tests/test_security_headers.py b/tests/test_security_headers.py new file mode 100644 index 0000000..5b9097e --- /dev/null +++ b/tests/test_security_headers.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from fastapi.testclient import TestClient + +from app.db import create_schema, session_scope +from app.ledger.service import create_bounty, ensure_genesis +from app.main import create_app +from app.security_headers import API_DOCS_CSP, SECURITY_HEADERS + + +def test_security_header_defaults_are_applied_to_browser_routes(sqlite_url: str) -> None: + client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret")) + + response = client.get("/") + + for name, value in SECURITY_HEADERS.items(): + assert response.headers[name.lower()] == value + + +def test_security_header_middleware_preserves_head_as_get(sqlite_url: str) -> None: + create_schema(sqlite_url) + with session_scope(sqlite_url) as session: + ensure_genesis(session) + create_bounty( + session, + repo="ramimbo/mergework", + issue_number=320, + issue_url="https://github.com/ramimbo/mergework/issues/320", + title="Security header middleware", + reward_mrwk="25", + acceptance="HEAD requests should keep GET route semantics without a body.", + ) + client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret")) + + response = client.head("/api/v1/bounties") + + assert response.status_code == 200 + assert response.content == b"" + assert response.headers["content-length"] == "0" + assert response.headers["x-frame-options"] == "DENY" + + +def test_api_docs_use_relaxed_docs_csp(sqlite_url: str) -> None: + client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret")) + + response = client.get("/api/docs") + + assert response.status_code == 200 + assert response.headers["content-security-policy"] == API_DOCS_CSP + + +def test_forwarded_https_redirects_keep_https_scheme(sqlite_url: str) -> None: + create_schema(sqlite_url) + with session_scope(sqlite_url) as session: + ensure_genesis(session) + bounty = create_bounty( + session, + repo="ramimbo/mergework", + issue_number=321, + issue_url="https://github.com/ramimbo/mergework/issues/321", + title="Forwarded HTTPS redirect", + reward_mrwk="25", + acceptance="Trailing slash redirects should not downgrade public HTTPS requests.", + ) + client = TestClient( + create_app(database_url=sqlite_url, webhook_secret="secret"), + base_url="http://mrwk.ltclab.site", + ) + + response = client.get( + f"/bounties/{bounty.id}/", + headers={"x-forwarded-proto": "https"}, + follow_redirects=False, + ) + + assert response.status_code == 307 + assert response.headers["location"] == f"https://mrwk.ltclab.site/bounties/{bounty.id}" From 36b63821c0d79f4747c4144bbef33ad09f50bf58 Mon Sep 17 00:00:00 2001 From: Tu Pham Date: Tue, 26 May 2026 09:21:29 +0700 Subject: [PATCH 2/2] =?UTF-8?q?Kh=C3=B4i=20ph=E1=BB=A5c=20ki=E1=BB=83u=20C?= =?UTF-8?q?ursorResult=20cho=20security=20headers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/ledger/service.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/app/ledger/service.py b/app/ledger/service.py index 295ae46..3812285 100644 --- a/app/ledger/service.py +++ b/app/ledger/service.py @@ -6,10 +6,11 @@ import re from datetime import UTC, datetime from decimal import Decimal, InvalidOperation -from typing import Any +from typing import Any, cast from urllib.parse import urlparse from sqlalchemy import case, func, select, update +from sqlalchemy.engine import CursorResult from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -656,16 +657,19 @@ def pay_bounty( reserve_account = reserve_account_for_bounty(bounty.id) if get_balance(session, reserve_account) < bounty.reward_microunits: raise LedgerError("bounty reserve balance too low") - claimed = session.execute( - update(Bounty) - .where(Bounty.id == bounty.id, Bounty.awards_paid < Bounty.max_awards) - .values( - awards_paid=Bounty.awards_paid + 1, - status=case( - (Bounty.awards_paid + 1 >= Bounty.max_awards, "paid"), - else_="open", - ), - ) + claimed = cast( + CursorResult[Any], + session.execute( + update(Bounty) + .where(Bounty.id == bounty.id, Bounty.awards_paid < Bounty.max_awards) + .values( + awards_paid=Bounty.awards_paid + 1, + status=case( + (Bounty.awards_paid + 1 >= Bounty.max_awards, "paid"), + else_="open", + ), + ) + ), ) if claimed.rowcount != 1: raise LedgerError("bounty already paid") @@ -733,10 +737,13 @@ def close_bounty( raise LedgerError("bounty is not open") _clean_required_text(closed_by, "closed_by", 80) clean_reference = validate_public_url(reference or bounty.issue_url) - claimed = session.execute( - update(Bounty) - .where(Bounty.id == bounty.id, Bounty.status == "open") - .values(status="closed") + claimed = cast( + CursorResult[Any], + session.execute( + update(Bounty) + .where(Bounty.id == bounty.id, Bounty.status == "open") + .values(status="closed") + ), ) if claimed.rowcount != 1: raise LedgerError("bounty is not open")