Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import secrets
import string
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse

import anyio
import httpx
Expand Down Expand Up @@ -54,6 +54,16 @@
logger = logging.getLogger(__name__)


def build_authorization_url(auth_endpoint: str, auth_params: Mapping[str, str | None]) -> str:
"""Append OAuth authorization parameters to an endpoint that may already include query params."""
parsed_endpoint = urlparse(auth_endpoint)
query_params = [
*parse_qsl(parsed_endpoint.query, keep_blank_values=True),
*auth_params.items(),
]
return urlunparse(parsed_endpoint._replace(query=urlencode(query_params)))


class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""

Expand Down Expand Up @@ -352,7 +362,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
if "offline_access" in self.context.client_metadata.scope.split():
auth_params["prompt"] = "consent"

authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
authorization_url = build_authorization_url(auth_endpoint, auth_params)
await self.context.redirect_handler(authorization_url)

# Wait for callback
Expand Down
36 changes: 36 additions & 0 deletions tests/interaction/auth/test_authorize_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,42 @@ async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_me
assert json.loads(register.content)["scope"] == "from-header"


@requirement("client-auth:resource-parameter")
async def test_authorization_endpoint_existing_query_params_are_preserved() -> None:
"""Authorization metadata endpoints may include provider-required query params."""
provider = InMemoryAuthorizationServerProvider()
server = Server("guarded", on_list_tools=list_tools)
override = OAuthMetadata(
issuer=AnyHttpUrl(f"{BASE_URL}/"),
authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize?prompt=select_account"),
token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"),
registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"),
scopes_supported=["mcp"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["S256"],
)
serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()}

with anyio.fail_after(5):
async with connect_with_oauth(
server,
provider=provider,
app_shim=lambda app: shimmed_app(app, serve=serve),
) as (client, headless):
await client.list_tools()

assert headless.authorize_url is not None
split_url = urlsplit(headless.authorize_url)
assert split_url.path == "/authorize"
assert split_url.query.count("?") == 0

params = authorize_params(headless.authorize_url)
assert params["prompt"] == "select_account"
assert params["response_type"] == "code"
assert params["client_id"] != ""
assert params["redirect_uri"] == REDIRECT_URI


@requirement("client-auth:pkce:refuse-if-unsupported")
async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None:
"""AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE.
Expand Down
Loading