Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions src/mcp/client/auth/extensions/client_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import time
import warnings
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Mapping
from typing import Any, Literal
from uuid import uuid4

Expand Down Expand Up @@ -82,20 +82,22 @@ async def _initialize(self) -> None:
self.context.client_info = self._fixed_client_info
self._initialized = True

async def _perform_authorization(self) -> httpx.Request:
async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
"""Perform client_credentials authorization."""
return await self._exchange_token_client_credentials()
return await self._exchange_token_client_credentials(headers=headers)

async def _exchange_token_client_credentials(self) -> httpx.Request:
async def _exchange_token_client_credentials(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
"""Build token exchange request for client_credentials grant."""
token_data: dict[str, Any] = {
"grant_type": "client_credentials",
}

headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
request_headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
if headers:
request_headers.update(headers)

# Use standard auth methods (client_secret_basic, client_secret_post, none)
token_data, headers = self.context.prepare_token_auth(token_data, headers)
token_data, request_headers = self.context.prepare_token_auth(token_data, request_headers)

if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url()
Expand All @@ -104,7 +106,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
token_data["scope"] = self.context.client_metadata.scope

token_url = self._get_token_endpoint()
return httpx.Request("POST", token_url, data=token_data, headers=headers)
return httpx.Request("POST", token_url, data=token_data, headers=request_headers)


def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]:
Expand Down Expand Up @@ -296,9 +298,9 @@ async def _initialize(self) -> None:
self.context.client_info = self._fixed_client_info
self._initialized = True

async def _perform_authorization(self) -> httpx.Request:
async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
"""Perform client_credentials authorization with private_key_jwt."""
return await self._exchange_token_client_credentials()
return await self._exchange_token_client_credentials(headers=headers)

async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None:
"""Add JWT assertion for client authentication to token endpoint parameters."""
Expand All @@ -314,13 +316,15 @@ async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) ->
token_data["client_assertion"] = assertion
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"

async def _exchange_token_client_credentials(self) -> httpx.Request:
async def _exchange_token_client_credentials(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
"""Build token exchange request for client_credentials grant with private_key_jwt."""
token_data: dict[str, Any] = {
"grant_type": "client_credentials",
}

headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
request_headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
if headers:
request_headers.update(headers)

# Add JWT client authentication (RFC 7523 Section 2.2)
await self._add_client_authentication_jwt(token_data=token_data)
Expand All @@ -332,7 +336,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
token_data["scope"] = self.context.client_metadata.scope

token_url = self._get_token_endpoint()
return httpx.Request("POST", token_url, data=token_data, headers=headers)
return httpx.Request("POST", token_url, data=token_data, headers=request_headers)


class JWTParameters(BaseModel):
Expand Down Expand Up @@ -419,21 +423,33 @@ def __init__(
self.jwt_parameters = jwt_parameters

async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
self,
auth_code: str,
code_verifier: str,
*,
token_data: dict[str, Any] | None = None,
headers: Mapping[str, str] | None = None,
) -> httpx.Request: # pragma: no cover
"""Build token exchange request for authorization_code flow."""
token_data = token_data or {}
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
self._add_client_authentication_jwt(token_data=token_data)
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
return await super()._exchange_token_authorization_code(
auth_code,
code_verifier,
token_data=token_data,
headers=headers,
)

async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
async def _perform_authorization(
self, headers: Mapping[str, str] | None = None
) -> httpx.Request: # pragma: no cover
"""Perform the authorization flow."""
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
token_request = await self._exchange_token_jwt_bearer()
return token_request
else:
return await super()._perform_authorization()
return await super()._perform_authorization(headers=headers)

def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
"""Add JWT assertion for client authentication to token endpoint parameters."""
Expand Down
48 changes: 32 additions & 16 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import secrets
import string
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
Expand Down Expand Up @@ -303,10 +303,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
f"Protected Resource Metadata request failed: {response.status_code}"
) # pragma: no cover

async def _perform_authorization(self) -> httpx.Request:
async def _perform_authorization(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier, headers=headers)
return token_request

async def _perform_authorization_code_grant(self) -> tuple[str, str]:
Expand Down Expand Up @@ -376,7 +376,12 @@ def _get_token_endpoint(self) -> str:
return token_url

async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
self,
auth_code: str,
code_verifier: str,
*,
token_data: dict[str, Any] | None = {},
headers: Mapping[str, str] | None = None,
) -> httpx.Request:
"""Build token exchange request for authorization_code flow."""
if self.context.client_metadata.redirect_uris is None:
Expand All @@ -401,10 +406,12 @@ async def _exchange_token_authorization_code(
token_data["resource"] = self.context.get_resource_url() # RFC 8707

# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
token_data, headers = self.context.prepare_token_auth(token_data, headers)
request_headers = {"Content-Type": "application/x-www-form-urlencoded"}
if headers:
request_headers.update(headers)
token_data, request_headers = self.context.prepare_token_auth(token_data, request_headers)

return httpx.Request("POST", token_url, data=token_data, headers=headers)
return httpx.Request("POST", token_url, data=token_data, headers=request_headers)

async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
Expand All @@ -421,7 +428,7 @@ async def _handle_token_response(self, response: httpx.Response) -> None:
self.context.update_token_expiry(token_response)
await self.context.storage.set_tokens(token_response)

async def _refresh_token(self) -> httpx.Request:
async def _refresh_token(self, headers: Mapping[str, str] | None = None) -> httpx.Request:
"""Build token refresh request."""
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
raise OAuthTokenError("No refresh token available") # pragma: no cover
Expand All @@ -446,10 +453,12 @@ async def _refresh_token(self) -> httpx.Request:
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707

# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
request_headers = {"Content-Type": "application/x-www-form-urlencoded"}
if headers:
request_headers.update(headers)
refresh_data, request_headers = self.context.prepare_token_auth(refresh_data, request_headers)

return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
return httpx.Request("POST", token_url, data=refresh_data, headers=request_headers)

async def _handle_refresh_response(self, response: httpx.Response) -> bool:
"""Handle token refresh response. Returns True if successful."""
Expand Down Expand Up @@ -483,6 +492,11 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

def _auth_flow_request_headers(self, request: httpx.Request) -> dict[str, str]:
"""Headers that are safe to carry from the MCP request into OAuth requests."""
user_agent = request.headers.get("User-Agent")
return {"User-Agent": user_agent} if user_agent else {}

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
Expand Down Expand Up @@ -510,10 +524,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.

# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
auth_flow_headers = self._auth_flow_request_headers(request)

if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token()
refresh_request = await self._refresh_token(headers=auth_flow_headers)
refresh_response = yield refresh_request

if not await self._handle_refresh_response(refresh_response):
Expand All @@ -537,7 +552,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
)

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_request = create_oauth_metadata_request(url, headers=auth_flow_headers)

discovery_response = yield discovery_request # sending request

Expand All @@ -563,7 +578,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.

# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_request = create_oauth_metadata_request(url, headers=auth_flow_headers)
oauth_metadata_response = yield oauth_metadata_request

ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
Expand Down Expand Up @@ -602,14 +617,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
headers=auth_flow_headers,
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
token_response = yield await self._perform_authorization(headers=auth_flow_headers)
await self._handle_token_response(token_response)
except Exception:
logger.exception("OAuth flow error")
Expand All @@ -634,7 +650,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
)

# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
token_response = yield await self._perform_authorization(headers=auth_flow_headers)
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
Expand Down
19 changes: 15 additions & 4 deletions src/mcp/client/auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from collections.abc import Mapping
from urllib.parse import urljoin, urlparse

from httpx import Request, Response
Expand Down Expand Up @@ -211,12 +212,18 @@ async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuth
return True, None


def create_oauth_metadata_request(url: str) -> Request:
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
def create_oauth_metadata_request(url: str, headers: Mapping[str, str] | None = None) -> Request:
request_headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
if headers:
request_headers.update(headers)
return Request("GET", url, headers=request_headers)


def create_client_registration_request(
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
auth_server_metadata: OAuthMetadata | None,
client_metadata: OAuthClientMetadata,
auth_base_url: str,
headers: Mapping[str, str] | None = None,
) -> Request:
"""Build a client registration request."""

Expand All @@ -227,7 +234,11 @@ def create_client_registration_request(

registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)

return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
request_headers = {"Content-Type": "application/json"}
if headers:
request_headers.update(headers)

return Request("POST", registration_url, json=registration_data, headers=request_headers)


async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
Expand Down
21 changes: 21 additions & 0 deletions tests/client/auth/extensions/test_client_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,27 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt
assert "scope=read write" in content
assert "resource=https://api.example.com/v1/mcp" in content

@pytest.mark.anyio
async def test_exchange_token_preserves_user_agent_header(self, mock_storage: MockTokenStorage):
"""Test that client_credentials token requests preserve a caller User-Agent."""
provider = ClientCredentialsOAuthProvider(
server_url="https://api.example.com/v1/mcp",
storage=mock_storage,
client_id="test-client-id",
client_secret="test-client-secret",
)
await provider._initialize()
provider.context.oauth_metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://api.example.com"),
authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"),
token_endpoint=AnyHttpUrl("https://api.example.com/token"),
)

request = await provider._perform_authorization(headers={"User-Agent": "mcp-python-sdk/issue-1664"})

assert request.headers["User-Agent"] == "mcp-python-sdk/issue-1664"
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"

@pytest.mark.anyio
async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage):
"""Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1)."""
Expand Down
7 changes: 6 additions & 1 deletion tests/interaction/auth/_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ async def connect_with_oauth(
client_metadata_url: str | None = None,
headless: HeadlessOAuth | None = None,
auth: httpx.Auth | None = None,
headers: Mapping[str, str] | None = None,
verify_tokens: bool = True,
app_shim: Callable[[ASGIApp], ASGIApp] | None = None,
on_request: Callable[[httpx.Request], None] | None = None,
Expand Down Expand Up @@ -455,7 +456,11 @@ async def hook(request: httpx.Request) -> None:
await stack.enter_async_context(server.session_manager.run())
http_client = await stack.enter_async_context(
httpx.AsyncClient(
transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks
transport=StreamingASGITransport(app),
base_url=BASE_URL,
auth=oauth,
headers=headers,
event_hooks=event_hooks,
)
)
headless.bind(http_client)
Expand Down
Loading
Loading