diff --git a/app/main.py b/app/main.py index 7e7c4f9..6d28eca 100644 --- a/app/main.py +++ b/app/main.py @@ -58,6 +58,13 @@ Submission, Wallet, ) +from app.path_params import ( + SQLITE_INTEGER_MAX, + issue_number_search_value, + positive_bounty_id, + positive_ledger_sequence, + proof_hash_from_path, +) from app.serializers import ( accepted_work_for_account, account_accepted_summary, @@ -97,7 +104,6 @@ "X-Frame-Options": "DENY", } GITHUB_LOGIN_RE = re.compile(r"^[a-z0-9](?:[a-z0-9-]{0,37}[a-z0-9])?$") -HEX_HASH_RE = re.compile(r"^[0-9a-f]{64}$") API_DOCS_CSP = ( "default-src 'self'; " "base-uri 'self'; " @@ -112,7 +118,6 @@ "worker-src 'self' blob:" ) API_DOCS_PATHS = {"/api/docs", "/api/redoc"} -SQLITE_INTEGER_MAX = 2**63 - 1 DEFAULT_ATTEMPT_TTL_SECONDS = 24 * 60 * 60 MIN_ATTEMPT_TTL_SECONDS = 60 MAX_ATTEMPT_TTL_SECONDS = 7 * 24 * 60 * 60 @@ -139,16 +144,6 @@ def _preserve_forwarded_https_redirect(request: Request, response: Response) -> ) -def _issue_number_search_value(query: str) -> int | None: - if not query.isdigit(): - return None - try: - issue_number = int(query) - except ValueError: - return None - return issue_number if issue_number <= SQLITE_INTEGER_MAX else None - - def _utc_now() -> datetime: return datetime.now(UTC) @@ -336,22 +331,6 @@ def _github_login_from_account(account: str) -> str | None: return login -def _positive_bounty_id(bounty_id: int) -> int: - if bounty_id <= 0: - raise HTTPException(status_code=400, detail="bounty id must be positive") - if bounty_id > SQLITE_INTEGER_MAX: - raise HTTPException(status_code=400, detail="bounty id is too large") - return bounty_id - - -def _positive_ledger_sequence(sequence: int) -> int: - if sequence <= 0: - raise HTTPException(status_code=400, detail="ledger sequence must be positive") - if sequence > SQLITE_INTEGER_MAX: - raise HTTPException(status_code=400, detail="ledger sequence is too large") - return sequence - - def _normalized_wallet_address(address: str) -> str: try: return normalize_wallet_address(address) @@ -359,15 +338,6 @@ def _normalized_wallet_address(address: str) -> str: raise HTTPException(status_code=400, detail=str(exc)) from exc -def _proof_hash_from_path(proof_hash: str) -> str: - if proof_hash != proof_hash.strip(): - raise HTTPException(status_code=400, detail="proof hash must be 64 hex characters") - clean = proof_hash.lower() - if not HEX_HASH_RE.fullmatch(clean): - raise HTTPException(status_code=400, detail="proof hash must be 64 hex characters") - return clean - - def _signed_value(value: str, secret: str) -> str: timestamp = str(int(time.time())) body = f"{value}|{timestamp}" @@ -600,7 +570,7 @@ def list_bounties_by_status( .replace("_", "\\_") ) like_query = f"%{escaped_query}%" - issue_number = _issue_number_search_value(normalized_query) + issue_number = issue_number_search_value(normalized_query) text_filter = or_( func.lower(Bounty.repo).like(like_query, escape="\\"), func.lower(Bounty.title).like(like_query, escape="\\"), @@ -661,7 +631,7 @@ async def api_create_bounty( @app.get("/api/v1/bounties/{bounty_id}") def api_bounty(bounty_id: int) -> dict[str, Any]: - bounty_id = _positive_bounty_id(bounty_id) + bounty_id = positive_bounty_id(bounty_id) with session_scope(db_url) as session: bounty = session.get(Bounty, bounty_id) if bounty is None: @@ -672,7 +642,7 @@ def api_bounty(bounty_id: int) -> dict[str, Any]: @app.get("/api/v1/bounties/{bounty_id}/attempts") def api_bounty_attempts(bounty_id: int, include_expired: bool = Query(False)) -> dict[str, Any]: - bounty_id = _positive_bounty_id(bounty_id) + bounty_id = positive_bounty_id(bounty_id) now = _utc_now() with session_scope(db_url) as session: bounty = session.get(Bounty, bounty_id) @@ -696,7 +666,7 @@ async def api_create_bounty_attempt( request: Request, github_login: str = Depends(require_github_login), ) -> JSONResponse: - bounty_id = _positive_bounty_id(bounty_id) + bounty_id = positive_bounty_id(bounty_id) data = await _json_object(request) submitter_account = attempt_submitter_account(data, github_login) ttl_seconds = _optional_int(data, "ttl_seconds", DEFAULT_ATTEMPT_TTL_SECONDS) @@ -839,7 +809,7 @@ async def api_pay_bounty( request: Request, admin_login: str = Depends(require_admin_token), ) -> Any: - bounty_id = _positive_bounty_id(bounty_id) + bounty_id = positive_bounty_id(bounty_id) data = await _json_object(request) try: requested_account = _required_str(data, "to_account") @@ -909,7 +879,7 @@ async def api_close_bounty( request: Request, admin_login: str = Depends(require_admin_token), ) -> dict[str, Any]: - bounty_id = _positive_bounty_id(bounty_id) + bounty_id = positive_bounty_id(bounty_id) data = await _json_object(request) reference = _optional_str(data, "reference") if data.get("reference") is not None else None closed_by = _optional_str(data, "closed_by", admin_login) @@ -1069,7 +1039,7 @@ def api_ledger(limit: Annotated[int, Query(ge=1, le=200)] = 50) -> list[dict[str @app.get("/api/v1/ledger/{sequence}") def api_ledger_entry(sequence: int) -> dict[str, Any]: - sequence = _positive_ledger_sequence(sequence) + sequence = positive_ledger_sequence(sequence) with session_scope(db_url) as session: entry = session.get(LedgerEntry, sequence) if entry is None: @@ -1079,7 +1049,7 @@ def api_ledger_entry(sequence: int) -> dict[str, Any]: @app.get("/api/v1/proofs/{proof_hash}") def api_proof(proof_hash: str) -> dict[str, Any]: - proof_hash = _proof_hash_from_path(proof_hash) + proof_hash = proof_hash_from_path(proof_hash) with session_scope(db_url) as session: proof = session.get(Proof, proof_hash) if proof is None: @@ -1530,15 +1500,6 @@ def output_format_arg() -> str: raise ValueError("format must be text or json") return normalized - def mcp_issue_number_search_value(query_text: str) -> int | None: - if not query_text.isdigit(): - return None - try: - issue_number = int(query_text) - except ValueError: - return None - return issue_number if issue_number <= SQLITE_INTEGER_MAX else None - def list_limit_arg(default: int = 25) -> int: if "limit" not in args or args.get("limit") is None: return default @@ -1658,7 +1619,7 @@ def optional_bool_arg(field: str, default: bool = False) -> bool: query_text.lower().replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") ) like_query = f"%{escaped_query}%" - issue_number = mcp_issue_number_search_value(query_text) + issue_number = issue_number_search_value(query_text) text_filter = or_( func.lower(Bounty.repo).like(like_query, escape="\\"), func.lower(Bounty.title).like(like_query, escape="\\"), @@ -1735,7 +1696,7 @@ def optional_bool_arg(field: str, default: bool = False) -> bool: ) return json.dumps(ledger_to_dict(entry, proof.hash if proof else None)) if name == "get_proof": - proof = session.get(Proof, _proof_hash_from_path(str_arg("hash"))) + proof = session.get(Proof, proof_hash_from_path(str_arg("hash"))) if proof is None: return "proof not found" public_payload = json.loads(proof.public_json) diff --git a/app/path_params.py b/app/path_params.py new file mode 100644 index 0000000..2672d0c --- /dev/null +++ b/app/path_params.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import re + +from fastapi import HTTPException + +SQLITE_INTEGER_MAX = 2**63 - 1 +HEX_HASH_RE = re.compile(r"^[0-9a-f]{64}$") + + +def issue_number_search_value(query: str) -> int | None: + """Return a bounded GitHub issue number from a plain numeric search query.""" + if not query.isdigit(): + return None + try: + issue_number = int(query) + except ValueError: + return None + return issue_number if issue_number <= SQLITE_INTEGER_MAX else None + + +def positive_bounty_id(bounty_id: int) -> int: + if bounty_id <= 0: + raise HTTPException(status_code=400, detail="bounty id must be positive") + if bounty_id > SQLITE_INTEGER_MAX: + raise HTTPException(status_code=400, detail="bounty id is too large") + return bounty_id + + +def positive_ledger_sequence(sequence: int) -> int: + if sequence <= 0: + raise HTTPException(status_code=400, detail="ledger sequence must be positive") + if sequence > SQLITE_INTEGER_MAX: + raise HTTPException(status_code=400, detail="ledger sequence is too large") + return sequence + + +def proof_hash_from_path(proof_hash: str) -> str: + if proof_hash != proof_hash.strip(): + raise HTTPException(status_code=400, detail="proof hash must be 64 hex characters") + clean = proof_hash.lower() + if not HEX_HASH_RE.fullmatch(clean): + raise HTTPException(status_code=400, detail="proof hash must be 64 hex characters") + return clean diff --git a/tests/test_path_params.py b/tests/test_path_params.py new file mode 100644 index 0000000..52f9c9b --- /dev/null +++ b/tests/test_path_params.py @@ -0,0 +1,51 @@ +from fastapi import HTTPException + +from app.path_params import ( + SQLITE_INTEGER_MAX, + issue_number_search_value, + positive_bounty_id, + positive_ledger_sequence, + proof_hash_from_path, +) + + +def assert_bad_request(func, *args): + try: + func(*args) + except HTTPException as exc: + assert exc.status_code == 400 + else: # pragma: no cover - defensive test helper + raise AssertionError("expected HTTPException") + + +def test_issue_number_search_value_accepts_bounded_numeric_query(): + assert issue_number_search_value("340") == 340 + assert issue_number_search_value(str(SQLITE_INTEGER_MAX)) == SQLITE_INTEGER_MAX + + +def test_issue_number_search_value_rejects_non_numeric_or_overflow_query(): + assert issue_number_search_value("") is None + assert issue_number_search_value(" 340") is None + assert issue_number_search_value("340a") is None + assert issue_number_search_value(str(SQLITE_INTEGER_MAX + 1)) is None + + +def test_positive_bounty_id_and_ledger_sequence_validate_bounds(): + assert positive_bounty_id(1) == 1 + assert positive_ledger_sequence(SQLITE_INTEGER_MAX) == SQLITE_INTEGER_MAX + + assert_bad_request(positive_bounty_id, 0) + assert_bad_request(positive_bounty_id, SQLITE_INTEGER_MAX + 1) + assert_bad_request(positive_ledger_sequence, -1) + assert_bad_request(positive_ledger_sequence, SQLITE_INTEGER_MAX + 1) + + +def test_proof_hash_from_path_normalizes_hex_hash(): + raw_hash = "A" * 64 + assert proof_hash_from_path(raw_hash) == "a" * 64 + + +def test_proof_hash_from_path_rejects_whitespace_or_non_hex(): + assert_bad_request(proof_hash_from_path, " " + "a" * 64) + assert_bad_request(proof_hash_from_path, "g" * 64) + assert_bad_request(proof_hash_from_path, "a" * 63)