Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 12 additions & 87 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@
ensure_genesis,
format_mrwk,
get_balance,
link_wallet_to_github,
pay_bounty,
public_url_or_none,
register_wallet,
resolve_payout_account,
submit_github_claim,
submit_wallet_transfer,
validate_public_url,
)
Expand Down Expand Up @@ -73,12 +71,12 @@
bounty_awards_to_dict,
bounty_list_summary,
bounty_to_dict,
ledger_to_dict,
payout_reconciliation_to_dict,
wallet_to_dict,
wallet_transfer_to_dict,
)
from app.status import health_status, system_status
from app.wallet_api import register_wallet_api_routes
from app.webhooks.github import handle_github_webhook

BASE_DIR = Path(__file__).resolve().parent
Expand Down Expand Up @@ -606,90 +604,17 @@ def api_auth_me(request: Request) -> dict[str, Any]:
login = github_login_from_request(request)
return {"authenticated": login is not None, "github_login": login}

@app.post("/api/v1/wallets/register")
async def api_register_wallet(request: Request) -> dict[str, Any]:
data = await _json_object(request)
with session_scope(db_url) as session:
try:
wallet = register_wallet(
session,
public_key_hex=_required_str(data, "public_key_hex"),
label=_optional_str(data, "label") if data.get("label") is not None else None,
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return wallet_to_dict(session, wallet)

@app.get("/api/v1/wallets/register", include_in_schema=False)
def api_register_wallet_get() -> None:
post_only_route()

@app.get("/api/v1/wallets/link-github", include_in_schema=False)
def api_link_wallet_github_get() -> None:
post_only_route()

@app.get("/api/v1/wallets/{address}")
def api_wallet(address: str) -> dict[str, Any]:
address = normalized_wallet_address(address)
with session_scope(db_url) as session:
wallet = session.get(Wallet, address)
if wallet is None:
raise HTTPException(status_code=404, detail="wallet not found")
return wallet_to_dict(session, wallet)

@app.post("/api/v1/wallets/link-github")
async def api_link_wallet_github(
request: Request, github_login: str = Depends(require_github_login)
) -> dict[str, Any]:
data = await _json_object(request)
with session_scope(db_url) as session:
try:
wallet = link_wallet_to_github(
session,
address=_required_str(data, "address"),
github_login=github_login,
nonce=_required_int(data, "nonce"),
signature_hex=_required_str(data, "signature_hex"),
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return wallet_to_dict(session, wallet)

@app.post("/api/v1/github/claim")
async def api_github_claim(
request: Request, github_login: str = Depends(require_github_login)
) -> dict[str, Any]:
data = await _json_object(request)
with session_scope(db_url) as session:
try:
entry = submit_github_claim(
session,
address=_required_str(data, "address"),
github_login=github_login,
nonce=_required_int(data, "nonce"),
signature_hex=_required_str(data, "signature_hex"),
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return ledger_to_dict(entry)

@app.post("/api/v1/transfers")
async def api_submit_transfer(request: Request) -> dict[str, Any]:
data = await _json_object(request)
with session_scope(db_url) as session:
try:
transfer = submit_wallet_transfer(
session,
from_address=_required_str(data, "from_address"),
to_address=_required_str(data, "to_address"),
amount_mrwk=_required_str(data, "amount_mrwk"),
nonce=_required_int(data, "nonce"),
memo=_optional_str(data, "memo"),
signature_hex=_required_str(data, "signature_hex"),
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return wallet_transfer_to_dict(transfer)
register_wallet_api_routes(
app,
db_url=db_url,
require_github_login=require_github_login,
json_object=_json_object,
required_str=_required_str,
required_int=_required_int,
optional_str=_optional_str,
normalized_wallet_address=normalized_wallet_address,
post_only_route=post_only_route,
)

@app.get("/api/v1/ledger")
def api_ledger(limit: Annotated[int, Query(ge=1, le=200)] = 50) -> list[dict[str, Any]]:
Expand Down
123 changes: 123 additions & 0 deletions app/wallet_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable
from typing import Any

from fastapi import Depends, FastAPI, HTTPException, Request

from app.db import session_scope
from app.ledger.service import (
LedgerError,
link_wallet_to_github,
register_wallet,
submit_github_claim,
submit_wallet_transfer,
)
from app.models import Wallet
from app.serializers import ledger_to_dict, wallet_to_dict, wallet_transfer_to_dict

JsonObjectLoader = Callable[[Request], Awaitable[dict[str, Any]]]
LoginDependency = Callable[[Request], str]
RequiredString = Callable[[dict[str, Any], str], str]
RequiredInteger = Callable[[dict[str, Any], str], int]
OptionalString = Callable[[dict[str, Any], str], str]
NormalizeWalletAddress = Callable[[str], str]
PostOnlyRoute = Callable[[], None]


def register_wallet_api_routes(
app: FastAPI,
*,
db_url: str,
require_github_login: LoginDependency,
json_object: JsonObjectLoader,
required_str: RequiredString,
required_int: RequiredInteger,
optional_str: OptionalString,
normalized_wallet_address: NormalizeWalletAddress,
post_only_route: PostOnlyRoute,
) -> None:
@app.post("/api/v1/wallets/register")
async def api_register_wallet(request: Request) -> dict[str, Any]:
data = await json_object(request)
with session_scope(db_url) as session:
try:
wallet = register_wallet(
session,
public_key_hex=required_str(data, "public_key_hex"),
label=optional_str(data, "label") if data.get("label") is not None else None,
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return wallet_to_dict(session, wallet)

@app.get("/api/v1/wallets/register", include_in_schema=False)
def api_register_wallet_get() -> None:
post_only_route()

@app.get("/api/v1/wallets/link-github", include_in_schema=False)
def api_link_wallet_github_get() -> None:
post_only_route()

@app.get("/api/v1/wallets/{address}")
def api_wallet(address: str) -> dict[str, Any]:
address = normalized_wallet_address(address)
with session_scope(db_url) as session:
wallet = session.get(Wallet, address)
if wallet is None:
raise HTTPException(status_code=404, detail="wallet not found")
return wallet_to_dict(session, wallet)

@app.post("/api/v1/wallets/link-github")
async def api_link_wallet_github(
request: Request, github_login: str = Depends(require_github_login)
) -> dict[str, Any]:
data = await json_object(request)
with session_scope(db_url) as session:
try:
wallet = link_wallet_to_github(
session,
address=required_str(data, "address"),
github_login=github_login,
nonce=required_int(data, "nonce"),
signature_hex=required_str(data, "signature_hex"),
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return wallet_to_dict(session, wallet)

@app.post("/api/v1/github/claim")
async def api_github_claim(
request: Request, github_login: str = Depends(require_github_login)
) -> dict[str, Any]:
data = await json_object(request)
with session_scope(db_url) as session:
try:
entry = submit_github_claim(
session,
address=required_str(data, "address"),
github_login=github_login,
nonce=required_int(data, "nonce"),
signature_hex=required_str(data, "signature_hex"),
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return ledger_to_dict(entry)

@app.post("/api/v1/transfers")
async def api_submit_transfer(request: Request) -> dict[str, Any]:
data = await json_object(request)
with session_scope(db_url) as session:
try:
transfer = submit_wallet_transfer(
session,
from_address=required_str(data, "from_address"),
to_address=required_str(data, "to_address"),
amount_mrwk=required_str(data, "amount_mrwk"),
nonce=required_int(data, "nonce"),
memo=optional_str(data, "memo"),
signature_hex=required_str(data, "signature_hex"),
)
except LedgerError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return wallet_transfer_to_dict(transfer)
42 changes: 38 additions & 4 deletions tests/test_wallet_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
submit_wallet_transfer,
wallet_claim_payload,
wallet_link_payload,
wallet_transfer_payload,
)
from app.main import _safe_next_path, _signed_value, _verified_value, create_app
from app.wallets import address_from_public_key_hex, canonical_wallet_json
Expand Down Expand Up @@ -100,6 +101,38 @@ def test_wallet_api_register_lookup_and_transfer(sqlite_url: str) -> None:
assert client.get(f"/api/v1/wallets/{receiver_address}").json()["balance_mrwk"] == "3"


def test_wallet_transfer_api_rejects_replayed_signed_body(sqlite_url: str) -> None:
create_schema(sqlite_url)
sender_key, sender_public, sender_address = _keypair()
_, receiver_public, receiver_address = _keypair()
client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret"))
_register_wallet(client, sender_public)
_register_wallet(client, receiver_public)
_fund_wallet(sqlite_url, sender_address)
payload = wallet_transfer_payload(
from_address=sender_address,
to_address=receiver_address,
amount_microunits=1_000_000,
nonce=1,
memo="api replay",
)
body = {
"from_address": sender_address,
"to_address": receiver_address,
"amount_mrwk": "1",
"nonce": 1,
"memo": "api replay",
"signature_hex": _sign(sender_key, payload),
}

first = client.post("/api/v1/transfers", json=body)
second = client.post("/api/v1/transfers", json=body)

assert first.status_code == 200
assert second.status_code == 400
assert second.json()["detail"] == "invalid nonce"


@pytest.mark.parametrize(
("body_overrides", "payload_overrides", "expected_detail"),
[
Expand Down Expand Up @@ -269,10 +302,11 @@ def test_wallet_method_boundary_routes_are_hidden_from_openapi(sqlite_url: str)

paths = client.get("/openapi.json").json()["paths"]

assert "post" in paths["/api/v1/wallets/register"]
assert "get" not in paths["/api/v1/wallets/register"]
assert "post" in paths["/api/v1/wallets/link-github"]
assert "get" not in paths["/api/v1/wallets/link-github"]
assert set(paths["/api/v1/wallets/register"]) == {"post"}
assert set(paths["/api/v1/wallets/link-github"]) == {"post"}
assert set(paths["/api/v1/wallets/{address}"]) == {"get"}
assert set(paths["/api/v1/github/claim"]) == {"post"}
assert set(paths["/api/v1/transfers"]) == {"post"}


def test_wallet_api_malformed_transfer_requests_return_4xx(sqlite_url: str) -> None:
Expand Down
Loading