diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 01bcc8234..90335d43e 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -9,10 +9,10 @@ import secrets import string import time -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass, field from typing import Any, Protocol -from urllib.parse import quote, urlencode, urljoin, urlparse +from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse import anyio import httpx @@ -54,6 +54,16 @@ logger = logging.getLogger(__name__) +def build_authorization_url(auth_endpoint: str, auth_params: Mapping[str, str | None]) -> str: + """Append OAuth authorization parameters to an endpoint that may already include query params.""" + parsed_endpoint = urlparse(auth_endpoint) + query_params = [ + *parse_qsl(parsed_endpoint.query, keep_blank_values=True), + *auth_params.items(), + ] + return urlunparse(parsed_endpoint._replace(query=urlencode(query_params))) + + class PKCEParameters(BaseModel): """PKCE (Proof Key for Code Exchange) parameters.""" @@ -352,7 +362,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if "offline_access" in self.context.client_metadata.scope.split(): auth_params["prompt"] = "consent" - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + authorization_url = build_authorization_url(auth_endpoint, auth_params) await self.context.redirect_handler(authorization_url) # Wait for callback diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py index cb8524c09..fc45f7093 100644 --- a/tests/interaction/auth/test_authorize_token.py +++ b/tests/interaction/auth/test_authorize_token.py @@ -341,6 +341,42 @@ async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_me assert json.loads(register.content)["scope"] == "from-header" +@requirement("client-auth:resource-parameter") +async def test_authorization_endpoint_existing_query_params_are_preserved() -> None: + """Authorization metadata endpoints may include provider-required query params.""" + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize?prompt=select_account"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + app_shim=lambda app: shimmed_app(app, serve=serve), + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + split_url = urlsplit(headless.authorize_url) + assert split_url.path == "/authorize" + assert split_url.query.count("?") == 0 + + params = authorize_params(headless.authorize_url) + assert params["prompt"] == "select_account" + assert params["response_type"] == "code" + assert params["client_id"] != "" + assert params["redirect_uri"] == REDIRECT_URI + + @requirement("client-auth:pkce:refuse-if-unsupported") async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None: """AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE.