Skip to content

Commit 2205c37

Browse files
committed
fix(auth): preserve user-agent through oauth flow
1 parent cf110e3 commit 2205c37

6 files changed

Lines changed: 134 additions & 37 deletions

File tree

src/mcp/client/auth/extensions/client_credentials.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import time
1111
import warnings
12-
from collections.abc import Awaitable, Callable
12+
from collections.abc import Awaitable, Callable, Mapping
1313
from typing import Any, Literal
1414
from uuid import uuid4
1515

@@ -82,20 +82,22 @@ async def _initialize(self) -> None:
8282
self.context.client_info = self._fixed_client_info
8383
self._initialized = True
8484

85-
async def _perform_authorization(self) -> httpx.Request:
85+
async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
8686
"""Perform client_credentials authorization."""
87-
return await self._exchange_token_client_credentials()
87+
return await self._exchange_token_client_credentials(headers=headers)
8888

89-
async def _exchange_token_client_credentials(self) -> httpx.Request:
89+
async def _exchange_token_client_credentials(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
9090
"""Build token exchange request for client_credentials grant."""
9191
token_data: dict[str, Any] = {
9292
"grant_type": "client_credentials",
9393
}
9494

95-
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
95+
request_headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
96+
if headers:
97+
request_headers.update(headers)
9698

9799
# Use standard auth methods (client_secret_basic, client_secret_post, none)
98-
token_data, headers = self.context.prepare_token_auth(token_data, headers)
100+
token_data, request_headers = self.context.prepare_token_auth(token_data, request_headers)
99101

100102
if self.context.should_include_resource_param(self.context.protocol_version):
101103
token_data["resource"] = self.context.get_resource_url()
@@ -104,7 +106,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
104106
token_data["scope"] = self.context.client_metadata.scope
105107

106108
token_url = self._get_token_endpoint()
107-
return httpx.Request("POST", token_url, data=token_data, headers=headers)
109+
return httpx.Request("POST", token_url, data=token_data, headers=request_headers)
108110

109111

110112
def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]:
@@ -296,9 +298,9 @@ async def _initialize(self) -> None:
296298
self.context.client_info = self._fixed_client_info
297299
self._initialized = True
298300

299-
async def _perform_authorization(self) -> httpx.Request:
301+
async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
300302
"""Perform client_credentials authorization with private_key_jwt."""
301-
return await self._exchange_token_client_credentials()
303+
return await self._exchange_token_client_credentials(headers=headers)
302304

303305
async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None:
304306
"""Add JWT assertion for client authentication to token endpoint parameters."""
@@ -314,13 +316,15 @@ async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) ->
314316
token_data["client_assertion"] = assertion
315317
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
316318

317-
async def _exchange_token_client_credentials(self) -> httpx.Request:
319+
async def _exchange_token_client_credentials(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
318320
"""Build token exchange request for client_credentials grant with private_key_jwt."""
319321
token_data: dict[str, Any] = {
320322
"grant_type": "client_credentials",
321323
}
322324

323-
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
325+
request_headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
326+
if headers:
327+
request_headers.update(headers)
324328

325329
# Add JWT client authentication (RFC 7523 Section 2.2)
326330
await self._add_client_authentication_jwt(token_data=token_data)
@@ -332,7 +336,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
332336
token_data["scope"] = self.context.client_metadata.scope
333337

334338
token_url = self._get_token_endpoint()
335-
return httpx.Request("POST", token_url, data=token_data, headers=headers)
339+
return httpx.Request("POST", token_url, data=token_data, headers=request_headers)
336340

337341

338342
class JWTParameters(BaseModel):
@@ -419,21 +423,33 @@ def __init__(
419423
self.jwt_parameters = jwt_parameters
420424

421425
async def _exchange_token_authorization_code(
422-
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
426+
self,
427+
auth_code: str,
428+
code_verifier: str,
429+
*,
430+
token_data: dict[str, Any] | None = None,
431+
headers: Mapping[str, str] | None = None,
423432
) -> httpx.Request: # pragma: no cover
424433
"""Build token exchange request for authorization_code flow."""
425434
token_data = token_data or {}
426435
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
427436
self._add_client_authentication_jwt(token_data=token_data)
428-
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
437+
return await super()._exchange_token_authorization_code(
438+
auth_code,
439+
code_verifier,
440+
token_data=token_data,
441+
headers=headers,
442+
)
429443

430-
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
444+
async def _perform_authorization(
445+
self, headers: Mapping[str, str] | None = None
446+
) -> httpx.Request: # pragma: no cover
431447
"""Perform the authorization flow."""
432448
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
433449
token_request = await self._exchange_token_jwt_bearer()
434450
return token_request
435451
else:
436-
return await super()._perform_authorization()
452+
return await super()._perform_authorization(headers=headers)
437453

438454
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
439455
"""Add JWT assertion for client authentication to token endpoint parameters."""

src/mcp/client/auth/oauth2.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import secrets
1010
import string
1111
import time
12-
from collections.abc import AsyncGenerator, Awaitable, Callable
12+
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
1515
from urllib.parse import quote, urlencode, urljoin, urlparse
@@ -303,10 +303,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
303303
f"Protected Resource Metadata request failed: {response.status_code}"
304304
) # pragma: no cover
305305

306-
async def _perform_authorization(self) -> httpx.Request:
306+
async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
307307
"""Perform the authorization flow."""
308308
auth_code, code_verifier = await self._perform_authorization_code_grant()
309-
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
309+
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier, headers=headers)
310310
return token_request
311311

312312
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
@@ -376,7 +376,12 @@ def _get_token_endpoint(self) -> str:
376376
return token_url
377377

378378
async def _exchange_token_authorization_code(
379-
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
379+
self,
380+
auth_code: str,
381+
code_verifier: str,
382+
*,
383+
token_data: dict[str, Any] | None = {},
384+
headers: Mapping[str, str] | None = None,
380385
) -> httpx.Request:
381386
"""Build token exchange request for authorization_code flow."""
382387
if self.context.client_metadata.redirect_uris is None:
@@ -401,10 +406,12 @@ async def _exchange_token_authorization_code(
401406
token_data["resource"] = self.context.get_resource_url() # RFC 8707
402407

403408
# Prepare authentication based on preferred method
404-
headers = {"Content-Type": "application/x-www-form-urlencoded"}
405-
token_data, headers = self.context.prepare_token_auth(token_data, headers)
409+
request_headers = {"Content-Type": "application/x-www-form-urlencoded"}
410+
if headers:
411+
request_headers.update(headers)
412+
token_data, request_headers = self.context.prepare_token_auth(token_data, request_headers)
406413

407-
return httpx.Request("POST", token_url, data=token_data, headers=headers)
414+
return httpx.Request("POST", token_url, data=token_data, headers=request_headers)
408415

409416
async def _handle_token_response(self, response: httpx.Response) -> None:
410417
"""Handle token exchange response."""
@@ -421,7 +428,7 @@ async def _handle_token_response(self, response: httpx.Response) -> None:
421428
self.context.update_token_expiry(token_response)
422429
await self.context.storage.set_tokens(token_response)
423430

424-
async def _refresh_token(self) -> httpx.Request:
431+
async def _refresh_token(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
425432
"""Build token refresh request."""
426433
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
427434
raise OAuthTokenError("No refresh token available") # pragma: no cover
@@ -446,10 +453,12 @@ async def _refresh_token(self) -> httpx.Request:
446453
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
447454

448455
# Prepare authentication based on preferred method
449-
headers = {"Content-Type": "application/x-www-form-urlencoded"}
450-
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
456+
request_headers = {"Content-Type": "application/x-www-form-urlencoded"}
457+
if headers:
458+
request_headers.update(headers)
459+
refresh_data, request_headers = self.context.prepare_token_auth(refresh_data, request_headers)
451460

452-
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
461+
return httpx.Request("POST", token_url, data=refresh_data, headers=request_headers)
453462

454463
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
455464
"""Handle token refresh response. Returns True if successful."""
@@ -483,6 +492,11 @@ def _add_auth_header(self, request: httpx.Request) -> None:
483492
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
484493
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
485494

495+
def _auth_flow_request_headers(self, request: httpx.Request) -> dict[str, str]:
496+
"""Headers that are safe to carry from the MCP request into OAuth requests."""
497+
user_agent = request.headers.get("User-Agent")
498+
return {"User-Agent": user_agent} if user_agent else {}
499+
486500
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
487501
content = await response.aread()
488502
metadata = OAuthMetadata.model_validate_json(content)
@@ -510,10 +524,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
510524

511525
# Capture protocol version from request headers
512526
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
527+
auth_flow_headers = self._auth_flow_request_headers(request)
513528

514529
if not self.context.is_token_valid() and self.context.can_refresh_token():
515530
# Try to refresh token
516-
refresh_request = await self._refresh_token()
531+
refresh_request = await self._refresh_token(headers=auth_flow_headers)
517532
refresh_response = yield refresh_request
518533

519534
if not await self._handle_refresh_response(refresh_response):
@@ -537,7 +552,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
537552
)
538553

539554
for url in prm_discovery_urls: # pragma: no branch
540-
discovery_request = create_oauth_metadata_request(url)
555+
discovery_request = create_oauth_metadata_request(url, headers=auth_flow_headers)
541556

542557
discovery_response = yield discovery_request # sending request
543558

@@ -563,7 +578,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
563578

564579
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
565580
for url in asm_discovery_urls: # pragma: no branch
566-
oauth_metadata_request = create_oauth_metadata_request(url)
581+
oauth_metadata_request = create_oauth_metadata_request(url, headers=auth_flow_headers)
567582
oauth_metadata_response = yield oauth_metadata_request
568583

569584
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
@@ -602,14 +617,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
602617
self.context.oauth_metadata,
603618
self.context.client_metadata,
604619
self.context.get_authorization_base_url(self.context.server_url),
620+
headers=auth_flow_headers,
605621
)
606622
registration_response = yield registration_request
607623
client_information = await handle_registration_response(registration_response)
608624
self.context.client_info = client_information
609625
await self.context.storage.set_client_info(client_information)
610626

611627
# Step 5: Perform authorization and complete token exchange
612-
token_response = yield await self._perform_authorization()
628+
token_response = yield await self._perform_authorization(headers=auth_flow_headers)
613629
await self._handle_token_response(token_response)
614630
except Exception:
615631
logger.exception("OAuth flow error")
@@ -634,7 +650,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
634650
)
635651

636652
# Step 2b: Perform (re-)authorization and token exchange
637-
token_response = yield await self._perform_authorization()
653+
token_response = yield await self._perform_authorization(headers=auth_flow_headers)
638654
await self._handle_token_response(token_response)
639655
except Exception: # pragma: no cover
640656
logger.exception("OAuth flow error")

src/mcp/client/auth/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from collections.abc import Mapping
23
from urllib.parse import urljoin, urlparse
34

45
from httpx import Request, Response
@@ -211,12 +212,18 @@ async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuth
211212
return True, None
212213

213214

214-
def create_oauth_metadata_request(url: str) -> Request:
215-
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
215+
def create_oauth_metadata_request(url: str, headers: Mapping[str, str] | None = None) -> Request:
216+
request_headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
217+
if headers:
218+
request_headers.update(headers)
219+
return Request("GET", url, headers=request_headers)
216220

217221

218222
def create_client_registration_request(
219-
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
223+
auth_server_metadata: OAuthMetadata | None,
224+
client_metadata: OAuthClientMetadata,
225+
auth_base_url: str,
226+
headers: Mapping[str, str] | None = None,
220227
) -> Request:
221228
"""Build a client registration request."""
222229

@@ -227,7 +234,11 @@ def create_client_registration_request(
227234

228235
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
229236

230-
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
237+
request_headers = {"Content-Type": "application/json"}
238+
if headers:
239+
request_headers.update(headers)
240+
241+
return Request("POST", registration_url, json=registration_data, headers=request_headers)
231242

232243

233244
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:

tests/client/auth/extensions/test_client_credentials.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,27 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt
252252
assert "scope=read write" in content
253253
assert "resource=https://api.example.com/v1/mcp" in content
254254

255+
@pytest.mark.anyio
256+
async def test_exchange_token_preserves_user_agent_header(self, mock_storage: MockTokenStorage):
257+
"""Test that client_credentials token requests preserve a caller User-Agent."""
258+
provider = ClientCredentialsOAuthProvider(
259+
server_url="https://api.example.com/v1/mcp",
260+
storage=mock_storage,
261+
client_id="test-client-id",
262+
client_secret="test-client-secret",
263+
)
264+
await provider._initialize()
265+
provider.context.oauth_metadata = OAuthMetadata(
266+
issuer=AnyHttpUrl("https://api.example.com"),
267+
authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"),
268+
token_endpoint=AnyHttpUrl("https://api.example.com/token"),
269+
)
270+
271+
request = await provider._perform_authorization(headers={"User-Agent": "mcp-python-sdk/issue-1664"})
272+
273+
assert request.headers["User-Agent"] == "mcp-python-sdk/issue-1664"
274+
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
275+
255276
@pytest.mark.anyio
256277
async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage):
257278
"""Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1)."""

tests/interaction/auth/_harness.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ async def connect_with_oauth(
394394
client_metadata_url: str | None = None,
395395
headless: HeadlessOAuth | None = None,
396396
auth: httpx.Auth | None = None,
397+
headers: Mapping[str, str] | None = None,
397398
verify_tokens: bool = True,
398399
app_shim: Callable[[ASGIApp], ASGIApp] | None = None,
399400
on_request: Callable[[httpx.Request], None] | None = None,
@@ -455,7 +456,11 @@ async def hook(request: httpx.Request) -> None:
455456
await stack.enter_async_context(server.session_manager.run())
456457
http_client = await stack.enter_async_context(
457458
httpx.AsyncClient(
458-
transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks
459+
transport=StreamingASGITransport(app),
460+
base_url=BASE_URL,
461+
auth=oauth,
462+
headers=headers,
463+
event_hooks=event_hooks,
459464
)
460465
)
461466
headless.bind(http_client)

0 commit comments

Comments
 (0)