99import secrets
1010import string
1111import time
12- from collections .abc import AsyncGenerator , Awaitable , Callable
12+ from collections .abc import AsyncGenerator , Awaitable , Callable , Mapping
1313from dataclasses import dataclass , field
1414from typing import Any , Protocol
1515from urllib .parse import quote , urlencode , urljoin , urlparse
@@ -303,10 +303,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
303303 f"Protected Resource Metadata request failed: { response .status_code } "
304304 ) # pragma: no cover
305305
306- async def _perform_authorization (self ) -> httpx .Request :
306+ async def _perform_authorization (self , headers : Mapping [ str , str ] | None = None ) -> httpx .Request :
307307 """Perform the authorization flow."""
308308 auth_code , code_verifier = await self ._perform_authorization_code_grant ()
309- token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
309+ token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier , headers = headers )
310310 return token_request
311311
312312 async def _perform_authorization_code_grant (self ) -> tuple [str , str ]:
@@ -376,7 +376,12 @@ def _get_token_endpoint(self) -> str:
376376 return token_url
377377
378378 async def _exchange_token_authorization_code (
379- self , auth_code : str , code_verifier : str , * , token_data : dict [str , Any ] | None = {}
379+ self ,
380+ auth_code : str ,
381+ code_verifier : str ,
382+ * ,
383+ token_data : dict [str , Any ] | None = {},
384+ headers : Mapping [str , str ] | None = None ,
380385 ) -> httpx .Request :
381386 """Build token exchange request for authorization_code flow."""
382387 if self .context .client_metadata .redirect_uris is None :
@@ -401,10 +406,12 @@ async def _exchange_token_authorization_code(
401406 token_data ["resource" ] = self .context .get_resource_url () # RFC 8707
402407
403408 # Prepare authentication based on preferred method
404- headers = {"Content-Type" : "application/x-www-form-urlencoded" }
405- token_data , headers = self .context .prepare_token_auth (token_data , headers )
409+ request_headers = {"Content-Type" : "application/x-www-form-urlencoded" }
410+ if headers :
411+ request_headers .update (headers )
412+ token_data , request_headers = self .context .prepare_token_auth (token_data , request_headers )
406413
407- return httpx .Request ("POST" , token_url , data = token_data , headers = headers )
414+ return httpx .Request ("POST" , token_url , data = token_data , headers = request_headers )
408415
409416 async def _handle_token_response (self , response : httpx .Response ) -> None :
410417 """Handle token exchange response."""
@@ -421,7 +428,7 @@ async def _handle_token_response(self, response: httpx.Response) -> None:
421428 self .context .update_token_expiry (token_response )
422429 await self .context .storage .set_tokens (token_response )
423430
424- async def _refresh_token (self ) -> httpx .Request :
431+ async def _refresh_token (self , headers : Mapping [ str , str ] | None = None ) -> httpx .Request :
425432 """Build token refresh request."""
426433 if not self .context .current_tokens or not self .context .current_tokens .refresh_token :
427434 raise OAuthTokenError ("No refresh token available" ) # pragma: no cover
@@ -446,10 +453,12 @@ async def _refresh_token(self) -> httpx.Request:
446453 refresh_data ["resource" ] = self .context .get_resource_url () # RFC 8707
447454
448455 # Prepare authentication based on preferred method
449- headers = {"Content-Type" : "application/x-www-form-urlencoded" }
450- refresh_data , headers = self .context .prepare_token_auth (refresh_data , headers )
456+ request_headers = {"Content-Type" : "application/x-www-form-urlencoded" }
457+ if headers :
458+ request_headers .update (headers )
459+ refresh_data , request_headers = self .context .prepare_token_auth (refresh_data , request_headers )
451460
452- return httpx .Request ("POST" , token_url , data = refresh_data , headers = headers )
461+ return httpx .Request ("POST" , token_url , data = refresh_data , headers = request_headers )
453462
454463 async def _handle_refresh_response (self , response : httpx .Response ) -> bool :
455464 """Handle token refresh response. Returns True if successful."""
@@ -483,6 +492,11 @@ def _add_auth_header(self, request: httpx.Request) -> None:
483492 if self .context .current_tokens and self .context .current_tokens .access_token : # pragma: no branch
484493 request .headers ["Authorization" ] = f"Bearer { self .context .current_tokens .access_token } "
485494
495+ def _auth_flow_request_headers (self , request : httpx .Request ) -> dict [str , str ]:
496+ """Headers that are safe to carry from the MCP request into OAuth requests."""
497+ user_agent = request .headers .get ("User-Agent" )
498+ return {"User-Agent" : user_agent } if user_agent else {}
499+
486500 async def _handle_oauth_metadata_response (self , response : httpx .Response ) -> None :
487501 content = await response .aread ()
488502 metadata = OAuthMetadata .model_validate_json (content )
@@ -510,10 +524,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
510524
511525 # Capture protocol version from request headers
512526 self .context .protocol_version = request .headers .get (MCP_PROTOCOL_VERSION )
527+ auth_flow_headers = self ._auth_flow_request_headers (request )
513528
514529 if not self .context .is_token_valid () and self .context .can_refresh_token ():
515530 # Try to refresh token
516- refresh_request = await self ._refresh_token ()
531+ refresh_request = await self ._refresh_token (headers = auth_flow_headers )
517532 refresh_response = yield refresh_request
518533
519534 if not await self ._handle_refresh_response (refresh_response ):
@@ -537,7 +552,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
537552 )
538553
539554 for url in prm_discovery_urls : # pragma: no branch
540- discovery_request = create_oauth_metadata_request (url )
555+ discovery_request = create_oauth_metadata_request (url , headers = auth_flow_headers )
541556
542557 discovery_response = yield discovery_request # sending request
543558
@@ -563,7 +578,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
563578
564579 # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
565580 for url in asm_discovery_urls : # pragma: no branch
566- oauth_metadata_request = create_oauth_metadata_request (url )
581+ oauth_metadata_request = create_oauth_metadata_request (url , headers = auth_flow_headers )
567582 oauth_metadata_response = yield oauth_metadata_request
568583
569584 ok , asm = await handle_auth_metadata_response (oauth_metadata_response )
@@ -602,14 +617,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
602617 self .context .oauth_metadata ,
603618 self .context .client_metadata ,
604619 self .context .get_authorization_base_url (self .context .server_url ),
620+ headers = auth_flow_headers ,
605621 )
606622 registration_response = yield registration_request
607623 client_information = await handle_registration_response (registration_response )
608624 self .context .client_info = client_information
609625 await self .context .storage .set_client_info (client_information )
610626
611627 # Step 5: Perform authorization and complete token exchange
612- token_response = yield await self ._perform_authorization ()
628+ token_response = yield await self ._perform_authorization (headers = auth_flow_headers )
613629 await self ._handle_token_response (token_response )
614630 except Exception :
615631 logger .exception ("OAuth flow error" )
@@ -634,7 +650,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
634650 )
635651
636652 # Step 2b: Perform (re-)authorization and token exchange
637- token_response = yield await self ._perform_authorization ()
653+ token_response = yield await self ._perform_authorization (headers = auth_flow_headers )
638654 await self ._handle_token_response (token_response )
639655 except Exception : # pragma: no cover
640656 logger .exception ("OAuth flow error" )
0 commit comments