Skip to content

Commit 75c0e57

Browse files
peisukeBartok9
authored andcommitted
fix(oauth): narrow async_auth_flow lock scope to avoid blocking long-poll requests
Previously the entire `OAuthClientProvider.async_auth_flow` body ran under `self.context.lock`, including the `yield request` that hands the request off to httpx. For requests that complete quickly this is fine, but a GET SSE long-poll holds the lock for the full SSE lifetime — which means any concurrent POST (e.g. `tools/call`) is blocked waiting for the lock, producing a ~16s perceived stall on lazy MCP connections that use OAuth. This commit splits the single coarse lock into purpose-specific scopes: Phase 1 (context.lock): initialize state, capture protocol_version, and decide whether a refresh is needed. Short-held; no HTTP I/O. Phase 2 (refresh_lock, new): single-flight token refresh. The refresh request `yield` happens outside any lock. A double-check inside `context.lock` ensures concurrent waiters do not redundantly refresh after another coroutine completed one. Phase 3 (no lock): add the auth header and yield the actual request. GET SSE long-polls and other long-running requests no longer block unrelated traffic. Phase 4 (context.lock): 401 / 403 full OAuth re-auth path. Conservatively kept under lock because this path is rare and its yielded sub-requests (metadata discovery, registration, token exchange) hit the AS, not the resource server. A future refactor can narrow this further. Lock additions: - `OAuthContext.refresh_lock: anyio.Lock` provides single-flight refresh so concurrent requests trigger at most one token refresh. Behavior changes: - Concurrent requests through the same `OAuthClientProvider` no longer serialize at the lock. GET SSE long-polls and POSTs now proceed in parallel. - Token refresh remains serialized (via `refresh_lock`), preserving the invariant that only one refresh request is in flight at a time. - Public API and behavior are otherwise unchanged. Related upstream issue: #1326
1 parent cf110e3 commit 75c0e57

1 file changed

Lines changed: 61 additions & 16 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,18 @@ class OAuthContext:
115115
token_expiry_time: float | None = None
116116

117117
# State
118+
#
119+
# `lock` guards short-lived reads/writes of provider state (initialization
120+
# flag, token cache mutation, protocol_version assignment). It is held only
121+
# while mutating state and is released before any HTTP request is yielded
122+
# so a long-running request (e.g. GET SSE long-poll) does not block
123+
# unrelated concurrent requests.
124+
#
125+
# `refresh_lock` provides single-flight semantics for token refresh: only
126+
# one concurrent refresh fires; other waiters block on this lock, then
127+
# re-check the token cache and proceed without re-refreshing.
118128
lock: anyio.Lock = field(default_factory=anyio.Lock)
129+
refresh_lock: anyio.Lock = field(default_factory=anyio.Lock)
119130

120131
def get_authorization_base_url(self, server_url: str) -> str:
121132
"""Extract base URL by removing path component."""
@@ -503,7 +514,17 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
503514
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")
504515

505516
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
506-
"""HTTPX auth flow integration."""
517+
"""HTTPX auth flow integration.
518+
519+
Lock scope:
520+
``self.context.lock`` is held only while reading/mutating provider
521+
state. The actual HTTP request yield (which may be a long-poll GET
522+
SSE stream) runs outside any lock so concurrent unrelated requests
523+
are not blocked. ``self.context.refresh_lock`` provides
524+
single-flight semantics for token refresh.
525+
"""
526+
# === Phase 1: state read + refresh decision (brief context.lock) ===
527+
needs_refresh = False
507528
async with self.context.lock:
508529
if not self._initialized:
509530
await self._initialize()
@@ -512,20 +533,43 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
512533
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
513534

514535
if not self.context.is_token_valid() and self.context.can_refresh_token():
515-
# Try to refresh token
516-
refresh_request = await self._refresh_token()
517-
refresh_response = yield refresh_request
518-
519-
if not await self._handle_refresh_response(refresh_response):
520-
# Refresh failed, need full re-authentication
521-
self._initialized = False
522-
523-
if self.context.is_token_valid():
524-
self._add_auth_header(request)
525-
526-
response = yield request
527-
528-
if response.status_code == 401:
536+
needs_refresh = True
537+
538+
# === Phase 2: single-flight token refresh (yield outside context.lock) ===
539+
if needs_refresh:
540+
async with self.context.refresh_lock:
541+
# Re-check under context.lock: another coroutine may already have
542+
# refreshed while we were waiting on refresh_lock.
543+
async with self.context.lock:
544+
still_invalid = (
545+
not self.context.is_token_valid()
546+
and self.context.can_refresh_token()
547+
)
548+
if still_invalid:
549+
refresh_request = await self._refresh_token() # pragma: no cover
550+
if still_invalid:
551+
# yield runs outside any lock so a long network round trip
552+
# does not block unrelated concurrent requests.
553+
refresh_response = yield refresh_request # pragma: no cover
554+
async with self.context.lock:
555+
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
556+
# Refresh failed; fall through to 401 handling below.
557+
self._initialized = False
558+
559+
# === Phase 3: send request (no lock; safe for long-poll GET SSE) ===
560+
if self.context.is_token_valid():
561+
self._add_auth_header(request)
562+
563+
response = yield request
564+
565+
# === Phase 4: 401 / 403 full OAuth flow ===
566+
# NOTE: Phase 4 yields multiple sub-requests (discovery, registration,
567+
# token exchange) under context.lock. This is the existing behavior and
568+
# is acceptable because the 401 path is exceptional and not concurrent
569+
# with steady-state traffic. A future refactor could narrow the lock
570+
# here in the same pattern as Phase 1-2.
571+
if response.status_code == 401:
572+
async with self.context.lock:
529573
# Perform full OAuth flow
530574
try:
531575
# OAuth flow must be inline due to generator constraints
@@ -618,7 +662,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
618662
# Retry with new tokens
619663
self._add_auth_header(request)
620664
yield request
621-
elif response.status_code == 403:
665+
elif response.status_code == 403:
666+
async with self.context.lock:
622667
# Step 1: Extract error field from WWW-Authenticate header
623668
error = extract_field_from_www_auth(response, "error")
624669

0 commit comments

Comments
 (0)