Skip to content

Commit 8f1f658

Browse files
committed
fix: preserve OAuth endpoint query params
1 parent 6d0c160 commit 8f1f658

2 files changed

Lines changed: 65 additions & 8 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import secrets
1010
import string
1111
import time
12-
from collections.abc import AsyncGenerator, Awaitable, Callable
12+
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
15-
from urllib.parse import quote, urlencode, urljoin, urlparse
15+
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse
1616

1717
import anyio
1818
import httpx
@@ -53,6 +53,13 @@
5353
logger = logging.getLogger(__name__)
5454

5555

56+
def _append_url_query_params(url: str, params: Mapping[str, str]) -> str:
57+
parsed = urlparse(url)
58+
query_params = parse_qsl(parsed.query, keep_blank_values=True)
59+
query_params.extend(params.items())
60+
return urlunparse(parsed._replace(query=urlencode(query_params)))
61+
62+
5663
class PKCEParameters(BaseModel):
5764
"""PKCE (Proof Key for Code Exchange) parameters."""
5865

@@ -327,14 +334,17 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
327334

328335
if not self.context.client_info:
329336
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
337+
client_id = self.context.client_info.client_id
338+
if not client_id:
339+
raise OAuthFlowError("No client ID available for authorization") # pragma: no cover
330340

331341
# Generate PKCE parameters
332342
pkce_params = PKCEParameters.generate()
333343
state = secrets.token_urlsafe(32)
334344

335-
auth_params = {
345+
auth_params: dict[str, str] = {
336346
"response_type": "code",
337-
"client_id": self.context.client_info.client_id,
347+
"client_id": client_id,
338348
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
339349
"state": state,
340350
"code_challenge": pkce_params.code_challenge,
@@ -345,15 +355,16 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
345355
if self.context.should_include_resource_param(self.context.protocol_version):
346356
auth_params["resource"] = self.context.get_resource_url() # RFC 8707
347357

348-
if self.context.client_metadata.scope: # pragma: no branch
349-
auth_params["scope"] = self.context.client_metadata.scope
358+
scope = self.context.client_metadata.scope
359+
if scope: # pragma: no branch
360+
auth_params["scope"] = scope
350361

351362
# OIDC requires prompt=consent when offline_access is requested
352363
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
353-
if "offline_access" in self.context.client_metadata.scope.split():
364+
if "offline_access" in scope.split():
354365
auth_params["prompt"] = "consent"
355366

356-
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
367+
authorization_url = _append_url_query_params(auth_endpoint, auth_params)
357368
await self.context.redirect_handler(authorization_url)
358369

359370
# Wait for callback

tests/client/test_auth.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,52 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O
606606
assert "client_id=test_client" in content
607607
assert "client_secret=test_secret" in content
608608

609+
@pytest.mark.anyio
610+
async def test_authorization_endpoint_preserves_existing_query_params(
611+
self, oauth_provider: OAuthClientProvider
612+
):
613+
"""Authorization endpoint query params should survive OAuth parameter injection."""
614+
captured_auth_url: str | None = None
615+
captured_state: str | None = None
616+
617+
async def redirect_handler(url: str) -> None:
618+
nonlocal captured_auth_url, captured_state
619+
captured_auth_url = url
620+
captured_state = parse_qs(urlparse(url).query)["state"][0]
621+
622+
async def callback_handler() -> tuple[str, str | None]:
623+
return "test_auth_code", captured_state
624+
625+
oauth_provider.context.redirect_handler = redirect_handler
626+
oauth_provider.context.callback_handler = callback_handler
627+
oauth_provider.context.oauth_metadata = OAuthMetadata(
628+
issuer=AnyHttpUrl("https://test.salesforce.com"),
629+
authorization_endpoint=AnyHttpUrl(
630+
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account"
631+
),
632+
token_endpoint=AnyHttpUrl("https://test.salesforce.com/services/oauth2/token"),
633+
)
634+
oauth_provider.context.client_info = OAuthClientInformationFull(
635+
client_id="test_client",
636+
client_secret="test_secret",
637+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
638+
)
639+
640+
auth_code, code_verifier = await oauth_provider._perform_authorization_code_grant()
641+
642+
assert auth_code == "test_auth_code"
643+
assert code_verifier
644+
assert captured_auth_url is not None
645+
parsed = urlparse(captured_auth_url)
646+
params = parse_qs(parsed.query)
647+
assert parsed.scheme == "https"
648+
assert parsed.netloc == "test.salesforce.com"
649+
assert parsed.path == "/services/oauth2/authorize"
650+
assert params["prompt"] == ["select_account"]
651+
assert params["response_type"] == ["code"]
652+
assert params["client_id"] == ["test_client"]
653+
assert params["redirect_uri"] == ["http://localhost:3030/callback"]
654+
609655
@pytest.mark.anyio
610656
async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
611657
"""Test refresh token request building."""

0 commit comments

Comments
 (0)