Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Annotated, Any
from urllib.parse import urlencode, urlsplit, urlunsplit
from urllib.parse import unquote, urlencode, urlsplit, urlunsplit

import httpx
from fastapi import Depends, FastAPI, Form, HTTPException, Query, Request
Expand Down Expand Up @@ -278,13 +278,17 @@ def _oauth_configured(settings: Settings) -> bool:


def _safe_next_path(next_path: str | None) -> str:
decoded_next_path = unquote(next_path) if next_path else ""
if (
not next_path
or not next_path.startswith("/")
or next_path.startswith("//")
or len(next_path) > 2048
or "\\" in next_path
or decoded_next_path.startswith("//")
or "\\" in decoded_next_path
or any(ord(char) < 32 or 127 <= ord(char) < 160 for char in next_path)
or any(ord(char) < 32 or 127 <= ord(char) < 160 for char in decoded_next_path)
):
return "/me"
return next_path
Expand Down Expand Up @@ -1530,6 +1534,14 @@ def output_format_arg() -> str:
raise ValueError("format must be text or json")
return normalized

def optional_repo_selector_arg() -> str | None:
repo = optional_clean_str_arg("repo")
if repo is None:
return None
if len(repo) > 200:
raise ValueError("repo is too long (max 200 characters)")
return repo.lower()

def mcp_issue_number_search_value(query_text: str) -> int | None:
if not query_text.isdigit():
return None
Expand Down Expand Up @@ -1756,8 +1768,11 @@ def optional_bool_arg(field: str, default: bool = False) -> bool:
output_format = output_format_arg()
has_bounty_id = "bounty_id" in args and args.get("bounty_id") is not None
has_issue_number = "issue_number" in args and args.get("issue_number") is not None
repo_selector = optional_repo_selector_arg()
if has_bounty_id and has_issue_number:
raise ValueError("use bounty_id or issue_number, not both")
if repo_selector is not None and not has_issue_number:
raise ValueError("repo can only be used with issue_number")
if has_bounty_id:
bounty = session.get(Bounty, positive_int_arg("bounty_id"))
if bounty is None:
Expand All @@ -1768,12 +1783,19 @@ def optional_bool_arg(field: str, default: bool = False) -> bool:
else work_proof_guidance(bounty)
)
if has_issue_number:
bounties = session.scalars(
select(Bounty)
.where(Bounty.issue_number == positive_int_arg("issue_number"))
.order_by(Bounty.id.desc())
.limit(2)
).all()
issue_number = positive_int_arg("issue_number")
issue_query = select(Bounty).where(Bounty.issue_number == issue_number)
if repo_selector is not None:
issue_query = issue_query.where(Bounty.repo == repo_selector)
bounties = session.scalars(issue_query.order_by(Bounty.id.desc()).limit(2)).all()
if repo_selector is not None and not bounties:
legacy_issue_query = select(Bounty).where(
Bounty.issue_number == issue_number,
func.lower(Bounty.repo) == repo_selector,
)
bounties = session.scalars(
legacy_issue_query.order_by(Bounty.id.desc()).limit(2)
).all()
if not bounties:
return "bounty not found"
if len(bounties) > 1:
Expand Down
5 changes: 4 additions & 1 deletion app/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
{"name": "get_proof", "description": "Get a public proof by hash"},
{
"name": "submit_work_proof",
"description": "Return submission instructions, optionally for a bounty_id or issue_number",
"description": (
"Return submission instructions for bounty_id or issue_number optionally "
"scoped by repo; supports text or json format"
),
},
]

Expand Down
107 changes: 105 additions & 2 deletions tests/test_api_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.db import create_schema, session_scope
from app.ledger.service import close_bounty, create_bounty, ensure_genesis, pay_bounty
from app.main import create_app
from app.models import BountyAttempt, Proof
from app.models import Bounty, BountyAttempt, Proof


def test_health_status_and_bounty_api(sqlite_url: str) -> None:
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_mcp_tools_list_and_call(sqlite_url: str) -> None:
submit_tool = next(
tool for tool in tools["result"]["tools"] if tool["name"] == "submit_work_proof"
)
assert "bounty_id or issue_number" in submit_tool["description"]
assert "issue_number optionally scoped by repo" in submit_tool["description"]
bounty_tool = next(tool for tool in tools["result"]["tools"] if tool["name"] == "get_bounty")
assert "accepted awards" in bounty_tool["description"]
attempt_tool = next(
Expand Down Expand Up @@ -1542,6 +1542,105 @@ def test_mcp_submit_work_proof_reports_unknown_bounty(sqlite_url: str) -> None:
assert result["result"]["content"][0]["text"] == "bounty not found"


def test_mcp_submit_work_proof_scopes_issue_number_by_repo(sqlite_url: str) -> None:
create_schema(sqlite_url)
with session_scope(sqlite_url) as session:
ensure_genesis(session)
create_bounty(
session,
repo="ramimbo/mergework",
issue_number=284,
issue_url="https://github.com/ramimbo/mergework/issues/284",
title="First bounty",
reward_mrwk="100",
acceptance="First acceptance.",
)
target = create_bounty(
session,
repo="example/mergework",
issue_number=284,
issue_url="https://github.com/example/mergework/issues/284",
title="Second bounty",
reward_mrwk="250",
acceptance="Second acceptance.",
)
target_id = target.id

client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret"))

response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"id": 28,
"method": "tools/call",
"params": {
"name": "submit_work_proof",
"arguments": {
"issue_number": 284,
"repo": "Example/MergeWork",
"format": "json",
},
},
},
)

result = response.json()["result"]
structured = result["structuredContent"]
assert json.loads(result["content"][0]["text"]) == structured
assert structured["bounty_id"] == target_id
assert structured["repository"] == "example/mergework"
assert structured["title"] == "Second bounty"
assert structured["reward_mrwk"] == "250"
assert structured["acceptance"] == "Second acceptance."


def test_mcp_submit_work_proof_scopes_legacy_mixed_case_repo(sqlite_url: str) -> None:
create_schema(sqlite_url)
with session_scope(sqlite_url) as session:
ensure_genesis(session)
bounty = Bounty(
repo="Ramimbo/MergeWork",
issue_number=284,
issue_url="https://github.com/ramimbo/mergework/issues/284",
title="Legacy bounty",
reward_microunits=100_000_000,
reserved_microunits=100_000_000,
max_awards=1,
awards_paid=0,
status="open",
acceptance="Legacy rows should still match repo-scoped MCP guidance.",
)
session.add(bounty)
session.flush()
bounty_id = bounty.id

client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret"))

response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"id": 33,
"method": "tools/call",
"params": {
"name": "submit_work_proof",
"arguments": {
"issue_number": 284,
"repo": "ramimbo/mergework",
"format": "json",
},
},
},
)

result = response.json()["result"]
structured = result["structuredContent"]
assert structured["bounty_id"] == bounty_id
assert structured["repository"] == "Ramimbo/MergeWork"
assert structured["title"] == "Legacy bounty"


@pytest.mark.parametrize(
("arguments", "request_id"),
[
Expand All @@ -1552,6 +1651,10 @@ def test_mcp_submit_work_proof_reports_unknown_bounty(sqlite_url: str) -> None:
({"bounty_id": 1, "issue_number": 1}, 25),
({"format": "xml"}, 26),
({"format": 1}, 27),
({"repo": "ramimbo/mergework"}, 29),
({"bounty_id": 1, "repo": "ramimbo/mergework"}, 30),
({"issue_number": 1, "repo": 1}, 31),
({"issue_number": 1, "repo": "a" * 201}, 32),
],
)
def test_mcp_submit_work_proof_rejects_invalid_bounty_selectors(
Expand Down
32 changes: 31 additions & 1 deletion tests/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hmac
import json
import re
from urllib.parse import parse_qs, urlparse

import pytest
from fastapi.testclient import TestClient
Expand All @@ -26,7 +27,7 @@
register_wallet,
validate_public_url,
)
from app.main import _safe_next_path, _signed_value, create_app
from app.main import _safe_next_path, _signed_value, _verified_value, create_app
from app.models import Bounty, LedgerEntry, Proof, Submission, WebhookEvent
from app.webhooks.github import handle_github_webhook

Expand Down Expand Up @@ -856,7 +857,10 @@ def test_admin_bounty_api_accepts_multi_award_count(
("https://evil.example/me", "/me"),
("//evil.example/me", "/me"),
("/\\evil.example/me", "/me"),
("/%2f%2fevil.example/me", "/me"),
("/%5cevil.example/me", "/me"),
("/me\nLocation: https://evil.example", "/me"),
("/me%0d%0aLocation:%20https://evil.example", "/me"),
("/me" + chr(0x85), "/me"),
("/me\x7f", "/me"),
("/" + ("a" * 2048), "/me"),
Expand All @@ -870,6 +874,32 @@ def test_oauth_next_path_rejects_external_or_headerlike_paths(
assert _safe_next_path(next_path) == expected


@pytest.mark.parametrize(
"next_path",
(
"/%255cevil.example/me",
"/me%250d%250aLocation:%20https://evil.example",
),
)
def test_github_login_stores_safe_default_for_encoded_next_route(
sqlite_url: str, monkeypatch: pytest.MonkeyPatch, next_path: str
) -> None:
monkeypatch.setenv("MERGEWORK_GITHUB_OAUTH_CLIENT_ID", "client-id")
monkeypatch.setenv("MERGEWORK_GITHUB_OAUTH_CLIENT_SECRET", "client-secret")
monkeypatch.setenv("MERGEWORK_COOKIE_SECRET", "test-cookie-secret")
monkeypatch.setenv("MERGEWORK_PUBLIC_BASE_URL", "https://mrwk.example.test")
client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret"))

response = client.get(f"/auth/github/login?next={next_path}", follow_redirects=False)

assert response.status_code == 302
query = parse_qs(urlparse(response.headers["location"]).query)
state_value = _verified_value(query["state"][0], "test-cookie-secret", 600)
assert state_value is not None
_nonce, stored_next_path = state_value.split(",", 1)
assert stored_next_path == "/me"


def test_amount_parser_rejects_non_finite_values() -> None:
for amount in ("NaN", "Infinity", "-Infinity"):
with pytest.raises(LedgerError, match="invalid MRWK amount"):
Expand Down
Loading