Skip to content

Commit b6da2c6

Browse files
committed
feat(auth): proactive OAuth token refresh with jitter to reduce concurrent refresh spikes
1 parent cf110e3 commit b6da2c6

4 files changed

Lines changed: 343 additions & 4 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from mcp.shared.auth_utils import (
4848
calculate_token_expiry,
49+
calculate_token_refresh_time,
4950
check_resource_allowed,
5051
resource_url_from_server_url,
5152
)
@@ -113,6 +114,9 @@ class OAuthContext:
113114
# Token management
114115
current_tokens: OAuthToken | None = None
115116
token_expiry_time: float | None = None
117+
# Jittered point (before hard expiry) at which to proactively refresh, so a fleet
118+
# of connectors does not all refresh in the same window. See should_refresh_token.
119+
token_refresh_time: float | None = None
116120

117121
# State
118122
lock: anyio.Lock = field(default_factory=anyio.Lock)
@@ -123,11 +127,12 @@ def get_authorization_base_url(self, server_url: str) -> str:
123127
return f"{parsed.scheme}://{parsed.netloc}"
124128

125129
def update_token_expiry(self, token: OAuthToken) -> None:
126-
"""Update token expiry time using shared util function."""
130+
"""Update token expiry and proactive-refresh times using shared util functions."""
127131
self.token_expiry_time = calculate_token_expiry(token.expires_in)
132+
self.token_refresh_time = calculate_token_refresh_time(token.expires_in)
128133

129134
def is_token_valid(self) -> bool:
130-
"""Check if current token is valid."""
135+
"""Check if current token is valid (i.e. usable, not past hard expiry)."""
131136
return bool(
132137
self.current_tokens
133138
and self.current_tokens.access_token
@@ -138,10 +143,28 @@ def can_refresh_token(self) -> bool:
138143
"""Check if token can be refreshed."""
139144
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
140145

146+
def should_refresh_token(self) -> bool:
147+
"""Check if the token should be *proactively* refreshed.
148+
149+
Returns True when we hold refreshable tokens and have passed the jittered
150+
proactive-refresh point (``token_refresh_time``), even if the token is still
151+
technically valid. Refreshing slightly early -- and at a per-connector jittered
152+
moment -- spreads a fleet's refreshes out instead of bunching them into the
153+
same expiry window. Returns False when no refresh time is known (no expiry
154+
info) so behavior degrades to the existing reactive path.
155+
"""
156+
return bool(
157+
self.current_tokens
158+
and self.can_refresh_token()
159+
and self.token_refresh_time is not None
160+
and time.time() >= self.token_refresh_time
161+
)
162+
141163
def clear_tokens(self) -> None:
142164
"""Clear current tokens."""
143165
self.current_tokens = None
144166
self.token_expiry_time = None
167+
self.token_refresh_time = None
145168

146169
def get_resource_url(self) -> str:
147170
"""Get resource URL for RFC 8707.
@@ -511,7 +534,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
511534
# Capture protocol version from request headers
512535
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
513536

514-
if not self.context.is_token_valid() and self.context.can_refresh_token():
537+
if (
538+
not self.context.is_token_valid() or self.context.should_refresh_token()
539+
) and self.context.can_refresh_token():
540+
# Refresh either reactively (token already invalid) or proactively
541+
# (past the jittered refresh point, before hard expiry).
515542
# Try to refresh token
516543
refresh_request = await self._refresh_token()
517544
refresh_response = yield refresh_request

src/mcp/shared/auth_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636)."""
22

3+
import random
34
import time
45
from urllib.parse import urlparse, urlsplit, urlunsplit
56

@@ -78,3 +79,80 @@ def calculate_token_expiry(expires_in: int | str | None) -> float | None:
7879
return None # pragma: no cover
7980
# Defensive: handle servers that return expires_in as string
8081
return time.time() + int(expires_in)
82+
83+
84+
def calculate_token_refresh_time(
85+
expires_in: int | str | None,
86+
*,
87+
refresh_fraction: float = 0.8,
88+
max_jitter_seconds: float = 30.0,
89+
jitter: float | None = None,
90+
) -> float | None:
91+
"""Calculate when a token should be *proactively* refreshed.
92+
93+
Reactive refresh (waiting until a token has already expired) means that, for a
94+
fleet of OAuth-backed MCP connectors provisioned around the same time, every
95+
token tends to expire inside the same narrow window. When they do, all of those
96+
clients try to refresh simultaneously, producing a "thundering herd" of refresh
97+
requests against the authorization server -- contention, rate limiting, and
98+
spurious auth failures.
99+
100+
To avoid that, this returns a timestamp *before* hard expiry at which the token
101+
should be refreshed:
102+
103+
refresh_at = now + expires_in * refresh_fraction - jitter
104+
105+
The jitter is always *subtracted* so it pulls the refresh point earlier and can
106+
never push it past the hard-expiry boundary. Spreading each client's refresh
107+
point by a small random amount means a fleet naturally desynchronizes instead of
108+
refreshing in lockstep.
109+
110+
Args:
111+
expires_in: Seconds until token expiration (may be a string from some servers).
112+
refresh_fraction: Fraction of the token lifetime after which to refresh.
113+
Defaults to 0.8 (refresh once 80% of the lifetime has elapsed).
114+
max_jitter_seconds: Upper bound (in seconds) of the random jitter subtracted
115+
from the refresh point. Defaults to 30s.
116+
jitter: Optional explicit jitter value (seconds). When provided it is used
117+
directly instead of drawing a random value, which keeps the function
118+
deterministic and testable. When None, a value in
119+
``[0, max_jitter_seconds]`` is drawn at random.
120+
121+
Returns:
122+
Unix timestamp at which the token should be proactively refreshed, or None
123+
if ``expires_in`` is None (no expiry information -> nothing to schedule).
124+
The result is always in ``(now, hard_expiry]`` and never in the past.
125+
"""
126+
if expires_in is None:
127+
return None
128+
129+
expires_in_seconds = int(expires_in)
130+
now = time.time()
131+
hard_expiry = now + expires_in_seconds
132+
133+
# Base proactive point: refresh once `refresh_fraction` of the lifetime elapsed.
134+
refresh_at = now + expires_in_seconds * refresh_fraction
135+
136+
# Cap the jitter so it can never reach back before `now`, which matters for very
137+
# short TTLs (e.g. expires_in smaller than max_jitter_seconds). The window we are
138+
# allowed to pull earlier into is (refresh_at - now); never jitter more than that.
139+
available_window = refresh_at - now
140+
effective_max_jitter = min(max_jitter_seconds, max(available_window, 0.0))
141+
142+
if jitter is None:
143+
applied_jitter = random.uniform(0, effective_max_jitter)
144+
else:
145+
# Clamp an injected jitter into the valid range to preserve invariants.
146+
applied_jitter = min(max(jitter, 0.0), effective_max_jitter)
147+
148+
refresh_at -= applied_jitter
149+
150+
# Final guard: keep the result strictly within (now, hard_expiry]. For tiny or
151+
# zero TTLs this collapses gracefully toward `now` rather than going negative or
152+
# past the hard-expiry boundary.
153+
if refresh_at < now:
154+
refresh_at = now # pragma: no cover
155+
if refresh_at > hard_expiry:
156+
refresh_at = hard_expiry # pragma: no cover
157+
158+
return refresh_at

tests/client/test_auth.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,50 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
251251
context = oauth_provider.context
252252
context.current_tokens = valid_tokens
253253
context.token_expiry_time = time.time() + 1800
254+
context.token_refresh_time = time.time() + 1440
254255

255256
# Clear tokens
256257
context.clear_tokens()
257258

258259
# Verify cleared
259260
assert context.current_tokens is None
260261
assert context.token_expiry_time is None
262+
assert context.token_refresh_time is None
263+
264+
@pytest.mark.anyio
265+
async def test_should_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
266+
"""Test should_refresh_token() proactive-refresh logic."""
267+
context = oauth_provider.context
268+
269+
# No tokens at all -> never proactively refresh.
270+
assert not context.should_refresh_token()
271+
272+
context.current_tokens = valid_tokens
273+
context.client_info = OAuthClientInformationFull(
274+
client_id="test_client_id",
275+
client_secret="test_client_secret",
276+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
277+
)
278+
279+
# Token still hard-valid AND before the jittered refresh point -> no refresh.
280+
context.token_expiry_time = time.time() + 1800
281+
context.token_refresh_time = time.time() + 600
282+
assert context.is_token_valid()
283+
assert not context.should_refresh_token()
284+
285+
# Token still hard-valid but we have passed the proactive refresh point -> refresh.
286+
context.token_refresh_time = time.time() - 1
287+
assert context.is_token_valid()
288+
assert context.should_refresh_token()
289+
290+
# No refresh time known (e.g. server gave no expiry) -> fall back to reactive only.
291+
context.token_refresh_time = None
292+
assert not context.should_refresh_token()
293+
294+
# Past the refresh point but no refresh token -> cannot proactively refresh.
295+
context.token_refresh_time = time.time() - 1
296+
context.current_tokens.refresh_token = None
297+
assert not context.should_refresh_token()
261298

262299

263300
class TestOAuthFlow:
@@ -506,6 +543,102 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
506543
except StopAsyncIteration:
507544
pass # Expected - generator should complete
508545

546+
@pytest.mark.anyio
547+
async def test_async_auth_flow_proactively_refreshes_when_past_jitter_window(
548+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
549+
):
550+
"""async_auth_flow refreshes proactively past the jittered window.
551+
552+
The token is still hard-valid (is_token_valid() is True), but we are past the
553+
proactive refresh point, so the flow should yield a refresh request *before*
554+
sending the original request -- spreading fleet refreshes out instead of
555+
waiting for hard expiry.
556+
"""
557+
context = oauth_provider.context
558+
context.current_tokens = valid_tokens
559+
context.client_info = OAuthClientInformationFull(
560+
client_id="test_client_id",
561+
client_secret="test_client_secret",
562+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
563+
)
564+
oauth_provider._initialized = True
565+
566+
# Token is still valid for a while, but we are past the proactive refresh point.
567+
context.token_expiry_time = time.time() + 1800
568+
context.token_refresh_time = time.time() - 1
569+
assert context.is_token_valid()
570+
assert context.should_refresh_token()
571+
572+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
573+
auth_flow = oauth_provider.async_auth_flow(test_request)
574+
575+
# First yielded request must be a proactive refresh, not the original request.
576+
refresh_request = await auth_flow.__anext__()
577+
assert refresh_request.method == "POST"
578+
assert str(refresh_request.url) == "https://api.example.com/token"
579+
refresh_content = refresh_request.content.decode()
580+
assert "grant_type=refresh_token" in refresh_content
581+
assert "refresh_token=test_refresh_token" in refresh_content
582+
583+
# Provide a successful refresh response with fresh tokens.
584+
refresh_response = httpx.Response(
585+
200,
586+
content=(
587+
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
588+
b'"refresh_token": "new_refresh_token"}'
589+
),
590+
request=refresh_request,
591+
)
592+
593+
# After a successful refresh, the original request is sent with the new token.
594+
actual_request = await auth_flow.asend(refresh_response)
595+
assert actual_request.headers["Authorization"] == "Bearer new_access_token"
596+
assert str(actual_request.url) == "https://api.example.com/v1/mcp"
597+
598+
# New proactive-refresh point should have been scheduled in the future.
599+
assert context.token_refresh_time is not None
600+
assert context.token_refresh_time > time.time()
601+
602+
# Close out the generator with a final success response.
603+
final_response = httpx.Response(200, request=actual_request)
604+
try:
605+
await auth_flow.asend(final_response)
606+
except StopAsyncIteration:
607+
pass # Expected - generator completes
608+
609+
@pytest.mark.anyio
610+
async def test_async_auth_flow_skips_refresh_before_jitter_window(
611+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
612+
):
613+
"""A fresh token (before the proactive window) is used directly, no refresh."""
614+
context = oauth_provider.context
615+
context.current_tokens = valid_tokens
616+
context.client_info = OAuthClientInformationFull(
617+
client_id="test_client_id",
618+
client_secret="test_client_secret",
619+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
620+
)
621+
oauth_provider._initialized = True
622+
623+
# Token valid and well before the proactive refresh point.
624+
context.token_expiry_time = time.time() + 1800
625+
context.token_refresh_time = time.time() + 600
626+
assert not context.should_refresh_token()
627+
628+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
629+
auth_flow = oauth_provider.async_auth_flow(test_request)
630+
631+
# First (and only auth-related) yielded request is the original request itself.
632+
actual_request = await auth_flow.__anext__()
633+
assert actual_request.headers["Authorization"] == "Bearer test_access_token"
634+
assert str(actual_request.url) == "https://api.example.com/v1/mcp"
635+
636+
final_response = httpx.Response(200, request=actual_request)
637+
try:
638+
await auth_flow.asend(final_response)
639+
except StopAsyncIteration:
640+
pass # Expected - generator completes
641+
509642
@pytest.mark.anyio
510643
async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider):
511644
"""Test successful metadata response handling."""

0 commit comments

Comments
 (0)