diff --git a/docs/api/exceptions.rst b/docs/api/exceptions.rst index 4c5720a..8a4fc9f 100644 --- a/docs/api/exceptions.rst +++ b/docs/api/exceptions.rst @@ -129,7 +129,8 @@ OAuth Error Handling client_id=client_id, client_secret=client_secret ) - token = await credential_helper.get_token() + token_data = credential_helper.get_token() + access_token = token_data.access_token except OAuthError as e: logger.error(f"OAuth authentication failed: {e}") # Handle OAuth failure - check credentials diff --git a/docs/authentication.rst b/docs/authentication.rst index 066c587..0f1ab53 100644 --- a/docs/authentication.rst +++ b/docs/authentication.rst @@ -76,14 +76,6 @@ Set these environment variables for OAuth authentication: audience=os.getenv("OAUTH_AUDIENCE", "crisp-athena-live"), ) - # Test token acquisition - try: - token = await credential_helper.get_token() - print(f"Successfully acquired token (length: {len(token)})") - except Exception as e: - print(f"Failed to acquire OAuth token: {e}") - return - # Create authenticated channel channel = await create_channel_with_credentials( host=os.getenv("ATHENA_HOST"), @@ -258,11 +250,12 @@ Handle OAuth-specific errors gracefully: .. code-block:: python - from resolver_athena_client.client.exceptions import AuthenticationError + from resolver_athena_client.client.exceptions import OAuthError try: - token = await credential_helper.get_token() - except AuthenticationError as e: + token_data = credential_helper.get_token() + access_token = token_data.access_token + except OAuthError as e: logger.error(f"OAuth authentication failed: {e}") # Handle authentication failure except Exception as e: @@ -356,8 +349,8 @@ Test your authentication setup: client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) - token = await credential_helper.get_token() - print(f"✓ Authentication successful (token length: {len(token)})") + token_data = credential_helper.get_token() + print(f"✓ Authentication successful (token length: {len(token_data.access_token)})") return True except Exception as e: diff --git a/examples/classify_single_example.py b/examples/classify_single_example.py index cfa74b8..7ed92e7 100755 --- a/examples/classify_single_example.py +++ b/examples/classify_single_example.py @@ -213,15 +213,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Configure client options options = AthenaOptions( host=host, diff --git a/examples/example.py b/examples/example.py index ce87722..a9a3e50 100755 --- a/examples/example.py +++ b/examples/example.py @@ -163,15 +163,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Get available deployment channel = await create_channel_with_credentials(host, credential_helper) async with DeploymentSelector(channel) as deployment_selector: diff --git a/src/resolver_athena_client/client/__init__.py b/src/resolver_athena_client/client/__init__.py index 94be685..4ad131c 100644 --- a/src/resolver_athena_client/client/__init__.py +++ b/src/resolver_athena_client/client/__init__.py @@ -5,6 +5,7 @@ from resolver_athena_client.client.channel import ( CredentialHelper, + TokenData, create_channel_with_credentials, ) from resolver_athena_client.client.exceptions import ( @@ -26,6 +27,7 @@ "CredentialError", "CredentialHelper", "OAuthError", + "TokenData", "TokenExpiredError", "create_channel_with_credentials", "get_output_error_summary", diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index e9d1510..497341a 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -1,8 +1,10 @@ """Channel creation utilities for the Athena client.""" -import asyncio import json +import logging +import threading import time +from dataclasses import dataclass from typing import override import grpc @@ -15,38 +17,47 @@ OAuthError, ) +logger = logging.getLogger(__name__) -class TokenMetadataPlugin(grpc.AuthMetadataPlugin): - """Plugin that adds authorization token to gRPC metadata.""" - def __init__(self, token: str) -> None: - """Initialize the plugin with the auth token. +@dataclass(frozen=True) +class TokenData: + """Immutable snapshot of token state.""" - Args: - ---- - token: The authorization token to add to requests + access_token: str + expires_at: float + scheme: str + issued_at: float + proactive_refresh_threshold: float = 0.25 - """ - self._token: str = token + def __post_init__(self) -> None: + """Validate that proactive_refresh_threshold is between 0 and 1.""" + if ( + self.proactive_refresh_threshold <= 0 + or self.proactive_refresh_threshold >= 1 + ): + msg = "proactive_refresh_threshold must be between 0 and 1" + raise ValueError(msg) - @override - def __call__( - self, - _: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback, - ) -> None: - """Pass authentication metadata to the provided callback. + def is_valid(self) -> bool: + """Check if this token is still valid (with a 30-second buffer).""" + return time.time() < (self.expires_at - 30) - This method will be invoked asynchronously in a separate thread. + def is_old(self) -> bool: + """Check if this token should be proactively refreshed. - Args: - ---- - callback: An AuthMetadataPluginCallback to be invoked either - synchronously or asynchronously. + A token is considered "old" if less than the + proactive_refresh_threshold of its lifetime remains. This allows + background refresh to happen before expiry while the token is still + usable. """ - metadata = (("authorization", f"Token {self._token}"),) - callback(metadata, None) + current_time = time.time() + total_lifetime = self.expires_at - self.issued_at + time_remaining = self.expires_at - current_time + return time_remaining < ( + total_lifetime * self.proactive_refresh_threshold + ) class CredentialHelper: @@ -58,6 +69,7 @@ def __init__( client_secret: str, auth_url: str = "https://crispthinking.auth0.com/oauth/token", audience: str = "crisp-athena-live", + proactive_refresh_threshold: float = 0.25, ) -> None: """Initialize the credential helper. @@ -67,6 +79,8 @@ def __init__( client_secret: OAuth client secret auth_url: OAuth token endpoint URL audience: OAuth audience + proactive_refresh_threshold: Fraction of token lifetime to trigger + proactive refresh (default 0.25 for 25%) """ if not client_id: @@ -80,56 +94,118 @@ def __init__( self._client_secret: str = client_secret self._auth_url: str = auth_url self._audience: str = audience - self._token: str | None = None - self._token_expires_at: float | None = None - self._lock: asyncio.Lock = asyncio.Lock() + self._token_data: TokenData | None = None + self._lock: threading.Lock = threading.Lock() + self._refresh_thread: threading.Thread | None = None + + if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1: + msg = "proactive_refresh_threshold must be a float between 0 and 1" + raise ValueError(msg) - async def get_token(self) -> str: - """Get a valid authentication token. + self._proactive_refresh_threshold: float = proactive_refresh_threshold - This method will return a cached token if it's still valid, - or fetch a new token if needed. + def get_token(self) -> TokenData: + """Get valid token data, refreshing if necessary. Returns ------- - A valid authentication token + A valid ``TokenData`` containing access token, expiry, and scheme Raises ------ OAuthError: If token acquisition fails - TokenExpiredError: If token has expired and refresh fails + RuntimeError: If token is unexpectedly None after refresh """ - async with self._lock: - if self._is_token_valid(): - if self._token is None: - msg = "Token should be valid but is None" - raise RuntimeError(msg) - return self._token - - await self._refresh_token() - if self._token is None: - msg = "Token refresh failed" + token_data = self._token_data + + # Fast path: token is valid and fresh + if token_data is not None and token_data.is_valid(): + # If token is old, trigger background refresh + if token_data.is_old(): + self._start_background_refresh() + return token_data + + # Slow path: token is expired or missing, must block + with self._lock: + token_data = self._token_data + if token_data is not None and token_data.is_valid(): + return token_data + + self._refresh_token() + + token_data = self._token_data + if token_data is None: + msg = "Token is unexpectedly None after refresh" raise RuntimeError(msg) - return self._token + return token_data - def _is_token_valid(self) -> bool: - """Check if the current token is valid and not expired. + def _start_background_refresh(self) -> None: + """Start a background thread to refresh the token. - Returns - ------- - True if token is valid, False otherwise + Only starts a new thread if one isn't already running. + This method is safe to call multiple times - it only starts a new + thread if no refresh is currently in progress. """ - if not self._token or not self._token_expires_at: - return False + # Quick check without lock - if refresh thread exists and is + # alive, skip + if self._refresh_thread is not None and self._refresh_thread.is_alive(): + return - # Add 30 second buffer before expiration - return time.time() < (self._token_expires_at - 30) + # Try to acquire lock and start refresh + if self._lock.acquire(blocking=False): + try: + # Double-check: another thread might have started refresh, + # or the token may have been refreshed. + refresh_not_active = ( + self._refresh_thread is None + or not self._refresh_thread.is_alive() + ) + token_needs_refresh = ( + self._token_data is None or self._token_data.is_old() + ) + refresh_needed = refresh_not_active and token_needs_refresh + if refresh_needed: + self._refresh_thread = threading.Thread( + target=self._background_refresh, + daemon=True, + ) + self._refresh_thread.start() + finally: + self._lock.release() + + def _background_refresh(self) -> None: + """Background thread target for token refresh. + + Acquires the lock and refreshes the token. Errors are logged + but silently ignored since the next foreground request will + retry if needed. + """ + with self._lock: + # Check if token still needs refresh (prevent stampede) + token_data = self._token_data + if token_data is not None and not token_data.is_old(): + # Token was already refreshed by another thread + return - async def _refresh_token(self) -> None: + try: + self._refresh_token() + except Exception as e: # noqa: BLE001 + # Log but don't raise - background refresh failures + # are recoverable (next get_token() will retry) + logger.debug( + "Background token refresh failed, " + "will retry on next request: %s", + e, + ) + + def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. + This is a synchronous call (suitable for the gRPC metadata-plugin + thread) and must be called while ``self._lock`` is held. + Raises ------ OAuthError: If the OAuth request fails @@ -145,8 +221,8 @@ async def _refresh_token(self) -> None: headers = {"content-type": "application/json"} try: - async with httpx.AsyncClient() as client: - response = await client.post( + with httpx.Client() as client: + response = client.post( self._auth_url, json=payload, headers=headers, @@ -154,12 +230,19 @@ async def _refresh_token(self) -> None: ) _ = response.raise_for_status() - token_data = response.json() - self._token = token_data["access_token"] - expires_in = token_data.get( - "expires_in", 3600 - ) # Default 1 hour - self._token_expires_at = time.time() + expires_in + raw = response.json() + access_token: str = raw["access_token"] + expires_in: int = raw.get("expires_in", 3600) # Default 1 hour + token_type = raw.get("token_type", "Bearer") + scheme: str = token_type.strip() if token_type else "Bearer" + current_time = time.time() + self._token_data = TokenData( + access_token=access_token, + expires_at=current_time + expires_in, + scheme=scheme, + issued_at=current_time, + proactive_refresh_threshold=self._proactive_refresh_threshold, + ) except httpx.HTTPStatusError as e: error_detail = "" @@ -190,11 +273,52 @@ async def _refresh_token(self) -> None: msg = f"Unexpected error during OAuth: {e}" raise OAuthError(msg) from e - async def invalidate_token(self) -> None: + def invalidate_token(self) -> None: """Invalidate the current token to force a refresh on next use.""" - async with self._lock: - self._token = None - self._token_expires_at = None + with self._lock: + self._token_data = None + + +class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): + """gRPC auth plugin that fetches a fresh token for every RPC.""" + + def __init__(self, credential_helper: CredentialHelper) -> None: + """Initialize with a credential helper. + + Args: + ---- + credential_helper: The helper that manages token lifecycle + + """ + self._credential_helper: CredentialHelper = credential_helper + + @override + def __call__( + self, + _: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: + """Supply authorization metadata for an RPC. + + Called by the gRPC runtime on a background thread before each + RPC. On success the token is forwarded using the scheme from + the OAuth token response (typically ``Bearer``); on failure + the error is passed to the callback so gRPC can surface it as + an RPC error. + + Args: + ---- + callback: gRPC callback to receive metadata or an error + + """ + try: + token_data = self._credential_helper.get_token() + scheme = token_data.scheme + token = token_data.access_token + metadata = (("authorization", f"{scheme} {token}"),) + callback(metadata, None) + except Exception as err: # noqa: BLE001 + callback((), err) async def create_channel_with_credentials( @@ -215,19 +339,17 @@ async def create_channel_with_credentials( Raises: ------ InvalidHostError: If host is empty - OAuthError: If OAuth authentication fails """ if not host: raise InvalidHostError(InvalidHostError.default_message) - # Get a valid token from the credential helper - token = await credential_helper.get_token() - - # Create credentials with token authentication + # Create credentials with per-RPC token refresh credentials = grpc.composite_channel_credentials( grpc.ssl_channel_credentials(), - grpc.access_token_call_credentials(token), + grpc.metadata_call_credentials( + _AutoRefreshTokenAuthMetadataPlugin(credential_helper) + ), ) # Configure gRPC options for persistent connections diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 701bd00..0a538fa 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -7,11 +7,11 @@ import httpx import pytest -from grpc.aio import Channel from resolver_athena_client.client.channel import ( CredentialHelper, - TokenMetadataPlugin, + TokenData, + _AutoRefreshTokenAuthMetadataPlugin, create_channel_with_credentials, ) from resolver_athena_client.client.exceptions import ( @@ -21,23 +21,6 @@ ) -def test_token_metadata_plugin() -> None: - """Test TokenMetadataPlugin functionality.""" - test_token = "test-token" - plugin = TokenMetadataPlugin(test_token) - - # Mock callback - mock_callback = mock.Mock() - mock_context = mock.Mock() - - # Call the plugin - plugin(mock_context, mock_callback) - - # Verify the callback was called with correct metadata - expected_metadata = (("authorization", f"Token {test_token}"),) - mock_callback.assert_called_once_with(expected_metadata, None) - - @pytest.mark.asyncio async def test_create_channel_with_credentials_validation() -> None: """Test channel creation with credentials validates input properly.""" @@ -50,16 +33,23 @@ async def test_create_channel_with_credentials_validation() -> None: @pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_failure() -> None: - """Test channel creation when OAuth token acquisition fails.""" +async def test_create_channel_does_not_eagerly_fetch_token() -> None: + """Channel creation must NOT call get_token() eagerly.""" test_host = "test-host:50051" mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("Token acquisition failed") - with pytest.raises(OAuthError, match="Token acquisition failed"): + with ( + mock.patch("grpc.ssl_channel_credentials"), + mock.patch("grpc.metadata_call_credentials"), + mock.patch("grpc.composite_channel_credentials"), + mock.patch("grpc.aio.secure_channel"), + ): _ = await create_channel_with_credentials(test_host, mock_helper) + # Token should NOT be fetched at channel creation time + mock_helper.get_token.assert_not_called() + class TestCredentialHelper: """Test cases for CredentialHelper OAuth functionality.""" @@ -75,8 +65,7 @@ def test_init_with_valid_params(self) -> None: assert helper._client_secret == "test_client_secret" assert helper._auth_url == "https://crispthinking.auth0.com/oauth/token" assert helper._audience == "crisp-athena-live" - assert helper._token is None - assert helper._token_expires_at is None + assert helper._token_data is None def test_init_with_custom_params(self) -> None: """Test CredentialHelper initialization with custom parameters.""" @@ -108,53 +97,82 @@ def test_init_with_empty_client_secret(self) -> None: client_secret="", ) + @pytest.mark.parametrize( + "invalid", + [-0.1, 1.1, -0.5, 2.0], + ) + def test_init_with_invalid_proactive_refresh_threshold( + self, invalid: float + ) -> None: + with pytest.raises( + ValueError, + match="proactive_refresh_threshold must be a float between 0 and 1", + ): + _ = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + proactive_refresh_threshold=invalid, + ) + def test_is_token_valid_with_no_token(self) -> None: - """Test _is_token_valid returns False when no token is set.""" + """Test token is not valid when no token data is set.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - assert not helper._is_token_valid() + assert helper._token_data is None def test_is_token_valid_with_expired_token(self) -> None: - """Test _is_token_valid returns False when token is expired.""" + """Test TokenData.is_valid returns False when token is expired.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - helper._token = "test_token" - helper._token_expires_at = time.time() - 100 # Expired + helper._token_data = TokenData( + access_token="test_token", + expires_at=time.time() - 100, + scheme="Bearer", + issued_at=time.time() - 3700, + proactive_refresh_threshold=0.25, + ) - assert not helper._is_token_valid() + assert not helper._token_data.is_valid() def test_is_token_valid_with_valid_token(self) -> None: - """Test _is_token_valid returns True when token is valid.""" + """Test TokenData.is_valid returns True when token is valid.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - helper._token = "test_token" - helper._token_expires_at = time.time() + 3600 # Valid for 1 hour + helper._token_data = TokenData( + access_token="test_token", + expires_at=time.time() + 3600, + scheme="Bearer", + issued_at=time.time(), + ) - assert helper._is_token_valid() + assert helper._token_data.is_valid() def test_is_token_valid_with_soon_expiring_token(self) -> None: - """Test _is_token_valid returns False when token expires soon.""" + """Test is_valid returns False when token expires within 30s.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - helper._token = "test_token" - helper._token_expires_at = time.time() + 20 # Expires in 20 seconds + helper._token_data = TokenData( + access_token="test_token", + expires_at=time.time() + 20, + scheme="Bearer", + issued_at=time.time() - 3580, + ) - assert not helper._is_token_valid() + assert not helper._token_data.is_valid() - @pytest.mark.asyncio - async def test_get_token_success(self) -> None: + def test_get_token_success(self) -> None: """Test successful token acquisition.""" helper = CredentialHelper( client_id="test_client_id", @@ -165,21 +183,105 @@ async def test_get_token_success(self) -> None: mock_response.json.return_value = { "access_token": "new_access_token", "expires_in": 3600, + "token_type": "Bearer", + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + assert token_data.access_token == "new_access_token" + assert token_data.scheme == "Bearer" + assert helper._token_data is not None + assert helper._token_data.expires_at is not None + + def test_get_token_respects_token_type(self) -> None: + """Test that token_type from OAuth response is respected.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "some_token", + "expires_in": 3600, + "token_type": "DPoP", } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response - token = await helper.get_token() + token_data = helper.get_token() + + assert token_data.scheme == "DPoP" - assert token == "new_access_token" - assert helper._token == "new_access_token" - assert helper._token_expires_at is not None + def test_get_token_defaults_to_bearer(self) -> None: + """Test that scheme defaults to Bearer when token_type is absent.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "some_token", + "expires_in": 3600, + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response - @pytest.mark.asyncio - async def test_get_token_cached(self) -> None: + token_data = helper.get_token() + + assert token_data.scheme == "Bearer" + + def test_get_token_preserves_server_casing(self) -> None: + """Test that server-provided token_type casing is preserved.""" + test_cases = [ + ("Bearer", "Bearer"), + ("bearer", "bearer"), + ("BEARER", "BEARER"), + ("DPoP", "DPoP"), + ("dpop", "dpop"), + (" Bearer ", "Bearer"), # Whitespace is stripped + ] + + for server_type, expected_scheme in test_cases: + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "test_token", + "expires_in": 3600, + "token_type": server_type, + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = ( + mock_client.return_value.__enter__.return_value + ) + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + assert token_data.scheme == expected_scheme, ( + f"Expected {expected_scheme} for {server_type}, " + f"got {token_data.scheme}" + ) + + def test_get_token_cached(self) -> None: """Test that cached token is returned when valid.""" helper = CredentialHelper( client_id="test_client_id", @@ -187,15 +289,18 @@ async def test_get_token_cached(self) -> None: ) # Set up a valid cached token - helper._token = "cached_token" - helper._token_expires_at = time.time() + 3600 + helper._token_data = TokenData( + access_token="cached_token", + expires_at=time.time() + 3600, + scheme="Bearer", + issued_at=time.time(), + ) - token = await helper.get_token() + token_data = helper.get_token() - assert token == "cached_token" + assert token_data.access_token == "cached_token" - @pytest.mark.asyncio - async def test_refresh_token_http_error(self) -> None: + def test_refresh_token_http_error(self) -> None: """Test token refresh with HTTP error.""" helper = CredentialHelper( client_id="test_client_id", @@ -215,17 +320,16 @@ async def test_refresh_token_http_error(self) -> None: response=mock_response, ) - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = http_error with pytest.raises( OAuthError, match="OAuth request failed with status 401" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_request_error(self) -> None: + def test_refresh_token_request_error(self) -> None: """Test token refresh with request error.""" helper = CredentialHelper( client_id="test_client_id", @@ -234,17 +338,16 @@ async def test_refresh_token_request_error(self) -> None: request_error = httpx.RequestError("Connection failed") - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = request_error with pytest.raises( OAuthError, match="Failed to connect to OAuth server" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_invalid_response(self) -> None: + def test_refresh_token_invalid_response(self) -> None: """Test token refresh with invalid response format.""" helper = CredentialHelper( client_id="test_client_id", @@ -257,17 +360,16 @@ async def test_refresh_token_invalid_response(self) -> None: } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response with pytest.raises( OAuthError, match="Invalid OAuth response format" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_invalidate_token(self) -> None: + def test_invalidate_token(self) -> None: """Test token invalidation.""" helper = CredentialHelper( client_id="test_client_id", @@ -275,67 +377,335 @@ async def test_invalidate_token(self) -> None: ) # Set up a valid token - helper._token = "valid_token" - helper._token_expires_at = time.time() + 3600 + helper._token_data = TokenData( + access_token="valid_token", + expires_at=time.time() + 3600, + scheme="Bearer", + issued_at=time.time(), + ) - await helper.invalidate_token() + helper.invalidate_token() - assert helper._token is None - assert helper._token_expires_at is None + assert helper._token_data is None + def test_get_token_refreshes_after_invalidation(self) -> None: + """Test that get_token refreshes after invalidation.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) -@pytest.mark.asyncio -async def test_create_channel_with_credentials_success() -> None: - """Test successful channel creation with credential helper.""" - test_host = "test-host:50051" + # Set up a valid token, then invalidate it + helper._token_data = TokenData( + access_token="old_token", + expires_at=time.time() + 3600, + scheme="Bearer", + issued_at=time.time(), + ) + helper.invalidate_token() - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.return_value = "test_token" + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "refreshed_token", + "expires_in": 3600, + "token_type": "bearer", + } + mock_response.raise_for_status.return_value = None - mock_credentials = mock.Mock() - mock_channel = mock.Mock(spec=Channel) + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response - with ( - mock.patch("grpc.ssl_channel_credentials") as mock_ssl_creds, - mock.patch("grpc.access_token_call_credentials") as mock_token_creds, - mock.patch( - "grpc.composite_channel_credentials" - ) as mock_composite_creds, - mock.patch("grpc.aio.secure_channel") as mock_secure_channel, - ): - # Set up mocks - mock_ssl_creds.return_value = mock.Mock() - mock_token_creds.return_value = mock.Mock() - mock_composite_creds.return_value = mock_credentials - mock_secure_channel.return_value = mock_channel + token_data = helper.get_token() + + assert token_data.access_token == "refreshed_token" + + +class TestAutoRefreshTokenAuthMetadataPlugin: + """Tests for the per-RPC auth metadata plugin.""" + + def test_plugin_passes_bearer_token_to_callback(self) -> None: + """Plugin fetches token and passes Bearer metadata.""" + mock_helper = mock.Mock(spec=CredentialHelper) + mock_helper.get_token.return_value = TokenData( + access_token="test-bearer-token", + expires_at=time.time() + 3600, + scheme="Bearer", + issued_at=time.time(), + ) - # Create channel - channel = await create_channel_with_credentials(test_host, mock_helper) + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() + + plugin(mock_context, mock_callback) - # Verify channel creation - assert channel == mock_channel mock_helper.get_token.assert_called_once() - mock_token_creds.assert_called_once_with("test_token") + expected_metadata = (("authorization", "Bearer test-bearer-token"),) + mock_callback.assert_called_once_with(expected_metadata, None) + + def test_plugin_respects_token_scheme(self) -> None: + """Plugin uses the scheme from TokenData, not hardcoded Bearer.""" + mock_helper = mock.Mock(spec=CredentialHelper) + mock_helper.get_token.return_value = TokenData( + access_token="dpop-token", + expires_at=time.time() + 3600, + scheme="Dpop", + issued_at=time.time(), + ) + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() -@pytest.mark.asyncio -async def test_create_channel_with_credentials_invalid_host() -> None: - """Test channel creation with credentials and invalid host raises error.""" - test_host = "" # Invalid host + plugin(mock_context, mock_callback) - mock_helper = mock.Mock(spec=CredentialHelper) + expected_metadata = (("authorization", "Dpop dpop-token"),) + mock_callback.assert_called_once_with(expected_metadata, None) - with pytest.raises(InvalidHostError, match="host cannot be empty"): - _ = await create_channel_with_credentials(test_host, mock_helper) + def test_plugin_passes_oauth_error_to_callback(self) -> None: + """Test that OAuthError is forwarded to the callback as an error.""" + mock_helper = mock.Mock(spec=CredentialHelper) + oauth_error = OAuthError("token acquisition failed") + mock_helper.get_token.side_effect = oauth_error + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() -@pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_error() -> None: - """Test channel creation with credentials when OAuth fails.""" - test_host = "test-host:50051" + plugin(mock_context, mock_callback) - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("OAuth failed") + mock_callback.assert_called_once_with((), oauth_error) - with pytest.raises(OAuthError, match="OAuth failed"): - _ = await create_channel_with_credentials(test_host, mock_helper) + def test_plugin_catches_unexpected_exceptions(self) -> None: + """Non-OAuthError exceptions are forwarded to callback.""" + mock_helper = mock.Mock(spec=CredentialHelper) + runtime_error = RuntimeError("unexpected failure") + mock_helper.get_token.side_effect = runtime_error + + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() + + plugin(mock_context, mock_callback) + + mock_callback.assert_called_once_with((), runtime_error) + + +class TestBackgroundTokenRefresh: + """Tests for background token refresh functionality.""" + + def test_token_is_old_when_past_halfway_lifetime(self) -> None: + """Test that a token is considered old when past 25% of its lifetime.""" + current_time = time.time() + # Token with 1 hour lifetime, 20 minutes remaining (33%) + token = TokenData( + access_token="test_token", + expires_at=current_time + 600, # 10 minutes from now + scheme="Bearer", + issued_at=current_time - 3_000, # 50 minutes ago + ) + # Total lifetime = 3600s, remaining = 600s (1/6th), so it's old + assert token.is_old() + + def test_token_is_not_old_when_fresh(self) -> None: + """Test that a token is not old when more than 25% lifetime remains.""" + current_time = time.time() + # Token with 1 hour lifetime, 40 minutes remaining (67%) + token = TokenData( + access_token="test_token", + expires_at=current_time + 2400, # 40 minutes from now + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago + ) + # Total lifetime = 3600s, remaining = 2400s (67%), so it's fresh + assert not token.is_old() + + def test_get_token_triggers_background_refresh_for_old_token(self) -> None: + """Test that get_token triggers background refresh for old tokens.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up an old but valid token + helper._token_data = TokenData( + access_token="old_token", + expires_at=current_time + 600, # 10 minutes from now + scheme="Bearer", + issued_at=current_time - 3_000, # 50 minutes ago + ) + + with mock.patch.object( + helper, "_start_background_refresh" + ) as mock_start: + token_data = helper.get_token() + + # Should return current token immediately + assert token_data.access_token == "old_token" + # Should have triggered background refresh + mock_start.assert_called_once() + + def test_get_token_does_not_trigger_refresh_for_fresh_token(self) -> None: + """Test that get_token does not trigger refresh for fresh tokens.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up a fresh, valid token + helper._token_data = TokenData( + access_token="fresh_token", + expires_at=current_time + 2400, # 40 minutes remaining + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago, so it's fresh + ) + + with mock.patch.object( + helper, "_start_background_refresh" + ) as mock_start: + token_data = helper.get_token() + + # Should return current token + assert token_data.access_token == "fresh_token" + # Should NOT have triggered background refresh + mock_start.assert_not_called() + + def test_background_refresh_does_not_start_if_already_running(self) -> None: + """Test that background refresh doesn't start duplicate threads.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Mock a running refresh thread + mock_thread = mock.Mock() + mock_thread.is_alive.return_value = True + helper._refresh_thread = mock_thread + + with mock.patch("threading.Thread") as mock_thread_class: + helper._start_background_refresh() + + # Should not create a new thread + mock_thread_class.assert_not_called() + + def test_background_refresh_starts_new_thread_if_none_exists(self) -> None: + """Test that background refresh starts a thread when none exists.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_thread = mock.Mock() + with mock.patch("threading.Thread", return_value=mock_thread): + helper._start_background_refresh() + + # Should have started the thread + mock_thread.start.assert_called_once() + + def test_background_refresh_silently_handles_errors(self) -> None: + """Test that background refresh silently ignores errors.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Mock refresh to raise an error + with mock.patch.object( + helper, "_refresh_token", side_effect=OAuthError("Test error") + ): + # Should not raise an exception + helper._background_refresh() + + def test_background_refresh_prevents_stampede(self) -> None: + """Test background refresh skips refresh if token is fresh.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up a fresh token (already refreshed by another thread) + helper._token_data = TokenData( + access_token="fresh_token", + expires_at=current_time + 2400, # 40 minutes remaining + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago, so it's fresh + ) + + # Mock refresh to track if it's called + with mock.patch.object(helper, "_refresh_token") as mock_refresh: + helper._background_refresh() + + # Should NOT have called refresh since token is fresh + mock_refresh.assert_not_called() + + def test_get_token_blocks_for_expired_token(self) -> None: + """Test that get_token blocks and refreshes when token is expired.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Set up an expired token + helper._token_data = TokenData( + access_token="expired_token", + expires_at=time.time() - 100, # Expired + scheme="Bearer", + issued_at=time.time() - 3700, + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "new_token", + "expires_in": 3600, + "token_type": "bearer", + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + # Should have refreshed and returned new token + assert token_data.access_token == "new_token" + # Should have called the OAuth endpoint + mock_response_obj.post.assert_called_once() + + def test_refresh_token_sets_issued_at(self) -> None: + """Test that _refresh_token sets the issued_at timestamp.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "new_token", + "expires_in": 3600, + "token_type": "bearer", + } + mock_response.raise_for_status.return_value = None + + before_time = time.time() + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + _ = helper.get_token() + + after_time = time.time() + + # Check that issued_at was set to a reasonable value + assert helper._token_data is not None + assert before_time <= helper._token_data.issued_at <= after_time + # Check that expires_at is approximately issued_at + 3600 + # Allow 1 second tolerance for test execution time + assert ( + helper._token_data.expires_at - helper._token_data.issued_at - 3600 + < 1 + ) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 6401044..4d58a8c 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -71,22 +71,13 @@ async def credential_helper() -> CredentialHelper: audience = os.getenv("OAUTH_AUDIENCE", "crisp-athena-live") # Create credential helper - credential_helper = CredentialHelper( + return CredentialHelper( client_id=client_id, client_secret=client_secret, auth_url=auth_url, audience=audience, ) - # Test token acquisition - try: - _ = await credential_helper.get_token() - except Exception as e: - msg = "Failed to acquire OAuth token" - raise AssertionError(msg) from e - - return credential_helper - @pytest.fixture def athena_options() -> AthenaOptions: diff --git a/tests/functional/test_invalid_oauth.py b/tests/functional/test_invalid_oauth.py index 5e044e7..a2cfbbb 100644 --- a/tests/functional/test_invalid_oauth.py +++ b/tests/functional/test_invalid_oauth.py @@ -3,17 +3,12 @@ import pytest from dotenv import load_dotenv -from resolver_athena_client.client.athena_options import AthenaOptions -from resolver_athena_client.client.channel import ( - CredentialHelper, - create_channel_with_credentials, -) +from resolver_athena_client.client.channel import CredentialHelper from resolver_athena_client.client.exceptions import OAuthError -@pytest.mark.asyncio @pytest.mark.functional -async def test_invalid_secret(athena_options: AthenaOptions) -> None: +def test_invalid_secret() -> None: """Test that an invalid OAuth client secret is rejected.""" _ = load_dotenv() invalid_client_secret = "this_is_not_a_valid_secret" @@ -31,15 +26,12 @@ async def test_invalid_secret(athena_options: AthenaOptions) -> None: ) with pytest.raises(OAuthError): - _ = await create_channel_with_credentials( - athena_options.host, credential_helper=credential_helper - ) + _ = credential_helper.get_token() -@pytest.mark.asyncio @pytest.mark.functional -async def test_invalid_clientid(athena_options: AthenaOptions) -> None: - """Test that an invalid OAuth client secret is rejected.""" +def test_invalid_clientid() -> None: + """Test that an invalid OAuth client ID is rejected.""" _ = load_dotenv() client_secret = os.environ["OAUTH_CLIENT_SECRET"] client_id = "this_is_not_a_valid_client_id" @@ -56,14 +48,11 @@ async def test_invalid_clientid(athena_options: AthenaOptions) -> None: ) with pytest.raises(OAuthError): - _ = await create_channel_with_credentials( - athena_options.host, credential_helper=credential_helper - ) + _ = credential_helper.get_token() -@pytest.mark.asyncio @pytest.mark.functional -async def test_invalid_audience(athena_options: AthenaOptions) -> None: +def test_invalid_audience() -> None: """Test that an invalid OAuth audience is rejected.""" _ = load_dotenv() client_secret = os.environ["OAUTH_CLIENT_SECRET"] @@ -81,6 +70,4 @@ async def test_invalid_audience(athena_options: AthenaOptions) -> None: ) with pytest.raises(OAuthError): - _ = await create_channel_with_credentials( - athena_options.host, credential_helper=credential_helper - ) + _ = credential_helper.get_token()