@@ -251,13 +251,50 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
251251 context = oauth_provider .context
252252 context .current_tokens = valid_tokens
253253 context .token_expiry_time = time .time () + 1800
254+ context .token_refresh_time = time .time () + 1440
254255
255256 # Clear tokens
256257 context .clear_tokens ()
257258
258259 # Verify cleared
259260 assert context .current_tokens is None
260261 assert context .token_expiry_time is None
262+ assert context .token_refresh_time is None
263+
264+ @pytest .mark .anyio
265+ async def test_should_refresh_token (self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
266+ """Test should_refresh_token() proactive-refresh logic."""
267+ context = oauth_provider .context
268+
269+ # No tokens at all -> never proactively refresh.
270+ assert not context .should_refresh_token ()
271+
272+ context .current_tokens = valid_tokens
273+ context .client_info = OAuthClientInformationFull (
274+ client_id = "test_client_id" ,
275+ client_secret = "test_client_secret" ,
276+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
277+ )
278+
279+ # Token still hard-valid AND before the jittered refresh point -> no refresh.
280+ context .token_expiry_time = time .time () + 1800
281+ context .token_refresh_time = time .time () + 600
282+ assert context .is_token_valid ()
283+ assert not context .should_refresh_token ()
284+
285+ # Token still hard-valid but we have passed the proactive refresh point -> refresh.
286+ context .token_refresh_time = time .time () - 1
287+ assert context .is_token_valid ()
288+ assert context .should_refresh_token ()
289+
290+ # No refresh time known (e.g. server gave no expiry) -> fall back to reactive only.
291+ context .token_refresh_time = None
292+ assert not context .should_refresh_token ()
293+
294+ # Past the refresh point but no refresh token -> cannot proactively refresh.
295+ context .token_refresh_time = time .time () - 1
296+ context .current_tokens .refresh_token = None
297+ assert not context .should_refresh_token ()
261298
262299
263300class TestOAuthFlow :
@@ -506,6 +543,102 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
506543 except StopAsyncIteration :
507544 pass # Expected - generator should complete
508545
546+ @pytest .mark .anyio
547+ async def test_async_auth_flow_proactively_refreshes_when_past_jitter_window (
548+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
549+ ):
550+ """async_auth_flow refreshes proactively past the jittered window.
551+
552+ The token is still hard-valid (is_token_valid() is True), but we are past the
553+ proactive refresh point, so the flow should yield a refresh request *before*
554+ sending the original request -- spreading fleet refreshes out instead of
555+ waiting for hard expiry.
556+ """
557+ context = oauth_provider .context
558+ context .current_tokens = valid_tokens
559+ context .client_info = OAuthClientInformationFull (
560+ client_id = "test_client_id" ,
561+ client_secret = "test_client_secret" ,
562+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
563+ )
564+ oauth_provider ._initialized = True
565+
566+ # Token is still valid for a while, but we are past the proactive refresh point.
567+ context .token_expiry_time = time .time () + 1800
568+ context .token_refresh_time = time .time () - 1
569+ assert context .is_token_valid ()
570+ assert context .should_refresh_token ()
571+
572+ test_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
573+ auth_flow = oauth_provider .async_auth_flow (test_request )
574+
575+ # First yielded request must be a proactive refresh, not the original request.
576+ refresh_request = await auth_flow .__anext__ ()
577+ assert refresh_request .method == "POST"
578+ assert str (refresh_request .url ) == "https://api.example.com/token"
579+ refresh_content = refresh_request .content .decode ()
580+ assert "grant_type=refresh_token" in refresh_content
581+ assert "refresh_token=test_refresh_token" in refresh_content
582+
583+ # Provide a successful refresh response with fresh tokens.
584+ refresh_response = httpx .Response (
585+ 200 ,
586+ content = (
587+ b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
588+ b'"refresh_token": "new_refresh_token"}'
589+ ),
590+ request = refresh_request ,
591+ )
592+
593+ # After a successful refresh, the original request is sent with the new token.
594+ actual_request = await auth_flow .asend (refresh_response )
595+ assert actual_request .headers ["Authorization" ] == "Bearer new_access_token"
596+ assert str (actual_request .url ) == "https://api.example.com/v1/mcp"
597+
598+ # New proactive-refresh point should have been scheduled in the future.
599+ assert context .token_refresh_time is not None
600+ assert context .token_refresh_time > time .time ()
601+
602+ # Close out the generator with a final success response.
603+ final_response = httpx .Response (200 , request = actual_request )
604+ try :
605+ await auth_flow .asend (final_response )
606+ except StopAsyncIteration :
607+ pass # Expected - generator completes
608+
609+ @pytest .mark .anyio
610+ async def test_async_auth_flow_skips_refresh_before_jitter_window (
611+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
612+ ):
613+ """A fresh token (before the proactive window) is used directly, no refresh."""
614+ context = oauth_provider .context
615+ context .current_tokens = valid_tokens
616+ context .client_info = OAuthClientInformationFull (
617+ client_id = "test_client_id" ,
618+ client_secret = "test_client_secret" ,
619+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
620+ )
621+ oauth_provider ._initialized = True
622+
623+ # Token valid and well before the proactive refresh point.
624+ context .token_expiry_time = time .time () + 1800
625+ context .token_refresh_time = time .time () + 600
626+ assert not context .should_refresh_token ()
627+
628+ test_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
629+ auth_flow = oauth_provider .async_auth_flow (test_request )
630+
631+ # First (and only auth-related) yielded request is the original request itself.
632+ actual_request = await auth_flow .__anext__ ()
633+ assert actual_request .headers ["Authorization" ] == "Bearer test_access_token"
634+ assert str (actual_request .url ) == "https://api.example.com/v1/mcp"
635+
636+ final_response = httpx .Response (200 , request = actual_request )
637+ try :
638+ await auth_flow .asend (final_response )
639+ except StopAsyncIteration :
640+ pass # Expected - generator completes
641+
509642 @pytest .mark .anyio
510643 async def test_handle_metadata_response_success (self , oauth_provider : OAuthClientProvider ):
511644 """Test successful metadata response handling."""
0 commit comments