diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index 00131a0..e5aec66 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -1,6 +1,7 @@ """Channel creation utilities for the Athena client.""" import json +import logging import threading import time from typing import NamedTuple, override @@ -15,6 +16,8 @@ OAuthError, ) +logger = logging.getLogger(__name__) + class TokenData(NamedTuple): """Immutable snapshot of token state. @@ -28,11 +31,34 @@ class TokenData(NamedTuple): access_token: str expires_at: float scheme: str + issued_at: float def is_valid(self) -> bool: """Check if this token is still valid (with a 30-second buffer).""" return time.time() < (self.expires_at - 30) + def is_old(self, proactive_refresh_threshold: float) -> bool: + """Check if this token should be proactively refreshed. + + A token is considered "old" if less than the + proactive_refresh_threshold of its lifetime remains. This allows + background refresh to happen before expiry while the token is still + usable. + + Args: + ---- + proactive_refresh_threshold: Fraction of token lifetime past which + to trigger proactive refresh (e.g. 0.25 for 25%) + + """ + if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1: + msg = "proactive_refresh_threshold must be between 0 and 1" + raise ValueError(msg) + current_time = time.time() + total_lifetime = self.expires_at - self.issued_at + time_remaining = self.expires_at - current_time + return time_remaining < (total_lifetime * proactive_refresh_threshold) + class CredentialHelper: """OAuth credential helper for managing authentication tokens.""" @@ -43,6 +69,7 @@ def __init__( client_secret: str, auth_url: str = "https://crispthinking.auth0.com/oauth/token", audience: str = "crisp-athena-live", + proactive_refresh_threshold: float = 0.25, ) -> None: """Initialize the credential helper. @@ -52,6 +79,8 @@ def __init__( client_secret: OAuth client secret auth_url: OAuth token endpoint URL audience: OAuth audience + proactive_refresh_threshold: Fraction of token lifetime to trigger + proactive refresh (default 0.25 for 25%) """ if not client_id: @@ -67,14 +96,17 @@ def __init__( self._audience: str = audience self._token_data: TokenData | None = None self._lock: threading.Lock = threading.Lock() + self._refresh_thread: threading.Thread | None = None + + if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1: + msg = "proactive_refresh_threshold must be a float between 0 and 1" + raise ValueError(msg) + + self._proactive_refresh_threshold: float = proactive_refresh_threshold def get_token(self) -> TokenData: """Get valid token data, refreshing if necessary. - Uses double-checked locking: the happy path (token is valid) - avoids acquiring the lock entirely. The lock is only taken - when the token needs to be refreshed. - Returns ------- A valid ``TokenData`` containing access token, expiry, and scheme @@ -86,9 +118,15 @@ def get_token(self) -> TokenData: """ token_data = self._token_data + + # Fast path: token is valid and fresh if token_data is not None and token_data.is_valid(): + # If token is old, trigger background refresh + if token_data.is_old(self._proactive_refresh_threshold): + self._start_background_refresh() return token_data + # Slow path: token is expired or missing, must block with self._lock: token_data = self._token_data if token_data is not None and token_data.is_valid(): @@ -102,6 +140,71 @@ def get_token(self) -> TokenData: raise RuntimeError(msg) return token_data + def _start_background_refresh(self) -> None: + """Start a background thread to refresh the token. + + Only starts a new thread if one isn't already running. + + This method is safe to call multiple times - it only starts a new + thread if no refresh is currently in progress. + """ + # Quick check without lock - if refresh thread exists and is + # alive, skip + if self._refresh_thread is not None and self._refresh_thread.is_alive(): + return + + # Try to acquire lock and start refresh + if self._lock.acquire(blocking=False): + try: + # Double-check: another thread might have started refresh, + # or the token may have been refreshed. + refresh_not_active = ( + self._refresh_thread is None + or not self._refresh_thread.is_alive() + ) + token_needs_refresh = ( + self._token_data is None + or self._token_data.is_old( + self._proactive_refresh_threshold + ) + ) + refresh_needed = refresh_not_active and token_needs_refresh + if refresh_needed: + self._refresh_thread = threading.Thread( + target=self._background_refresh, + daemon=True, + ) + self._refresh_thread.start() + finally: + self._lock.release() + + def _background_refresh(self) -> None: + """Background thread target for token refresh. + + Acquires the lock and refreshes the token. Errors are logged + but silently ignored since the next foreground request will + retry if needed. + """ + with self._lock: + # Check if token still needs refresh (prevent stampede) + token_data = self._token_data + if token_data is not None and not token_data.is_old( + self._proactive_refresh_threshold + ): + # Token was already refreshed by another thread + return + + try: + self._refresh_token() + except Exception as e: # noqa: BLE001 + # Log but don't raise - background refresh failures + # are recoverable (next get_token() will retry) + logger.debug( + "Background token refresh failed, " + "will retry on next request: %s", + e, + ) + def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. @@ -138,10 +241,12 @@ def _refresh_token(self) -> None: token_type = raw.get("token_type", "Bearer") # Preserve server-provided casing, only strip whitespace scheme: str = token_type.strip() if token_type else "Bearer" + current_time = time.time() self._token_data = TokenData( access_token=access_token, - expires_at=time.time() + expires_in, + expires_at=current_time + expires_in, scheme=scheme, + issued_at=current_time, ) except httpx.HTTPStatusError as e: diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 221907d..33f8d85 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -97,6 +97,23 @@ def test_init_with_empty_client_secret(self) -> None: client_secret="", ) + @pytest.mark.parametrize( + "invalid", + [-0.1, 1.1, -0.5, 2.0], + ) + def test_init_with_invalid_proactive_refresh_threshold( + self, invalid: float + ) -> None: + with pytest.raises( + ValueError, + match="proactive_refresh_threshold must be a float between 0 and 1", + ): + _ = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + proactive_refresh_threshold=invalid, + ) + def test_is_token_valid_with_no_token(self) -> None: """Test token is not valid when no token data is set.""" helper = CredentialHelper( @@ -117,6 +134,7 @@ def test_is_token_valid_with_expired_token(self) -> None: access_token="test_token", expires_at=time.time() - 100, scheme="Bearer", + issued_at=time.time() - 3700, ) assert not helper._token_data.is_valid() @@ -132,6 +150,7 @@ def test_is_token_valid_with_valid_token(self) -> None: access_token="test_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) assert helper._token_data.is_valid() @@ -147,6 +166,7 @@ def test_is_token_valid_with_soon_expiring_token(self) -> None: access_token="test_token", expires_at=time.time() + 20, scheme="Bearer", + issued_at=time.time() - 3580, ) assert not helper._token_data.is_valid() @@ -162,7 +182,7 @@ def test_get_token_success(self) -> None: mock_response.json.return_value = { "access_token": "new_access_token", "expires_in": 3600, - "token_type": "bearer", + "token_type": "Bearer", } mock_response.raise_for_status.return_value = None @@ -198,7 +218,7 @@ def test_get_token_respects_token_type(self) -> None: token_data = helper.get_token() - assert token_data.scheme == "Dpop" + assert token_data.scheme == "DPoP" def test_get_token_defaults_to_bearer(self) -> None: """Test that scheme defaults to Bearer when token_type is absent.""" @@ -222,6 +242,44 @@ def test_get_token_defaults_to_bearer(self) -> None: assert token_data.scheme == "Bearer" + def test_get_token_preserves_server_casing(self) -> None: + """Test that server-provided token_type casing is preserved.""" + test_cases = [ + ("Bearer", "Bearer"), + ("bearer", "bearer"), + ("BEARER", "BEARER"), + ("DPoP", "DPoP"), + ("dpop", "dpop"), + (" Bearer ", "Bearer"), # Whitespace is stripped + ] + + for server_type, expected_scheme in test_cases: + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "test_token", + "expires_in": 3600, + "token_type": server_type, + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = ( + mock_client.return_value.__enter__.return_value + ) + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + assert token_data.scheme == expected_scheme, ( + f"Expected {expected_scheme} for {server_type}, " + f"got {token_data.scheme}" + ) + def test_get_token_cached(self) -> None: """Test that cached token is returned when valid.""" helper = CredentialHelper( @@ -234,6 +292,7 @@ def test_get_token_cached(self) -> None: access_token="cached_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) token_data = helper.get_token() @@ -321,6 +380,7 @@ def test_invalidate_token(self) -> None: access_token="valid_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) helper.invalidate_token() @@ -339,6 +399,7 @@ def test_get_token_refreshes_after_invalidation(self) -> None: access_token="old_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) helper.invalidate_token() @@ -369,6 +430,7 @@ def test_plugin_passes_bearer_token_to_callback(self) -> None: access_token="test-bearer-token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) @@ -388,6 +450,7 @@ def test_plugin_respects_token_scheme(self) -> None: access_token="dpop-token", expires_at=time.time() + 3600, scheme="Dpop", + issued_at=time.time(), ) plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) @@ -426,3 +489,222 @@ def test_plugin_catches_unexpected_exceptions(self) -> None: plugin(mock_context, mock_callback) mock_callback.assert_called_once_with((), runtime_error) + + +class TestBackgroundTokenRefresh: + """Tests for background token refresh functionality.""" + + def test_token_is_old_when_past_halfway_lifetime(self) -> None: + """Test that a token is considered old when past 25% of its lifetime.""" + current_time = time.time() + # Token with 1 hour lifetime, 20 minutes remaining (33%) + token = TokenData( + access_token="test_token", + expires_at=current_time + 600, # 10 minutes from now + scheme="Bearer", + issued_at=current_time - 3_000, # 50 minutes ago + ) + # Total lifetime = 3600s, remaining = 600s (1/6th), so it's old + assert token.is_old(0.25) + + def test_token_is_not_old_when_fresh(self) -> None: + """Test that a token is not old when more than 25% lifetime remains.""" + current_time = time.time() + # Token with 1 hour lifetime, 40 minutes remaining (67%) + token = TokenData( + access_token="test_token", + expires_at=current_time + 2400, # 40 minutes from now + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago + ) + # Total lifetime = 3600s, remaining = 2400s (67%), so it's fresh + assert not token.is_old(0.25) + + def test_get_token_triggers_background_refresh_for_old_token(self) -> None: + """Test that get_token triggers background refresh for old tokens.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up an old but valid token + helper._token_data = TokenData( + access_token="old_token", + expires_at=current_time + 600, # 10 minutes from now + scheme="Bearer", + issued_at=current_time - 3_000, # 50 minutes ago + ) + + with mock.patch.object( + helper, "_start_background_refresh" + ) as mock_start: + token_data = helper.get_token() + + # Should return current token immediately + assert token_data.access_token == "old_token" + # Should have triggered background refresh + mock_start.assert_called_once() + + def test_get_token_does_not_trigger_refresh_for_fresh_token(self) -> None: + """Test that get_token does not trigger refresh for fresh tokens.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up a fresh, valid token + helper._token_data = TokenData( + access_token="fresh_token", + expires_at=current_time + 2400, # 40 minutes remaining + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago, so it's fresh + ) + + with mock.patch.object( + helper, "_start_background_refresh" + ) as mock_start: + token_data = helper.get_token() + + # Should return current token + assert token_data.access_token == "fresh_token" + # Should NOT have triggered background refresh + mock_start.assert_not_called() + + def test_background_refresh_does_not_start_if_already_running(self) -> None: + """Test that background refresh doesn't start duplicate threads.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Mock a running refresh thread + mock_thread = mock.Mock() + mock_thread.is_alive.return_value = True + helper._refresh_thread = mock_thread + + with mock.patch("threading.Thread") as mock_thread_class: + helper._start_background_refresh() + + # Should not create a new thread + mock_thread_class.assert_not_called() + + def test_background_refresh_starts_new_thread_if_none_exists(self) -> None: + """Test that background refresh starts a thread when none exists.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_thread = mock.Mock() + with mock.patch("threading.Thread", return_value=mock_thread): + helper._start_background_refresh() + + # Should have started the thread + mock_thread.start.assert_called_once() + + def test_background_refresh_silently_handles_errors(self) -> None: + """Test that background refresh silently ignores errors.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Mock refresh to raise an error + with mock.patch.object( + helper, "_refresh_token", side_effect=OAuthError("Test error") + ): + # Should not raise an exception + helper._background_refresh() + + def test_background_refresh_prevents_stampede(self) -> None: + """Test background refresh skips refresh if token is fresh.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up a fresh token (already refreshed by another thread) + helper._token_data = TokenData( + access_token="fresh_token", + expires_at=current_time + 2400, # 40 minutes remaining + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago, so it's fresh + ) + + # Mock refresh to track if it's called + with mock.patch.object(helper, "_refresh_token") as mock_refresh: + helper._background_refresh() + + # Should NOT have called refresh since token is fresh + mock_refresh.assert_not_called() + + def test_get_token_blocks_for_expired_token(self) -> None: + """Test that get_token blocks and refreshes when token is expired.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Set up an expired token + helper._token_data = TokenData( + access_token="expired_token", + expires_at=time.time() - 100, # Expired + scheme="Bearer", + issued_at=time.time() - 3700, + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "new_token", + "expires_in": 3600, + "token_type": "bearer", + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + # Should have refreshed and returned new token + assert token_data.access_token == "new_token" + # Should have called the OAuth endpoint + mock_response_obj.post.assert_called_once() + + def test_refresh_token_sets_issued_at(self) -> None: + """Test that _refresh_token sets the issued_at timestamp.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "new_token", + "expires_in": 3600, + "token_type": "bearer", + } + mock_response.raise_for_status.return_value = None + + before_time = time.time() + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + _ = helper.get_token() + + after_time = time.time() + + # Check that issued_at was set to a reasonable value + assert helper._token_data is not None + assert before_time <= helper._token_data.issued_at <= after_time + # Check that expires_at is approximately issued_at + 3600 + # Allow 1 second tolerance for test execution time + assert ( + helper._token_data.expires_at - helper._token_data.issued_at - 3600 + < 1 + )