Skip to content

Commit 6bd4fef

Browse files
committed
fix(auth): preserve user-agent through oauth flow
1 parent cf110e3 commit 6bd4fef

4 files changed

Lines changed: 81 additions & 21 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import secrets
1010
import string
1111
import time
12-
from collections.abc import AsyncGenerator, Awaitable, Callable
12+
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
1515
from urllib.parse import quote, urlencode, urljoin, urlparse
@@ -303,10 +303,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
303303
f"Protected Resource Metadata request failed: {response.status_code}"
304304
) # pragma: no cover
305305

306-
async def _perform_authorization(self) -> httpx.Request:
306+
async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
307307
"""Perform the authorization flow."""
308308
auth_code, code_verifier = await self._perform_authorization_code_grant()
309-
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
309+
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier, headers=headers)
310310
return token_request
311311

312312
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
@@ -376,7 +376,12 @@ def _get_token_endpoint(self) -> str:
376376
return token_url
377377

378378
async def _exchange_token_authorization_code(
379-
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
379+
self,
380+
auth_code: str,
381+
code_verifier: str,
382+
*,
383+
token_data: dict[str, Any] | None = {},
384+
headers: Mapping[str, str] | None = None,
380385
) -> httpx.Request:
381386
"""Build token exchange request for authorization_code flow."""
382387
if self.context.client_metadata.redirect_uris is None:
@@ -401,10 +406,12 @@ async def _exchange_token_authorization_code(
401406
token_data["resource"] = self.context.get_resource_url() # RFC 8707
402407

403408
# Prepare authentication based on preferred method
404-
headers = {"Content-Type": "application/x-www-form-urlencoded"}
405-
token_data, headers = self.context.prepare_token_auth(token_data, headers)
409+
request_headers = {"Content-Type": "application/x-www-form-urlencoded"}
410+
if headers:
411+
request_headers.update(headers)
412+
token_data, request_headers = self.context.prepare_token_auth(token_data, request_headers)
406413

407-
return httpx.Request("POST", token_url, data=token_data, headers=headers)
414+
return httpx.Request("POST", token_url, data=token_data, headers=request_headers)
408415

409416
async def _handle_token_response(self, response: httpx.Response) -> None:
410417
"""Handle token exchange response."""
@@ -421,7 +428,7 @@ async def _handle_token_response(self, response: httpx.Response) -> None:
421428
self.context.update_token_expiry(token_response)
422429
await self.context.storage.set_tokens(token_response)
423430

424-
async def _refresh_token(self) -> httpx.Request:
431+
async def _refresh_token(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
425432
"""Build token refresh request."""
426433
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
427434
raise OAuthTokenError("No refresh token available") # pragma: no cover
@@ -446,10 +453,12 @@ async def _refresh_token(self) -> httpx.Request:
446453
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
447454

448455
# Prepare authentication based on preferred method
449-
headers = {"Content-Type": "application/x-www-form-urlencoded"}
450-
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
456+
request_headers = {"Content-Type": "application/x-www-form-urlencoded"}
457+
if headers:
458+
request_headers.update(headers)
459+
refresh_data, request_headers = self.context.prepare_token_auth(refresh_data, request_headers)
451460

452-
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
461+
return httpx.Request("POST", token_url, data=refresh_data, headers=request_headers)
453462

454463
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
455464
"""Handle token refresh response. Returns True if successful."""
@@ -483,6 +492,11 @@ def _add_auth_header(self, request: httpx.Request) -> None:
483492
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
484493
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
485494

495+
def _auth_flow_request_headers(self, request: httpx.Request) -> dict[str, str]:
496+
"""Headers that are safe to carry from the MCP request into OAuth requests."""
497+
user_agent = request.headers.get("User-Agent")
498+
return {"User-Agent": user_agent} if user_agent else {}
499+
486500
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
487501
content = await response.aread()
488502
metadata = OAuthMetadata.model_validate_json(content)
@@ -510,10 +524,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
510524

511525
# Capture protocol version from request headers
512526
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
527+
auth_flow_headers = self._auth_flow_request_headers(request)
513528

514529
if not self.context.is_token_valid() and self.context.can_refresh_token():
515530
# Try to refresh token
516-
refresh_request = await self._refresh_token()
531+
refresh_request = await self._refresh_token(headers=auth_flow_headers)
517532
refresh_response = yield refresh_request
518533

519534
if not await self._handle_refresh_response(refresh_response):
@@ -537,7 +552,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
537552
)
538553

539554
for url in prm_discovery_urls: # pragma: no branch
540-
discovery_request = create_oauth_metadata_request(url)
555+
discovery_request = create_oauth_metadata_request(url, headers=auth_flow_headers)
541556

542557
discovery_response = yield discovery_request # sending request
543558

@@ -563,7 +578,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
563578

564579
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
565580
for url in asm_discovery_urls: # pragma: no branch
566-
oauth_metadata_request = create_oauth_metadata_request(url)
581+
oauth_metadata_request = create_oauth_metadata_request(url, headers=auth_flow_headers)
567582
oauth_metadata_response = yield oauth_metadata_request
568583

569584
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
@@ -602,14 +617,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
602617
self.context.oauth_metadata,
603618
self.context.client_metadata,
604619
self.context.get_authorization_base_url(self.context.server_url),
620+
headers=auth_flow_headers,
605621
)
606622
registration_response = yield registration_request
607623
client_information = await handle_registration_response(registration_response)
608624
self.context.client_info = client_information
609625
await self.context.storage.set_client_info(client_information)
610626

611627
# Step 5: Perform authorization and complete token exchange
612-
token_response = yield await self._perform_authorization()
628+
token_response = yield await self._perform_authorization(headers=auth_flow_headers)
613629
await self._handle_token_response(token_response)
614630
except Exception:
615631
logger.exception("OAuth flow error")
@@ -634,7 +650,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
634650
)
635651

636652
# Step 2b: Perform (re-)authorization and token exchange
637-
token_response = yield await self._perform_authorization()
653+
token_response = yield await self._perform_authorization(headers=auth_flow_headers)
638654
await self._handle_token_response(token_response)
639655
except Exception: # pragma: no cover
640656
logger.exception("OAuth flow error")

src/mcp/client/auth/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from collections.abc import Mapping
23
from urllib.parse import urljoin, urlparse
34

45
from httpx import Request, Response
@@ -211,12 +212,18 @@ async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuth
211212
return True, None
212213

213214

214-
def create_oauth_metadata_request(url: str) -> Request:
215-
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
215+
def create_oauth_metadata_request(url: str, headers: Mapping[str, str] | None = None) -> Request:
216+
request_headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
217+
if headers:
218+
request_headers.update(headers)
219+
return Request("GET", url, headers=request_headers)
216220

217221

218222
def create_client_registration_request(
219-
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
223+
auth_server_metadata: OAuthMetadata | None,
224+
client_metadata: OAuthClientMetadata,
225+
auth_base_url: str,
226+
headers: Mapping[str, str] | None = None,
220227
) -> Request:
221228
"""Build a client registration request."""
222229

@@ -227,7 +234,11 @@ def create_client_registration_request(
227234

228235
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
229236

230-
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
237+
request_headers = {"Content-Type": "application/json"}
238+
if headers:
239+
request_headers.update(headers)
240+
241+
return Request("POST", registration_url, json=registration_data, headers=request_headers)
231242

232243

233244
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:

tests/interaction/auth/_harness.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ async def connect_with_oauth(
394394
client_metadata_url: str | None = None,
395395
headless: HeadlessOAuth | None = None,
396396
auth: httpx.Auth | None = None,
397+
headers: Mapping[str, str] | None = None,
397398
verify_tokens: bool = True,
398399
app_shim: Callable[[ASGIApp], ASGIApp] | None = None,
399400
on_request: Callable[[httpx.Request], None] | None = None,
@@ -455,7 +456,11 @@ async def hook(request: httpx.Request) -> None:
455456
await stack.enter_async_context(server.session_manager.run())
456457
http_client = await stack.enter_async_context(
457458
httpx.AsyncClient(
458-
transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks
459+
transport=StreamingASGITransport(app),
460+
base_url=BASE_URL,
461+
auth=oauth,
462+
headers=headers,
463+
event_hooks=event_hooks,
459464
)
460465
)
461466
headless.bind(http_client)

tests/interaction/auth/test_flow.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,34 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow
118118
assert storage.tokens.access_token in provider.access_tokens
119119

120120

121+
@requirement("client-transport:http:custom-headers")
122+
async def test_oauth_flow_preserves_custom_user_agent_on_auth_requests() -> None:
123+
"""OAuth requests keep the caller's User-Agent from the Streamable HTTP client."""
124+
requests: list[httpx.Request] = []
125+
server = Server("guarded", on_list_tools=list_tools)
126+
user_agent = "mcp-python-sdk/issue-1664"
127+
128+
with anyio.fail_after(5):
129+
async with connect_with_oauth(
130+
server,
131+
provider=InMemoryAuthorizationServerProvider(),
132+
headers={"User-Agent": user_agent},
133+
on_request=requests.append,
134+
) as (client, _):
135+
await client.list_tools()
136+
137+
auth_paths = {
138+
"/.well-known/oauth-protected-resource/mcp",
139+
"/.well-known/oauth-authorization-server",
140+
"/register",
141+
"/token",
142+
}
143+
auth_requests = [request for request in requests if request.url.path in auth_paths]
144+
145+
assert {request.url.path for request in auth_requests} == auth_paths
146+
assert all(request.headers["user-agent"] == user_agent for request in auth_requests)
147+
148+
121149
@requirement("hosting:auth:authinfo-propagates")
122150
async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None:
123151
"""A tool handler reads the request's access token through `get_access_token()`."""

0 commit comments

Comments
 (0)