Skip to content

Commit b731b2d

Browse files
peisukeBartok9
authored andcommitted
test: add regression coverage for OAuthClientProvider concurrent requests
Two tests in a new TestConcurrentRequestsDoNotDeadlock class exercise the behavior the previous commit fixes: 1. ``test_concurrent_request_not_blocked_by_pending_long_running_request`` drives one async_auth_flow generator to its yield (= simulating a GET SSE long-poll suspended waiting for the next event) and then opens a second concurrent flow on the same provider. The second flow must reach its own yield within a short timeout — i.e., the lock release between Phase 1 and Phase 3 lets it through. Pre-fix, the second generator would block on context.lock indefinitely. 2. ``test_concurrent_token_refresh_is_single_flight`` exercises the refresh_lock single-flight path. A first flow performs the refresh yield; a second flow started after the refresh completes observes the freshly-updated token in Phase 1 and proceeds directly to its own request yield without issuing a second refresh. Also: tighten the refresh_request unbound-after-conditional-write pattern in async_auth_flow so pyright recognizes it as definitely assigned at the yield site (was: derived from a boolean predicate; now: typed as ``httpx.Request | None`` and checked explicitly).
1 parent 75c0e57 commit b731b2d

2 files changed

Lines changed: 118 additions & 6 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,14 +540,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
540540
async with self.context.refresh_lock:
541541
# Re-check under context.lock: another coroutine may already have
542542
# refreshed while we were waiting on refresh_lock.
543+
refresh_request: httpx.Request | None = None
543544
async with self.context.lock:
544-
still_invalid = (
545-
not self.context.is_token_valid()
546-
and self.context.can_refresh_token()
547-
)
548-
if still_invalid:
545+
if not self.context.is_token_valid() and self.context.can_refresh_token():
549546
refresh_request = await self._refresh_token() # pragma: no cover
550-
if still_invalid:
547+
if refresh_request is not None:
551548
# yield runs outside any lock so a long network round trip
552549
# does not block unrelated concurrent requests.
553550
refresh_response = yield refresh_request # pragma: no cover

tests/client/test_auth.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Tests for refactored OAuth client authentication implementation."""
22

33
import base64
4+
import contextlib
45
import time
56
from unittest import mock
67
from urllib.parse import parse_qs, quote, unquote, urlparse
78

9+
import anyio
810
import httpx
911
import pytest
1012
from inline_snapshot import Is, snapshot
@@ -2636,3 +2638,116 @@ async def callback_handler() -> tuple[str, str | None]:
26362638
await auth_flow.asend(final_response)
26372639
except StopAsyncIteration:
26382640
pass
2641+
2642+
2643+
class TestConcurrentRequestsDoNotDeadlock:
2644+
"""Regression tests for #1326.
2645+
2646+
Ensures that ``OAuthClientProvider.async_auth_flow`` does not serialize
2647+
concurrent unrelated requests behind a long-running one (e.g. GET SSE
2648+
long-poll). The fix narrows ``context.lock`` to state mutation only; the
2649+
actual ``yield request`` runs outside any lock.
2650+
"""
2651+
2652+
@pytest.mark.anyio
2653+
async def test_concurrent_request_not_blocked_by_pending_long_running_request(
2654+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2655+
):
2656+
"""A second request must reach its yield while the first is still
2657+
suspended at its yield (= simulating a server-side long-poll).
2658+
2659+
Before this fix, ``async_auth_flow`` held ``context.lock`` across
2660+
``yield request``. A GET SSE long-poll would therefore hold the lock
2661+
for the entire SSE lifetime, blocking any concurrent request waiting
2662+
on the same provider's lock and producing a multi-second stall.
2663+
"""
2664+
# Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2665+
# flow (Phase 4) is triggered — we want to exercise the steady-state
2666+
# Phase 3 yield path that previously held the lock.
2667+
oauth_provider.context.current_tokens = valid_tokens
2668+
oauth_provider.context.token_expiry_time = time.time() + 1800
2669+
oauth_provider.context.client_info = OAuthClientInformationFull(
2670+
client_id="test_client_id",
2671+
client_secret="test_client_secret",
2672+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2673+
)
2674+
oauth_provider._initialized = True
2675+
2676+
# Flow 1: simulate a slow request. Drive it to its yield, then
2677+
# deliberately do not send a response — it stays suspended at the
2678+
# yield, just like a GET SSE long-poll waiting for the next event.
2679+
slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
2680+
slow_flow = oauth_provider.async_auth_flow(slow_request)
2681+
yielded_slow = await slow_flow.__anext__()
2682+
assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token"
2683+
2684+
# Flow 2: a concurrent request on the same provider. With the fix,
2685+
# context.lock is not held during Flow 1's yield, so Flow 2 reaches
2686+
# its yield almost immediately. Without the fix, this would block
2687+
# until Flow 1 receives a response — i.e., it would hit the timeout.
2688+
fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2689+
fast_flow = oauth_provider.async_auth_flow(fast_request)
2690+
with anyio.fail_after(1.0):
2691+
yielded_fast = await fast_flow.__anext__()
2692+
assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token"
2693+
2694+
# Clean up both generators in deterministic order.
2695+
with contextlib.suppress(StopAsyncIteration):
2696+
await fast_flow.asend(httpx.Response(200, request=yielded_fast))
2697+
with contextlib.suppress(StopAsyncIteration):
2698+
await slow_flow.asend(httpx.Response(200, request=yielded_slow))
2699+
2700+
@pytest.mark.anyio
2701+
async def test_concurrent_token_refresh_is_single_flight(
2702+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2703+
):
2704+
"""When concurrent requests both observe an expired token, only one
2705+
refresh request is sent: ``refresh_lock`` provides single-flight
2706+
semantics so the second waiter re-checks state and proceeds without
2707+
re-triggering refresh.
2708+
"""
2709+
# Mark the token as expired so the next auth_flow run enters Phase 2.
2710+
oauth_provider.context.current_tokens = valid_tokens
2711+
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2712+
oauth_provider.context.client_info = OAuthClientInformationFull(
2713+
client_id="test_client_id",
2714+
client_secret="test_client_secret",
2715+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2716+
)
2717+
oauth_provider._initialized = True
2718+
2719+
# Flow A: drive it to the refresh yield and suspend there.
2720+
request_a = httpx.Request("GET", "https://api.example.com/v1/mcp")
2721+
flow_a = oauth_provider.async_auth_flow(request_a)
2722+
refresh_a = await flow_a.__anext__()
2723+
assert "grant_type=refresh_token" in refresh_a.read().decode()
2724+
2725+
# Complete Flow A's refresh with a fresh token.
2726+
refresh_response = httpx.Response(
2727+
200,
2728+
content=(
2729+
b'{"access_token": "new_access_token", "token_type": "Bearer", '
2730+
b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2731+
),
2732+
request=refresh_a,
2733+
)
2734+
request_a_post = await flow_a.asend(refresh_response)
2735+
assert request_a_post.headers.get("Authorization") == "Bearer new_access_token"
2736+
2737+
# Flow B starts after Flow A's refresh has completed. Because token
2738+
# state was updated under context.lock, Flow B observes the fresh
2739+
# token in Phase 1, skips Phase 2 entirely, and reaches its yield
2740+
# directly. No second refresh request is sent.
2741+
request_b = httpx.Request("POST", "https://api.example.com/v1/mcp")
2742+
flow_b = oauth_provider.async_auth_flow(request_b)
2743+
with anyio.fail_after(1.0):
2744+
request_b_yielded = await flow_b.__anext__()
2745+
assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token"
2746+
# Confirm Flow B yielded the original POST, not a refresh request.
2747+
assert request_b_yielded.method == "POST"
2748+
2749+
# Clean up.
2750+
with contextlib.suppress(StopAsyncIteration):
2751+
await flow_b.asend(httpx.Response(200, request=request_b_yielded))
2752+
with contextlib.suppress(StopAsyncIteration):
2753+
await flow_a.asend(httpx.Response(200, request=request_a_post))

0 commit comments

Comments
 (0)