From 2205c377dc42efd6808773d25dbfc22ceb88732c Mon Sep 17 00:00:00 2001 From: go165 <196723798+go165@users.noreply.github.com> Date: Mon, 15 Jun 2026 12:47:38 +0800 Subject: [PATCH] fix(auth): preserve user-agent through oauth flow --- .../auth/extensions/client_credentials.py | 48 ++++++++++++------- src/mcp/client/auth/oauth2.py | 48 ++++++++++++------- src/mcp/client/auth/utils.py | 19 ++++++-- .../extensions/test_client_credentials.py | 21 ++++++++ tests/interaction/auth/_harness.py | 7 ++- tests/interaction/auth/test_flow.py | 28 +++++++++++ 6 files changed, 134 insertions(+), 37 deletions(-) diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index cb6dafb407..b8d724e306 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -9,7 +9,7 @@ import time import warnings -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from typing import Any, Literal from uuid import uuid4 @@ -82,20 +82,22 @@ async def _initialize(self) -> None: self.context.client_info = self._fixed_client_info self._initialized = True - async def _perform_authorization(self) -> httpx.Request: + async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request: """Perform client_credentials authorization.""" - return await self._exchange_token_client_credentials() + return await self._exchange_token_client_credentials(headers=headers) - async def _exchange_token_client_credentials(self) -> httpx.Request: + async def _exchange_token_client_credentials(self, headers: Mapping[str, str] | None = None) -> httpx.Request: """Build token exchange request for client_credentials grant.""" token_data: dict[str, Any] = { "grant_type": "client_credentials", } - headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + request_headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + if headers: + request_headers.update(headers) # Use standard auth methods (client_secret_basic, client_secret_post, none) - token_data, headers = self.context.prepare_token_auth(token_data, headers) + token_data, request_headers = self.context.prepare_token_auth(token_data, request_headers) if self.context.should_include_resource_param(self.context.protocol_version): token_data["resource"] = self.context.get_resource_url() @@ -104,7 +106,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request: token_data["scope"] = self.context.client_metadata.scope token_url = self._get_token_endpoint() - return httpx.Request("POST", token_url, data=token_data, headers=headers) + return httpx.Request("POST", token_url, data=token_data, headers=request_headers) def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]: @@ -296,9 +298,9 @@ async def _initialize(self) -> None: self.context.client_info = self._fixed_client_info self._initialized = True - async def _perform_authorization(self) -> httpx.Request: + async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request: """Perform client_credentials authorization with private_key_jwt.""" - return await self._exchange_token_client_credentials() + return await self._exchange_token_client_credentials(headers=headers) async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None: """Add JWT assertion for client authentication to token endpoint parameters.""" @@ -314,13 +316,15 @@ async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> token_data["client_assertion"] = assertion token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" - async def _exchange_token_client_credentials(self) -> httpx.Request: + async def _exchange_token_client_credentials(self, headers: Mapping[str, str] | None = None) -> httpx.Request: """Build token exchange request for client_credentials grant with private_key_jwt.""" token_data: dict[str, Any] = { "grant_type": "client_credentials", } - headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + request_headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + if headers: + request_headers.update(headers) # Add JWT client authentication (RFC 7523 Section 2.2) await self._add_client_authentication_jwt(token_data=token_data) @@ -332,7 +336,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request: token_data["scope"] = self.context.client_metadata.scope token_url = self._get_token_endpoint() - return httpx.Request("POST", token_url, data=token_data, headers=headers) + return httpx.Request("POST", token_url, data=token_data, headers=request_headers) class JWTParameters(BaseModel): @@ -419,21 +423,33 @@ def __init__( self.jwt_parameters = jwt_parameters async def _exchange_token_authorization_code( - self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None + self, + auth_code: str, + code_verifier: str, + *, + token_data: dict[str, Any] | None = None, + headers: Mapping[str, str] | None = None, ) -> httpx.Request: # pragma: no cover """Build token exchange request for authorization_code flow.""" token_data = token_data or {} if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt": self._add_client_authentication_jwt(token_data=token_data) - return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data) + return await super()._exchange_token_authorization_code( + auth_code, + code_verifier, + token_data=token_data, + headers=headers, + ) - async def _perform_authorization(self) -> httpx.Request: # pragma: no cover + async def _perform_authorization( + self, headers: Mapping[str, str] | None = None + ) -> httpx.Request: # pragma: no cover """Perform the authorization flow.""" if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types: token_request = await self._exchange_token_jwt_bearer() return token_request else: - return await super()._perform_authorization() + return await super()._perform_authorization(headers=headers) def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover """Add JWT assertion for client authentication to token endpoint parameters.""" diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 01bcc82347..d297ad6198 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -9,7 +9,7 @@ import secrets import string import time -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass, field from typing import Any, Protocol from urllib.parse import quote, urlencode, urljoin, urlparse @@ -303,10 +303,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> f"Protected Resource Metadata request failed: {response.status_code}" ) # pragma: no cover - async def _perform_authorization(self) -> httpx.Request: + async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request: """Perform the authorization flow.""" auth_code, code_verifier = await self._perform_authorization_code_grant() - token_request = await self._exchange_token_authorization_code(auth_code, code_verifier) + token_request = await self._exchange_token_authorization_code(auth_code, code_verifier, headers=headers) return token_request async def _perform_authorization_code_grant(self) -> tuple[str, str]: @@ -376,7 +376,12 @@ def _get_token_endpoint(self) -> str: return token_url async def _exchange_token_authorization_code( - self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {} + self, + auth_code: str, + code_verifier: str, + *, + token_data: dict[str, Any] | None = {}, + headers: Mapping[str, str] | None = None, ) -> httpx.Request: """Build token exchange request for authorization_code flow.""" if self.context.client_metadata.redirect_uris is None: @@ -401,10 +406,12 @@ async def _exchange_token_authorization_code( token_data["resource"] = self.context.get_resource_url() # RFC 8707 # Prepare authentication based on preferred method - headers = {"Content-Type": "application/x-www-form-urlencoded"} - token_data, headers = self.context.prepare_token_auth(token_data, headers) + request_headers = {"Content-Type": "application/x-www-form-urlencoded"} + if headers: + request_headers.update(headers) + token_data, request_headers = self.context.prepare_token_auth(token_data, request_headers) - return httpx.Request("POST", token_url, data=token_data, headers=headers) + return httpx.Request("POST", token_url, data=token_data, headers=request_headers) async def _handle_token_response(self, response: httpx.Response) -> None: """Handle token exchange response.""" @@ -421,7 +428,7 @@ async def _handle_token_response(self, response: httpx.Response) -> None: self.context.update_token_expiry(token_response) await self.context.storage.set_tokens(token_response) - async def _refresh_token(self) -> httpx.Request: + async def _refresh_token(self, headers: Mapping[str, str] | None = None) -> httpx.Request: """Build token refresh request.""" if not self.context.current_tokens or not self.context.current_tokens.refresh_token: raise OAuthTokenError("No refresh token available") # pragma: no cover @@ -446,10 +453,12 @@ async def _refresh_token(self) -> httpx.Request: refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 # Prepare authentication based on preferred method - headers = {"Content-Type": "application/x-www-form-urlencoded"} - refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers) + request_headers = {"Content-Type": "application/x-www-form-urlencoded"} + if headers: + request_headers.update(headers) + refresh_data, request_headers = self.context.prepare_token_auth(refresh_data, request_headers) - return httpx.Request("POST", token_url, data=refresh_data, headers=headers) + return httpx.Request("POST", token_url, data=refresh_data, headers=request_headers) async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" @@ -483,6 +492,11 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + def _auth_flow_request_headers(self, request: httpx.Request) -> dict[str, str]: + """Headers that are safe to carry from the MCP request into OAuth requests.""" + user_agent = request.headers.get("User-Agent") + return {"User-Agent": user_agent} if user_agent else {} + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: content = await response.aread() metadata = OAuthMetadata.model_validate_json(content) @@ -510,10 +524,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) + auth_flow_headers = self._auth_flow_request_headers(request) if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token - refresh_request = await self._refresh_token() + refresh_request = await self._refresh_token(headers=auth_flow_headers) refresh_response = yield refresh_request if not await self._handle_refresh_response(refresh_response): @@ -537,7 +552,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. ) for url in prm_discovery_urls: # pragma: no branch - discovery_request = create_oauth_metadata_request(url) + discovery_request = create_oauth_metadata_request(url, headers=auth_flow_headers) discovery_response = yield discovery_request # sending request @@ -563,7 +578,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) for url in asm_discovery_urls: # pragma: no branch - oauth_metadata_request = create_oauth_metadata_request(url) + oauth_metadata_request = create_oauth_metadata_request(url, headers=auth_flow_headers) oauth_metadata_response = yield oauth_metadata_request ok, asm = await handle_auth_metadata_response(oauth_metadata_response) @@ -602,6 +617,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.oauth_metadata, self.context.client_metadata, self.context.get_authorization_base_url(self.context.server_url), + headers=auth_flow_headers, ) registration_response = yield registration_request client_information = await handle_registration_response(registration_response) @@ -609,7 +625,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. await self.context.storage.set_client_info(client_information) # Step 5: Perform authorization and complete token exchange - token_response = yield await self._perform_authorization() + token_response = yield await self._perform_authorization(headers=auth_flow_headers) await self._handle_token_response(token_response) except Exception: logger.exception("OAuth flow error") @@ -634,7 +650,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. ) # Step 2b: Perform (re-)authorization and token exchange - token_response = yield await self._perform_authorization() + token_response = yield await self._perform_authorization(headers=auth_flow_headers) await self._handle_token_response(token_response) except Exception: # pragma: no cover logger.exception("OAuth flow error") diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 780a24e859..7b78727b43 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping from urllib.parse import urljoin, urlparse from httpx import Request, Response @@ -211,12 +212,18 @@ async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuth return True, None -def create_oauth_metadata_request(url: str) -> Request: - return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) +def create_oauth_metadata_request(url: str, headers: Mapping[str, str] | None = None) -> Request: + request_headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + if headers: + request_headers.update(headers) + return Request("GET", url, headers=request_headers) def create_client_registration_request( - auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str + auth_server_metadata: OAuthMetadata | None, + client_metadata: OAuthClientMetadata, + auth_base_url: str, + headers: Mapping[str, str] | None = None, ) -> Request: """Build a client registration request.""" @@ -227,7 +234,11 @@ def create_client_registration_request( registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}) + request_headers = {"Content-Type": "application/json"} + if headers: + request_headers.update(headers) + + return Request("POST", registration_url, json=registration_data, headers=request_headers) async def handle_registration_response(response: Response) -> OAuthClientInformationFull: diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 09760f4530..66b37d8e19 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -252,6 +252,27 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt assert "scope=read write" in content assert "resource=https://api.example.com/v1/mcp" in content + @pytest.mark.anyio + async def test_exchange_token_preserves_user_agent_header(self, mock_storage: MockTokenStorage): + """Test that client_credentials token requests preserve a caller User-Agent.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + ) + await provider._initialize() + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + + request = await provider._perform_authorization(headers={"User-Agent": "mcp-python-sdk/issue-1664"}) + + assert request.headers["User-Agent"] == "mcp-python-sdk/issue-1664" + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + @pytest.mark.anyio async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage): """Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 ยง2.3.1).""" diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py index d013364f33..862cd8eb82 100644 --- a/tests/interaction/auth/_harness.py +++ b/tests/interaction/auth/_harness.py @@ -394,6 +394,7 @@ async def connect_with_oauth( client_metadata_url: str | None = None, headless: HeadlessOAuth | None = None, auth: httpx.Auth | None = None, + headers: Mapping[str, str] | None = None, verify_tokens: bool = True, app_shim: Callable[[ASGIApp], ASGIApp] | None = None, on_request: Callable[[httpx.Request], None] | None = None, @@ -455,7 +456,11 @@ async def hook(request: httpx.Request) -> None: await stack.enter_async_context(server.session_manager.run()) http_client = await stack.enter_async_context( httpx.AsyncClient( - transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks + transport=StreamingASGITransport(app), + base_url=BASE_URL, + auth=oauth, + headers=headers, + event_hooks=event_hooks, ) ) headless.bind(http_client) diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py index 968fc5f980..22c053aa50 100644 --- a/tests/interaction/auth/test_flow.py +++ b/tests/interaction/auth/test_flow.py @@ -118,6 +118,34 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow assert storage.tokens.access_token in provider.access_tokens +@requirement("client-transport:http:custom-headers") +async def test_oauth_flow_preserves_custom_user_agent_on_auth_requests() -> None: + """OAuth requests keep the caller's User-Agent from the Streamable HTTP client.""" + requests: list[httpx.Request] = [] + server = Server("guarded", on_list_tools=list_tools) + user_agent = "mcp-python-sdk/issue-1664" + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=InMemoryAuthorizationServerProvider(), + headers={"User-Agent": user_agent}, + on_request=requests.append, + ) as (client, _): + await client.list_tools() + + auth_paths = { + "/.well-known/oauth-protected-resource/mcp", + "/.well-known/oauth-authorization-server", + "/register", + "/token", + } + auth_requests = [request for request in requests if request.url.path in auth_paths] + + assert {request.url.path for request in auth_requests} == auth_paths + assert all(request.headers["user-agent"] == user_agent for request in auth_requests) + + @requirement("hosting:auth:authinfo-propagates") async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None: """A tool handler reads the request's access token through `get_access_token()`."""