From 878b19ed1ca6bc073ebacb54739d7927582bc58f Mon Sep 17 00:00:00 2001 From: Carl Patchett Date: Mon, 16 Feb 2026 12:06:23 +0000 Subject: [PATCH 01/12] refactor: simplify token acquisition by removing async calls and unused tests --- docs/api/exceptions.rst | 2 +- docs/authentication.rst | 12 +- examples/classify_single_example.py | 9 - examples/example.py | 9 - src/resolver_athena_client/client/channel.py | 150 ++++++++-------- tests/client/test_channel.py | 171 +++++++++---------- tests/functional/conftest.py | 11 +- 7 files changed, 171 insertions(+), 193 deletions(-) diff --git a/docs/api/exceptions.rst b/docs/api/exceptions.rst index 4c5720a..75419ee 100644 --- a/docs/api/exceptions.rst +++ b/docs/api/exceptions.rst @@ -129,7 +129,7 @@ OAuth Error Handling client_id=client_id, client_secret=client_secret ) - token = await credential_helper.get_token() + token = credential_helper.get_token() except OAuthError as e: logger.error(f"OAuth authentication failed: {e}") # Handle OAuth failure - check credentials diff --git a/docs/authentication.rst b/docs/authentication.rst index 066c587..49d147f 100644 --- a/docs/authentication.rst +++ b/docs/authentication.rst @@ -76,14 +76,6 @@ Set these environment variables for OAuth authentication: audience=os.getenv("OAUTH_AUDIENCE", "crisp-athena-live"), ) - # Test token acquisition - try: - token = await credential_helper.get_token() - print(f"Successfully acquired token (length: {len(token)})") - except Exception as e: - print(f"Failed to acquire OAuth token: {e}") - return - # Create authenticated channel channel = await create_channel_with_credentials( host=os.getenv("ATHENA_HOST"), @@ -261,7 +253,7 @@ Handle OAuth-specific errors gracefully: from resolver_athena_client.client.exceptions import AuthenticationError try: - token = await credential_helper.get_token() + token = credential_helper.get_token() except AuthenticationError as e: logger.error(f"OAuth authentication failed: {e}") # Handle authentication failure @@ -356,7 +348,7 @@ Test your authentication setup: client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) - token = await credential_helper.get_token() + token = credential_helper.get_token() print(f"✓ Authentication successful (token length: {len(token)})") return True diff --git a/examples/classify_single_example.py b/examples/classify_single_example.py index 1f1d305..484862a 100755 --- a/examples/classify_single_example.py +++ b/examples/classify_single_example.py @@ -213,15 +213,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Configure client options options = AthenaOptions( host=host, diff --git a/examples/example.py b/examples/example.py index ce87722..a9a3e50 100755 --- a/examples/example.py +++ b/examples/example.py @@ -163,15 +163,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Get available deployment channel = await create_channel_with_credentials(host, credential_helper) async with DeploymentSelector(channel) as deployment_selector: diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index e9d1510..e7ede64 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -1,7 +1,7 @@ """Channel creation utilities for the Athena client.""" -import asyncio import json +import threading import time from typing import override @@ -16,39 +16,6 @@ ) -class TokenMetadataPlugin(grpc.AuthMetadataPlugin): - """Plugin that adds authorization token to gRPC metadata.""" - - def __init__(self, token: str) -> None: - """Initialize the plugin with the auth token. - - Args: - ---- - token: The authorization token to add to requests - - """ - self._token: str = token - - @override - def __call__( - self, - _: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback, - ) -> None: - """Pass authentication metadata to the provided callback. - - This method will be invoked asynchronously in a separate thread. - - Args: - ---- - callback: An AuthMetadataPluginCallback to be invoked either - synchronously or asynchronously. - - """ - metadata = (("authorization", f"Token {self._token}"),) - callback(metadata, None) - - class CredentialHelper: """OAuth credential helper for managing authentication tokens.""" @@ -82,13 +49,14 @@ def __init__( self._audience: str = audience self._token: str | None = None self._token_expires_at: float | None = None - self._lock: asyncio.Lock = asyncio.Lock() + self._lock: threading.Lock = threading.Lock() - async def get_token(self) -> str: + def get_token(self) -> str: """Get a valid authentication token. - This method will return a cached token if it's still valid, - or fetch a new token if needed. + Returns a cached token if still valid, or acquires a new one. + The entire check-and-refresh cycle runs under a single lock + acquisition to prevent TOCTOU races with ``invalidate_token``. Returns ------- @@ -97,21 +65,18 @@ async def get_token(self) -> str: Raises ------ OAuthError: If token acquisition fails - TokenExpiredError: If token has expired and refresh fails + RuntimeError: If token is unexpectedly None after refresh """ - async with self._lock: - if self._is_token_valid(): - if self._token is None: - msg = "Token should be valid but is None" - raise RuntimeError(msg) - return self._token - - await self._refresh_token() - if self._token is None: - msg = "Token refresh failed" + with self._lock: + if not self._is_token_valid(): + self._refresh_token() + + token = self._token + if token is None: + msg = "Token is unexpectedly None after validity check" raise RuntimeError(msg) - return self._token + return token def _is_token_valid(self) -> bool: """Check if the current token is valid and not expired. @@ -127,9 +92,12 @@ def _is_token_valid(self) -> bool: # Add 30 second buffer before expiration return time.time() < (self._token_expires_at - 30) - async def _refresh_token(self) -> None: + def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. + This is a synchronous call (suitable for the gRPC metadata-plugin + thread) and must be called while ``self._lock`` is held. + Raises ------ OAuthError: If the OAuth request fails @@ -145,8 +113,8 @@ async def _refresh_token(self) -> None: headers = {"content-type": "application/json"} try: - async with httpx.AsyncClient() as client: - response = await client.post( + with httpx.Client() as client: + response = client.post( self._auth_url, json=payload, headers=headers, @@ -154,12 +122,10 @@ async def _refresh_token(self) -> None: ) _ = response.raise_for_status() - token_data = response.json() - self._token = token_data["access_token"] - expires_in = token_data.get( - "expires_in", 3600 - ) # Default 1 hour - self._token_expires_at = time.time() + expires_in + token_data = response.json() + self._token = token_data["access_token"] + expires_in = token_data.get("expires_in", 3600) # Default 1 hour + self._token_expires_at = time.time() + expires_in except httpx.HTTPStatusError as e: error_detail = "" @@ -190,13 +156,59 @@ async def _refresh_token(self) -> None: msg = f"Unexpected error during OAuth: {e}" raise OAuthError(msg) from e - async def invalidate_token(self) -> None: + def invalidate_token(self) -> None: """Invalidate the current token to force a refresh on next use.""" - async with self._lock: + with self._lock: self._token = None self._token_expires_at = None +class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): + """gRPC auth plugin that fetches a fresh token for every RPC. + + The plugin delegates to ``CredentialHelper.get_token()`` which + handles caching, expiry checks, and thread-safe refresh internally. + This callback is invoked by gRPC on a *separate* thread, so the + underlying ``CredentialHelper`` must use ``threading.Lock`` (not + ``asyncio.Lock``). + """ + + def __init__(self, credential_helper: CredentialHelper) -> None: + """Initialize with a credential helper. + + Args: + ---- + credential_helper: The helper that manages token lifecycle + + """ + self._credential_helper: CredentialHelper = credential_helper + + @override + def __call__( + self, + _: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: + """Supply authorization metadata for an RPC. + + Called by the gRPC runtime on a background thread before each + RPC. On success the token is forwarded as a Bearer token; on + failure the error is passed to the callback so gRPC can surface + it as an RPC error. + + Args: + ---- + callback: gRPC callback to receive metadata or an error + + """ + try: + token = self._credential_helper.get_token() + metadata = (("authorization", f"Bearer {token}"),) + callback(metadata, None) + except OAuthError as err: + callback(None, err) # pyright: ignore[reportArgumentType] + + async def create_channel_with_credentials( host: str, credential_helper: CredentialHelper, @@ -215,19 +227,23 @@ async def create_channel_with_credentials( Raises: ------ InvalidHostError: If host is empty - OAuthError: If OAuth authentication fails + + Note: + ---- + OAuth errors are no longer raised at channel-creation time. + Instead, they surface as RPC errors when the per-request auth + metadata plugin attempts to acquire a token. """ if not host: raise InvalidHostError(InvalidHostError.default_message) - # Get a valid token from the credential helper - token = await credential_helper.get_token() - - # Create credentials with token authentication + # Create credentials with per-RPC token refresh credentials = grpc.composite_channel_credentials( grpc.ssl_channel_credentials(), - grpc.access_token_call_credentials(token), + grpc.metadata_call_credentials( + _AutoRefreshTokenAuthMetadataPlugin(credential_helper) + ), ) # Configure gRPC options for persistent connections diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 701bd00..fb00a4d 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -7,11 +7,10 @@ import httpx import pytest -from grpc.aio import Channel from resolver_athena_client.client.channel import ( CredentialHelper, - TokenMetadataPlugin, + _AutoRefreshTokenAuthMetadataPlugin, create_channel_with_credentials, ) from resolver_athena_client.client.exceptions import ( @@ -21,23 +20,6 @@ ) -def test_token_metadata_plugin() -> None: - """Test TokenMetadataPlugin functionality.""" - test_token = "test-token" - plugin = TokenMetadataPlugin(test_token) - - # Mock callback - mock_callback = mock.Mock() - mock_context = mock.Mock() - - # Call the plugin - plugin(mock_context, mock_callback) - - # Verify the callback was called with correct metadata - expected_metadata = (("authorization", f"Token {test_token}"),) - mock_callback.assert_called_once_with(expected_metadata, None) - - @pytest.mark.asyncio async def test_create_channel_with_credentials_validation() -> None: """Test channel creation with credentials validates input properly.""" @@ -50,16 +32,23 @@ async def test_create_channel_with_credentials_validation() -> None: @pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_failure() -> None: - """Test channel creation when OAuth token acquisition fails.""" +async def test_create_channel_does_not_eagerly_fetch_token() -> None: + """Channel creation must NOT call get_token() eagerly.""" test_host = "test-host:50051" mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("Token acquisition failed") - with pytest.raises(OAuthError, match="Token acquisition failed"): + with ( + mock.patch("grpc.ssl_channel_credentials"), + mock.patch("grpc.metadata_call_credentials"), + mock.patch("grpc.composite_channel_credentials"), + mock.patch("grpc.aio.secure_channel"), + ): _ = await create_channel_with_credentials(test_host, mock_helper) + # Token should NOT be fetched at channel creation time + mock_helper.get_token.assert_not_called() + class TestCredentialHelper: """Test cases for CredentialHelper OAuth functionality.""" @@ -153,8 +142,7 @@ def test_is_token_valid_with_soon_expiring_token(self) -> None: assert not helper._is_token_valid() - @pytest.mark.asyncio - async def test_get_token_success(self) -> None: + def test_get_token_success(self) -> None: """Test successful token acquisition.""" helper = CredentialHelper( client_id="test_client_id", @@ -168,18 +156,17 @@ async def test_get_token_success(self) -> None: } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response - token = await helper.get_token() + token = helper.get_token() assert token == "new_access_token" assert helper._token == "new_access_token" assert helper._token_expires_at is not None - @pytest.mark.asyncio - async def test_get_token_cached(self) -> None: + def test_get_token_cached(self) -> None: """Test that cached token is returned when valid.""" helper = CredentialHelper( client_id="test_client_id", @@ -190,12 +177,11 @@ async def test_get_token_cached(self) -> None: helper._token = "cached_token" helper._token_expires_at = time.time() + 3600 - token = await helper.get_token() + token = helper.get_token() assert token == "cached_token" - @pytest.mark.asyncio - async def test_refresh_token_http_error(self) -> None: + def test_refresh_token_http_error(self) -> None: """Test token refresh with HTTP error.""" helper = CredentialHelper( client_id="test_client_id", @@ -215,17 +201,16 @@ async def test_refresh_token_http_error(self) -> None: response=mock_response, ) - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = http_error with pytest.raises( OAuthError, match="OAuth request failed with status 401" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_request_error(self) -> None: + def test_refresh_token_request_error(self) -> None: """Test token refresh with request error.""" helper = CredentialHelper( client_id="test_client_id", @@ -234,17 +219,16 @@ async def test_refresh_token_request_error(self) -> None: request_error = httpx.RequestError("Connection failed") - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = request_error with pytest.raises( OAuthError, match="Failed to connect to OAuth server" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_invalid_response(self) -> None: + def test_refresh_token_invalid_response(self) -> None: """Test token refresh with invalid response format.""" helper = CredentialHelper( client_id="test_client_id", @@ -257,17 +241,16 @@ async def test_refresh_token_invalid_response(self) -> None: } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response with pytest.raises( OAuthError, match="Invalid OAuth response format" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_invalidate_token(self) -> None: + def test_invalidate_token(self) -> None: """Test token invalidation.""" helper = CredentialHelper( client_id="test_client_id", @@ -278,64 +261,78 @@ async def test_invalidate_token(self) -> None: helper._token = "valid_token" helper._token_expires_at = time.time() + 3600 - await helper.invalidate_token() + helper.invalidate_token() assert helper._token is None assert helper._token_expires_at is None + def test_get_token_refreshes_after_invalidation(self) -> None: + """Test that get_token refreshes after invalidation.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) -@pytest.mark.asyncio -async def test_create_channel_with_credentials_success() -> None: - """Test successful channel creation with credential helper.""" - test_host = "test-host:50051" + # Set up a valid token, then invalidate it + helper._token = "old_token" + helper._token_expires_at = time.time() + 3600 + helper.invalidate_token() - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.return_value = "test_token" + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "refreshed_token", + "expires_in": 3600, + } + mock_response.raise_for_status.return_value = None - mock_credentials = mock.Mock() - mock_channel = mock.Mock(spec=Channel) + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + token = helper.get_token() + + assert token == "refreshed_token" - with ( - mock.patch("grpc.ssl_channel_credentials") as mock_ssl_creds, - mock.patch("grpc.access_token_call_credentials") as mock_token_creds, - mock.patch( - "grpc.composite_channel_credentials" - ) as mock_composite_creds, - mock.patch("grpc.aio.secure_channel") as mock_secure_channel, - ): - # Set up mocks - mock_ssl_creds.return_value = mock.Mock() - mock_token_creds.return_value = mock.Mock() - mock_composite_creds.return_value = mock_credentials - mock_secure_channel.return_value = mock_channel - # Create channel - channel = await create_channel_with_credentials(test_host, mock_helper) +class TestAutoRefreshTokenAuthMetadataPlugin: + """Tests for the per-RPC auth metadata plugin.""" + + def test_plugin_passes_bearer_token_to_callback(self) -> None: + """Plugin fetches token and passes Bearer metadata.""" + mock_helper = mock.Mock(spec=CredentialHelper) + mock_helper.get_token.return_value = "test-bearer-token" + + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() + + plugin(mock_context, mock_callback) - # Verify channel creation - assert channel == mock_channel mock_helper.get_token.assert_called_once() - mock_token_creds.assert_called_once_with("test_token") + expected_metadata = (("authorization", "Bearer test-bearer-token"),) + mock_callback.assert_called_once_with(expected_metadata, None) + def test_plugin_passes_oauth_error_to_callback(self) -> None: + """Test that OAuthError is forwarded to the callback as an error.""" + mock_helper = mock.Mock(spec=CredentialHelper) + oauth_error = OAuthError("token acquisition failed") + mock_helper.get_token.side_effect = oauth_error -@pytest.mark.asyncio -async def test_create_channel_with_credentials_invalid_host() -> None: - """Test channel creation with credentials and invalid host raises error.""" - test_host = "" # Invalid host + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() - mock_helper = mock.Mock(spec=CredentialHelper) + plugin(mock_context, mock_callback) - with pytest.raises(InvalidHostError, match="host cannot be empty"): - _ = await create_channel_with_credentials(test_host, mock_helper) + mock_callback.assert_called_once_with(None, oauth_error) @pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_error() -> None: - """Test channel creation with credentials when OAuth fails.""" - test_host = "test-host:50051" +async def test_create_channel_with_credentials_invalid_host() -> None: + """Test channel creation with credentials and invalid host raises error.""" + test_host = "" # Invalid host mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("OAuth failed") - with pytest.raises(OAuthError, match="OAuth failed"): + with pytest.raises(InvalidHostError, match="host cannot be empty"): _ = await create_channel_with_credentials(test_host, mock_helper) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 6401044..4d58a8c 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -71,22 +71,13 @@ async def credential_helper() -> CredentialHelper: audience = os.getenv("OAUTH_AUDIENCE", "crisp-athena-live") # Create credential helper - credential_helper = CredentialHelper( + return CredentialHelper( client_id=client_id, client_secret=client_secret, auth_url=auth_url, audience=audience, ) - # Test token acquisition - try: - _ = await credential_helper.get_token() - except Exception as e: - msg = "Failed to acquire OAuth token" - raise AssertionError(msg) from e - - return credential_helper - @pytest.fixture def athena_options() -> AthenaOptions: From 8ab55e1f944f2e04bc0a8ba52b1357ea528acc85 Mon Sep 17 00:00:00 2001 From: Carl Patchett Date: Tue, 17 Feb 2026 13:26:54 +0000 Subject: [PATCH 02/12] Introduce TokenData class --- src/resolver_athena_client/client/__init__.py | 2 + src/resolver_athena_client/client/channel.py | 92 ++++++---- tests/client/test_channel.py | 167 ++++++++++++++---- 3 files changed, 190 insertions(+), 71 deletions(-) diff --git a/src/resolver_athena_client/client/__init__.py b/src/resolver_athena_client/client/__init__.py index 94be685..4ad131c 100644 --- a/src/resolver_athena_client/client/__init__.py +++ b/src/resolver_athena_client/client/__init__.py @@ -5,6 +5,7 @@ from resolver_athena_client.client.channel import ( CredentialHelper, + TokenData, create_channel_with_credentials, ) from resolver_athena_client.client.exceptions import ( @@ -26,6 +27,7 @@ "CredentialError", "CredentialHelper", "OAuthError", + "TokenData", "TokenExpiredError", "create_channel_with_credentials", "get_output_error_summary", diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index e7ede64..9fe906d 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -3,7 +3,7 @@ import json import threading import time -from typing import override +from typing import NamedTuple, override import grpc import httpx @@ -16,6 +16,24 @@ ) +class TokenData(NamedTuple): + """Immutable snapshot of token state. + + Storing token, expiry, and scheme together as a single object + ensures that validity checks and token reads are always consistent, + eliminating TOCTOU races between ``get_token`` and + ``invalidate_token``. + """ + + access_token: str + expires_at: float + scheme: str + + def is_valid(self) -> bool: + """Check if this token is still valid (with a 30-second buffer).""" + return time.time() < (self.expires_at - 30) + + class CredentialHelper: """OAuth credential helper for managing authentication tokens.""" @@ -47,20 +65,19 @@ def __init__( self._client_secret: str = client_secret self._auth_url: str = auth_url self._audience: str = audience - self._token: str | None = None - self._token_expires_at: float | None = None + self._token_data: TokenData | None = None self._lock: threading.Lock = threading.Lock() - def get_token(self) -> str: - """Get a valid authentication token. + def get_token(self) -> TokenData: + """Get valid token data, refreshing if necessary. - Returns a cached token if still valid, or acquires a new one. - The entire check-and-refresh cycle runs under a single lock - acquisition to prevent TOCTOU races with ``invalidate_token``. + Uses double-checked locking: the happy path (token is valid) + avoids acquiring the lock entirely. The lock is only taken + when the token needs to be refreshed. Returns ------- - A valid authentication token + A valid ``TokenData`` containing access token, expiry, and scheme Raises ------ @@ -68,29 +85,22 @@ def get_token(self) -> str: RuntimeError: If token is unexpectedly None after refresh """ - with self._lock: - if not self._is_token_valid(): - self._refresh_token() - - token = self._token - if token is None: - msg = "Token is unexpectedly None after validity check" - raise RuntimeError(msg) - return token - - def _is_token_valid(self) -> bool: - """Check if the current token is valid and not expired. + token_data = self._token_data + if token_data is not None and token_data.is_valid(): + return token_data - Returns - ------- - True if token is valid, False otherwise + with self._lock: + token_data = self._token_data + if token_data is not None and token_data.is_valid(): + return token_data - """ - if not self._token or not self._token_expires_at: - return False + self._refresh_token() - # Add 30 second buffer before expiration - return time.time() < (self._token_expires_at - 30) + token_data = self._token_data + if token_data is None: + msg = "Token is unexpectedly None after refresh" + raise RuntimeError(msg) + return token_data def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. @@ -122,10 +132,15 @@ def _refresh_token(self) -> None: ) _ = response.raise_for_status() - token_data = response.json() - self._token = token_data["access_token"] - expires_in = token_data.get("expires_in", 3600) # Default 1 hour - self._token_expires_at = time.time() + expires_in + raw = response.json() + access_token: str = raw["access_token"] + expires_in: int = raw.get("expires_in", 3600) # Default 1 hour + scheme: str = raw.get("token_type", "Bearer").capitalize() + self._token_data = TokenData( + access_token=access_token, + expires_at=time.time() + expires_in, + scheme=scheme, + ) except httpx.HTTPStatusError as e: error_detail = "" @@ -159,8 +174,7 @@ def _refresh_token(self) -> None: def invalidate_token(self) -> None: """Invalidate the current token to force a refresh on next use.""" with self._lock: - self._token = None - self._token_expires_at = None + self._token_data = None class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): @@ -202,10 +216,12 @@ def __call__( """ try: - token = self._credential_helper.get_token() - metadata = (("authorization", f"Bearer {token}"),) + token_data = self._credential_helper.get_token() + scheme = token_data.scheme + token = token_data.access_token + metadata = (("authorization", f"{scheme} {token}"),) callback(metadata, None) - except OAuthError as err: + except Exception as err: # noqa: BLE001 callback(None, err) # pyright: ignore[reportArgumentType] diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index fb00a4d..aab33df 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -10,6 +10,7 @@ from resolver_athena_client.client.channel import ( CredentialHelper, + TokenData, _AutoRefreshTokenAuthMetadataPlugin, create_channel_with_credentials, ) @@ -64,8 +65,7 @@ def test_init_with_valid_params(self) -> None: assert helper._client_secret == "test_client_secret" assert helper._auth_url == "https://crispthinking.auth0.com/oauth/token" assert helper._audience == "crisp-athena-live" - assert helper._token is None - assert helper._token_expires_at is None + assert helper._token_data is None def test_init_with_custom_params(self) -> None: """Test CredentialHelper initialization with custom parameters.""" @@ -98,49 +98,58 @@ def test_init_with_empty_client_secret(self) -> None: ) def test_is_token_valid_with_no_token(self) -> None: - """Test _is_token_valid returns False when no token is set.""" + """Test token is not valid when no token data is set.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - assert not helper._is_token_valid() + assert helper._token_data is None def test_is_token_valid_with_expired_token(self) -> None: - """Test _is_token_valid returns False when token is expired.""" + """Test TokenData.is_valid returns False when token is expired.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - helper._token = "test_token" - helper._token_expires_at = time.time() - 100 # Expired + helper._token_data = TokenData( + access_token="test_token", + expires_at=time.time() - 100, + scheme="Bearer", + ) - assert not helper._is_token_valid() + assert not helper._token_data.is_valid() def test_is_token_valid_with_valid_token(self) -> None: - """Test _is_token_valid returns True when token is valid.""" + """Test TokenData.is_valid returns True when token is valid.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - helper._token = "test_token" - helper._token_expires_at = time.time() + 3600 # Valid for 1 hour + helper._token_data = TokenData( + access_token="test_token", + expires_at=time.time() + 3600, + scheme="Bearer", + ) - assert helper._is_token_valid() + assert helper._token_data.is_valid() def test_is_token_valid_with_soon_expiring_token(self) -> None: - """Test _is_token_valid returns False when token expires soon.""" + """Test is_valid returns False when token expires within 30s.""" helper = CredentialHelper( client_id="test_client_id", client_secret="test_client_secret", ) - helper._token = "test_token" - helper._token_expires_at = time.time() + 20 # Expires in 20 seconds + helper._token_data = TokenData( + access_token="test_token", + expires_at=time.time() + 20, + scheme="Bearer", + ) - assert not helper._is_token_valid() + assert not helper._token_data.is_valid() def test_get_token_success(self) -> None: """Test successful token acquisition.""" @@ -153,6 +162,7 @@ def test_get_token_success(self) -> None: mock_response.json.return_value = { "access_token": "new_access_token", "expires_in": 3600, + "token_type": "bearer", } mock_response.raise_for_status.return_value = None @@ -160,11 +170,57 @@ def test_get_token_success(self) -> None: mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response - token = helper.get_token() + token_data = helper.get_token() + + assert token_data.access_token == "new_access_token" + assert token_data.scheme == "Bearer" + assert helper._token_data is not None + assert helper._token_data.expires_at is not None + + def test_get_token_respects_token_type(self) -> None: + """Test that token_type from OAuth response is respected.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "some_token", + "expires_in": 3600, + "token_type": "DPoP", + } + mock_response.raise_for_status.return_value = None - assert token == "new_access_token" - assert helper._token == "new_access_token" - assert helper._token_expires_at is not None + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + assert token_data.scheme == "Dpop" + + def test_get_token_defaults_to_bearer(self) -> None: + """Test that scheme defaults to Bearer when token_type is absent.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "some_token", + "expires_in": 3600, + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + assert token_data.scheme == "Bearer" def test_get_token_cached(self) -> None: """Test that cached token is returned when valid.""" @@ -174,12 +230,15 @@ def test_get_token_cached(self) -> None: ) # Set up a valid cached token - helper._token = "cached_token" - helper._token_expires_at = time.time() + 3600 + helper._token_data = TokenData( + access_token="cached_token", + expires_at=time.time() + 3600, + scheme="Bearer", + ) - token = helper.get_token() + token_data = helper.get_token() - assert token == "cached_token" + assert token_data.access_token == "cached_token" def test_refresh_token_http_error(self) -> None: """Test token refresh with HTTP error.""" @@ -258,13 +317,15 @@ def test_invalidate_token(self) -> None: ) # Set up a valid token - helper._token = "valid_token" - helper._token_expires_at = time.time() + 3600 + helper._token_data = TokenData( + access_token="valid_token", + expires_at=time.time() + 3600, + scheme="Bearer", + ) helper.invalidate_token() - assert helper._token is None - assert helper._token_expires_at is None + assert helper._token_data is None def test_get_token_refreshes_after_invalidation(self) -> None: """Test that get_token refreshes after invalidation.""" @@ -274,14 +335,18 @@ def test_get_token_refreshes_after_invalidation(self) -> None: ) # Set up a valid token, then invalidate it - helper._token = "old_token" - helper._token_expires_at = time.time() + 3600 + helper._token_data = TokenData( + access_token="old_token", + expires_at=time.time() + 3600, + scheme="Bearer", + ) helper.invalidate_token() mock_response = mock.Mock() mock_response.json.return_value = { "access_token": "refreshed_token", "expires_in": 3600, + "token_type": "bearer", } mock_response.raise_for_status.return_value = None @@ -289,9 +354,9 @@ def test_get_token_refreshes_after_invalidation(self) -> None: mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response - token = helper.get_token() + token_data = helper.get_token() - assert token == "refreshed_token" + assert token_data.access_token == "refreshed_token" class TestAutoRefreshTokenAuthMetadataPlugin: @@ -300,7 +365,11 @@ class TestAutoRefreshTokenAuthMetadataPlugin: def test_plugin_passes_bearer_token_to_callback(self) -> None: """Plugin fetches token and passes Bearer metadata.""" mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.return_value = "test-bearer-token" + mock_helper.get_token.return_value = TokenData( + access_token="test-bearer-token", + expires_at=time.time() + 3600, + scheme="Bearer", + ) plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) mock_callback = mock.Mock() @@ -312,6 +381,24 @@ def test_plugin_passes_bearer_token_to_callback(self) -> None: expected_metadata = (("authorization", "Bearer test-bearer-token"),) mock_callback.assert_called_once_with(expected_metadata, None) + def test_plugin_respects_token_scheme(self) -> None: + """Plugin uses the scheme from TokenData, not hardcoded Bearer.""" + mock_helper = mock.Mock(spec=CredentialHelper) + mock_helper.get_token.return_value = TokenData( + access_token="dpop-token", + expires_at=time.time() + 3600, + scheme="Dpop", + ) + + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() + + plugin(mock_context, mock_callback) + + expected_metadata = (("authorization", "Dpop dpop-token"),) + mock_callback.assert_called_once_with(expected_metadata, None) + def test_plugin_passes_oauth_error_to_callback(self) -> None: """Test that OAuthError is forwarded to the callback as an error.""" mock_helper = mock.Mock(spec=CredentialHelper) @@ -326,6 +413,20 @@ def test_plugin_passes_oauth_error_to_callback(self) -> None: mock_callback.assert_called_once_with(None, oauth_error) + def test_plugin_catches_unexpected_exceptions(self) -> None: + """Non-OAuthError exceptions are forwarded to callback.""" + mock_helper = mock.Mock(spec=CredentialHelper) + runtime_error = RuntimeError("unexpected failure") + mock_helper.get_token.side_effect = runtime_error + + plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) + mock_callback = mock.Mock() + mock_context = mock.Mock() + + plugin(mock_context, mock_callback) + + mock_callback.assert_called_once_with(None, runtime_error) + @pytest.mark.asyncio async def test_create_channel_with_credentials_invalid_host() -> None: From 9ab122c3bc02bbab3022861c1601032bc625001a Mon Sep 17 00:00:00 2001 From: Carl Patchett Date: Tue, 17 Feb 2026 14:54:03 +0000 Subject: [PATCH 03/12] Passing () to avoid type checking problems --- src/resolver_athena_client/client/channel.py | 2 +- tests/client/test_channel.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index 9fe906d..c1183f9 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -222,7 +222,7 @@ def __call__( metadata = (("authorization", f"{scheme} {token}"),) callback(metadata, None) except Exception as err: # noqa: BLE001 - callback(None, err) # pyright: ignore[reportArgumentType] + callback((), err) async def create_channel_with_credentials( diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index aab33df..b5f64f4 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -411,7 +411,7 @@ def test_plugin_passes_oauth_error_to_callback(self) -> None: plugin(mock_context, mock_callback) - mock_callback.assert_called_once_with(None, oauth_error) + mock_callback.assert_called_once_with((), oauth_error) def test_plugin_catches_unexpected_exceptions(self) -> None: """Non-OAuthError exceptions are forwarded to callback.""" @@ -425,7 +425,7 @@ def test_plugin_catches_unexpected_exceptions(self) -> None: plugin(mock_context, mock_callback) - mock_callback.assert_called_once_with(None, runtime_error) + mock_callback.assert_called_once_with((), runtime_error) @pytest.mark.asyncio From 1b3e2a26d35b4ff1ab2543f11eb463a44094b608 Mon Sep 17 00:00:00 2001 From: Carl Patchett Date: Wed, 18 Feb 2026 15:27:52 +0000 Subject: [PATCH 04/12] Test fixups --- tests/functional/test_invalid_oauth.py | 29 +++++++------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/tests/functional/test_invalid_oauth.py b/tests/functional/test_invalid_oauth.py index 5e044e7..bc61c4c 100644 --- a/tests/functional/test_invalid_oauth.py +++ b/tests/functional/test_invalid_oauth.py @@ -3,17 +3,12 @@ import pytest from dotenv import load_dotenv -from resolver_athena_client.client.athena_options import AthenaOptions -from resolver_athena_client.client.channel import ( - CredentialHelper, - create_channel_with_credentials, -) +from resolver_athena_client.client.channel import CredentialHelper from resolver_athena_client.client.exceptions import OAuthError -@pytest.mark.asyncio @pytest.mark.functional -async def test_invalid_secret(athena_options: AthenaOptions) -> None: +def test_invalid_secret() -> None: """Test that an invalid OAuth client secret is rejected.""" _ = load_dotenv() invalid_client_secret = "this_is_not_a_valid_secret" @@ -31,15 +26,12 @@ async def test_invalid_secret(athena_options: AthenaOptions) -> None: ) with pytest.raises(OAuthError): - _ = await create_channel_with_credentials( - athena_options.host, credential_helper=credential_helper - ) + credential_helper.get_token() -@pytest.mark.asyncio @pytest.mark.functional -async def test_invalid_clientid(athena_options: AthenaOptions) -> None: - """Test that an invalid OAuth client secret is rejected.""" +def test_invalid_clientid() -> None: + """Test that an invalid OAuth client ID is rejected.""" _ = load_dotenv() client_secret = os.environ["OAUTH_CLIENT_SECRET"] client_id = "this_is_not_a_valid_client_id" @@ -56,14 +48,11 @@ async def test_invalid_clientid(athena_options: AthenaOptions) -> None: ) with pytest.raises(OAuthError): - _ = await create_channel_with_credentials( - athena_options.host, credential_helper=credential_helper - ) + credential_helper.get_token() -@pytest.mark.asyncio @pytest.mark.functional -async def test_invalid_audience(athena_options: AthenaOptions) -> None: +def test_invalid_audience() -> None: """Test that an invalid OAuth audience is rejected.""" _ = load_dotenv() client_secret = os.environ["OAUTH_CLIENT_SECRET"] @@ -81,6 +70,4 @@ async def test_invalid_audience(athena_options: AthenaOptions) -> None: ) with pytest.raises(OAuthError): - _ = await create_channel_with_credentials( - athena_options.host, credential_helper=credential_helper - ) + credential_helper.get_token() From 60032d8e240b6a3eb73edba21df252ec92765cfc Mon Sep 17 00:00:00 2001 From: Carl Patchett Date: Wed, 18 Feb 2026 16:28:11 +0000 Subject: [PATCH 05/12] fix: update test cases to suppress unused return value warnings, now only scanning relevant files --- pyproject.toml | 3 ++- tests/functional/test_invalid_oauth.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7ba0386..10d618e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,8 @@ ignore = ["COM812", "D213", "D211", "D203", "S324", "ASYNC109"] "tests/**" = ["D", "S106", "S105", "S101", "SLF001"] [tool.pyright] -exclude = ["src/resolver_athena_client/generated/*", ".venv"] +include = ["src", "tests"] +exclude = ["src/resolver_athena_client/generated/*"] venvPath = "." venv = ".venv" stubPath = "stubs" diff --git a/tests/functional/test_invalid_oauth.py b/tests/functional/test_invalid_oauth.py index bc61c4c..a2cfbbb 100644 --- a/tests/functional/test_invalid_oauth.py +++ b/tests/functional/test_invalid_oauth.py @@ -26,7 +26,7 @@ def test_invalid_secret() -> None: ) with pytest.raises(OAuthError): - credential_helper.get_token() + _ = credential_helper.get_token() @pytest.mark.functional @@ -48,7 +48,7 @@ def test_invalid_clientid() -> None: ) with pytest.raises(OAuthError): - credential_helper.get_token() + _ = credential_helper.get_token() @pytest.mark.functional @@ -70,4 +70,4 @@ def test_invalid_audience() -> None: ) with pytest.raises(OAuthError): - credential_helper.get_token() + _ = credential_helper.get_token() From 4b7e73a6ae272ccab412047af37e1441b4653ddc Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Feb 2026 10:39:10 +0000 Subject: [PATCH 06/12] Address PR review feedback: preserve token scheme casing and fix documentation (#106) * Initial plan * fix: address PR review comments - token scheme, docs, and tests Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> --- docs/api/exceptions.rst | 3 ++- docs/authentication.rst | 11 ++++++----- pyproject.toml | 2 +- src/resolver_athena_client/client/channel.py | 11 +++++++---- tests/client/test_channel.py | 11 ----------- 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/docs/api/exceptions.rst b/docs/api/exceptions.rst index 75419ee..8a4fc9f 100644 --- a/docs/api/exceptions.rst +++ b/docs/api/exceptions.rst @@ -129,7 +129,8 @@ OAuth Error Handling client_id=client_id, client_secret=client_secret ) - token = credential_helper.get_token() + token_data = credential_helper.get_token() + access_token = token_data.access_token except OAuthError as e: logger.error(f"OAuth authentication failed: {e}") # Handle OAuth failure - check credentials diff --git a/docs/authentication.rst b/docs/authentication.rst index 49d147f..0f1ab53 100644 --- a/docs/authentication.rst +++ b/docs/authentication.rst @@ -250,11 +250,12 @@ Handle OAuth-specific errors gracefully: .. code-block:: python - from resolver_athena_client.client.exceptions import AuthenticationError + from resolver_athena_client.client.exceptions import OAuthError try: - token = credential_helper.get_token() - except AuthenticationError as e: + token_data = credential_helper.get_token() + access_token = token_data.access_token + except OAuthError as e: logger.error(f"OAuth authentication failed: {e}") # Handle authentication failure except Exception as e: @@ -348,8 +349,8 @@ Test your authentication setup: client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) - token = credential_helper.get_token() - print(f"✓ Authentication successful (token length: {len(token)})") + token_data = credential_helper.get_token() + print(f"✓ Authentication successful (token length: {len(token_data.access_token)})") return True except Exception as e: diff --git a/pyproject.toml b/pyproject.toml index 10d618e..b1beb2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ ignore = ["COM812", "D213", "D211", "D203", "S324", "ASYNC109"] "tests/**" = ["D", "S106", "S105", "S101", "SLF001"] [tool.pyright] -include = ["src", "tests"] +include = ["src", "tests", "examples", "common_utils"] exclude = ["src/resolver_athena_client/generated/*"] venvPath = "." venv = ".venv" diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index c1183f9..00131a0 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -135,7 +135,9 @@ def _refresh_token(self) -> None: raw = response.json() access_token: str = raw["access_token"] expires_in: int = raw.get("expires_in", 3600) # Default 1 hour - scheme: str = raw.get("token_type", "Bearer").capitalize() + token_type = raw.get("token_type", "Bearer") + # Preserve server-provided casing, only strip whitespace + scheme: str = token_type.strip() if token_type else "Bearer" self._token_data = TokenData( access_token=access_token, expires_at=time.time() + expires_in, @@ -206,9 +208,10 @@ def __call__( """Supply authorization metadata for an RPC. Called by the gRPC runtime on a background thread before each - RPC. On success the token is forwarded as a Bearer token; on - failure the error is passed to the callback so gRPC can surface - it as an RPC error. + RPC. On success the token is forwarded using the scheme from + the OAuth token response (typically ``Bearer``); on failure + the error is passed to the callback so gRPC can surface it as + an RPC error. Args: ---- diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index b5f64f4..221907d 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -426,14 +426,3 @@ def test_plugin_catches_unexpected_exceptions(self) -> None: plugin(mock_context, mock_callback) mock_callback.assert_called_once_with((), runtime_error) - - -@pytest.mark.asyncio -async def test_create_channel_with_credentials_invalid_host() -> None: - """Test channel creation with credentials and invalid host raises error.""" - test_host = "" # Invalid host - - mock_helper = mock.Mock(spec=CredentialHelper) - - with pytest.raises(InvalidHostError, match="host cannot be empty"): - _ = await create_channel_with_credentials(test_host, mock_helper) From a7a7286ceaa39aaabcf593dc1bcd160409d2da0f Mon Sep 17 00:00:00 2001 From: anna-singleton-resolver Date: Fri, 20 Feb 2026 14:24:47 +0000 Subject: [PATCH 07/12] chore: remove 'include' and revert removal of .venv from exclude list for pyright --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1beb2a..7ba0386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,8 +71,7 @@ ignore = ["COM812", "D213", "D211", "D203", "S324", "ASYNC109"] "tests/**" = ["D", "S106", "S105", "S101", "SLF001"] [tool.pyright] -include = ["src", "tests", "examples", "common_utils"] -exclude = ["src/resolver_athena_client/generated/*"] +exclude = ["src/resolver_athena_client/generated/*", ".venv"] venvPath = "." venv = ".venv" stubPath = "stubs" From 52e8b009227096b43a6abc75967577386463e6d1 Mon Sep 17 00:00:00 2001 From: anna-singleton-resolver Date: Fri, 20 Feb 2026 14:35:28 +0000 Subject: [PATCH 08/12] chore: remove some extraneous comments --- src/resolver_athena_client/client/channel.py | 24 ++------------------ 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index 00131a0..56bd345 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -17,13 +17,7 @@ class TokenData(NamedTuple): - """Immutable snapshot of token state. - - Storing token, expiry, and scheme together as a single object - ensures that validity checks and token reads are always consistent, - eliminating TOCTOU races between ``get_token`` and - ``invalidate_token``. - """ + """Immutable snapshot of token state.""" access_token: str expires_at: float @@ -136,7 +130,6 @@ def _refresh_token(self) -> None: access_token: str = raw["access_token"] expires_in: int = raw.get("expires_in", 3600) # Default 1 hour token_type = raw.get("token_type", "Bearer") - # Preserve server-provided casing, only strip whitespace scheme: str = token_type.strip() if token_type else "Bearer" self._token_data = TokenData( access_token=access_token, @@ -180,14 +173,7 @@ def invalidate_token(self) -> None: class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): - """gRPC auth plugin that fetches a fresh token for every RPC. - - The plugin delegates to ``CredentialHelper.get_token()`` which - handles caching, expiry checks, and thread-safe refresh internally. - This callback is invoked by gRPC on a *separate* thread, so the - underlying ``CredentialHelper`` must use ``threading.Lock`` (not - ``asyncio.Lock``). - """ + """gRPC auth plugin that fetches a fresh token for every RPC.""" def __init__(self, credential_helper: CredentialHelper) -> None: """Initialize with a credential helper. @@ -247,12 +233,6 @@ async def create_channel_with_credentials( ------ InvalidHostError: If host is empty - Note: - ---- - OAuth errors are no longer raised at channel-creation time. - Instead, they surface as RPC errors when the per-request auth - metadata plugin attempts to acquire a token. - """ if not host: raise InvalidHostError(InvalidHostError.default_message) From 0c7c3b99d2e4e063b739e639becb8ccd64629d05 Mon Sep 17 00:00:00 2001 From: anna-singleton-resolver Date: Fri, 20 Feb 2026 14:40:33 +0000 Subject: [PATCH 09/12] test: consistent casing in assertions and provided mock data --- tests/client/test_channel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 221907d..4d17ce5 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -162,7 +162,7 @@ def test_get_token_success(self) -> None: mock_response.json.return_value = { "access_token": "new_access_token", "expires_in": 3600, - "token_type": "bearer", + "token_type": "Bearer", } mock_response.raise_for_status.return_value = None @@ -198,7 +198,7 @@ def test_get_token_respects_token_type(self) -> None: token_data = helper.get_token() - assert token_data.scheme == "Dpop" + assert token_data.scheme == "DPoP" def test_get_token_defaults_to_bearer(self) -> None: """Test that scheme defaults to Bearer when token_type is absent.""" From f27f6a4ac05c37c826f54905991c88cd2b041a2c Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:44:48 +0000 Subject: [PATCH 10/12] Address review feedback: require issued_at field and prevent refresh stampede (#107) * Initial plan * Implement background token refresh for old tokens - Add issued_at field to TokenData to track token creation time - Add is_old() method to check if token should be refreshed (< 50% lifetime) - Trigger non-blocking background refresh when token is old but valid - Keep blocking refresh for expired tokens - Add comprehensive tests for background refresh behavior Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> * fix: update tests to match preserved token scheme casing Tests were expecting capitalized schemes but implementation was changed in 4b7e73a to preserve server-provided casing. Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> * Address code review feedback - Add logging for background refresh failures - Improve comments explaining magic numbers - Add test for server casing preservation - Add comment explaining test tolerance Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> * Address PR review comments - Make issued_at a required field (remove default value) - Remove fallback logic for legacy tokens in is_old() - Add stampede prevention in _background_refresh - Remove obsolete tests for fallback logic - Add test for stampede prevention Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> * Revert Bearer/DPoP test assertion changes from c50327d The target branch now properly fixes the test by providing "Bearer" in the mock data. Reverting my previous fix that changed the assertion to match the old mock data. Co-authored-by: anna-singleton-resolver <199753965+anna-singleton-resolver@users.noreply.github.com> * fix: start refresh thread checks for already refreshed token * chore: remove extraneous comment * feat: reduce eagerness of background refresh from at 50% remaining to 25% remaining * doc: update is_old docstring * fix: correct boolean logic in _start_background_refresh Changed from OR to AND logic: refresh should start if (refresh_not_active AND token_needs_refresh), not if any condition is true. This prevents unnecessary refresh attempts. Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> * test: fix test assertions to be on 25% instead of 50% * feat: allow configurable proactive_refresh_threshold --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: iwillspeak <1004401+iwillspeak@users.noreply.github.com> Co-authored-by: anna-singleton-resolver <199753965+anna-singleton-resolver@users.noreply.github.com> Co-authored-by: anna-singleton-resolver --- src/resolver_athena_client/client/channel.py | 115 +++++++- tests/client/test_channel.py | 282 +++++++++++++++++++ 2 files changed, 392 insertions(+), 5 deletions(-) diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index 56bd345..9f51243 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -1,6 +1,7 @@ """Channel creation utilities for the Athena client.""" import json +import logging import threading import time from typing import NamedTuple, override @@ -15,6 +16,8 @@ OAuthError, ) +logger = logging.getLogger(__name__) + class TokenData(NamedTuple): """Immutable snapshot of token state.""" @@ -22,11 +25,34 @@ class TokenData(NamedTuple): access_token: str expires_at: float scheme: str + issued_at: float def is_valid(self) -> bool: """Check if this token is still valid (with a 30-second buffer).""" return time.time() < (self.expires_at - 30) + def is_old(self, proactive_refresh_threshold: float) -> bool: + """Check if this token should be proactively refreshed. + + A token is considered "old" if less than the + proactive_refresh_threshold of its lifetime remains. This allows + background refresh to happen before expiry while the token is still + usable. + + Args: + ---- + proactive_refresh_threshold: Fraction of token lifetime past which + to trigger proactive refresh (e.g. 0.25 for 25%) + + """ + if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1: + msg = "proactive_refresh_threshold must be between 0 and 1" + raise ValueError(msg) + current_time = time.time() + total_lifetime = self.expires_at - self.issued_at + time_remaining = self.expires_at - current_time + return time_remaining < (total_lifetime * proactive_refresh_threshold) + class CredentialHelper: """OAuth credential helper for managing authentication tokens.""" @@ -37,6 +63,7 @@ def __init__( client_secret: str, auth_url: str = "https://crispthinking.auth0.com/oauth/token", audience: str = "crisp-athena-live", + proactive_refresh_threshold: float = 0.25, ) -> None: """Initialize the credential helper. @@ -46,6 +73,8 @@ def __init__( client_secret: OAuth client secret auth_url: OAuth token endpoint URL audience: OAuth audience + proactive_refresh_threshold: Fraction of token lifetime to trigger + proactive refresh (default 0.25 for 25%) """ if not client_id: @@ -61,14 +90,17 @@ def __init__( self._audience: str = audience self._token_data: TokenData | None = None self._lock: threading.Lock = threading.Lock() + self._refresh_thread: threading.Thread | None = None + + if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1: + msg = "proactive_refresh_threshold must be a float between 0 and 1" + raise ValueError(msg) + + self._proactive_refresh_threshold: float = proactive_refresh_threshold def get_token(self) -> TokenData: """Get valid token data, refreshing if necessary. - Uses double-checked locking: the happy path (token is valid) - avoids acquiring the lock entirely. The lock is only taken - when the token needs to be refreshed. - Returns ------- A valid ``TokenData`` containing access token, expiry, and scheme @@ -80,9 +112,15 @@ def get_token(self) -> TokenData: """ token_data = self._token_data + + # Fast path: token is valid and fresh if token_data is not None and token_data.is_valid(): + # If token is old, trigger background refresh + if token_data.is_old(self._proactive_refresh_threshold): + self._start_background_refresh() return token_data + # Slow path: token is expired or missing, must block with self._lock: token_data = self._token_data if token_data is not None and token_data.is_valid(): @@ -96,6 +134,71 @@ def get_token(self) -> TokenData: raise RuntimeError(msg) return token_data + def _start_background_refresh(self) -> None: + """Start a background thread to refresh the token. + + Only starts a new thread if one isn't already running. + + This method is safe to call multiple times - it only starts a new + thread if no refresh is currently in progress. + """ + # Quick check without lock - if refresh thread exists and is + # alive, skip + if self._refresh_thread is not None and self._refresh_thread.is_alive(): + return + + # Try to acquire lock and start refresh + if self._lock.acquire(blocking=False): + try: + # Double-check: another thread might have started refresh, + # or the token may have been refreshed. + refresh_not_active = ( + self._refresh_thread is None + or not self._refresh_thread.is_alive() + ) + token_needs_refresh = ( + self._token_data is None + or self._token_data.is_old( + self._proactive_refresh_threshold + ) + ) + refresh_needed = refresh_not_active and token_needs_refresh + if refresh_needed: + self._refresh_thread = threading.Thread( + target=self._background_refresh, + daemon=True, + ) + self._refresh_thread.start() + finally: + self._lock.release() + + def _background_refresh(self) -> None: + """Background thread target for token refresh. + + Acquires the lock and refreshes the token. Errors are logged + but silently ignored since the next foreground request will + retry if needed. + """ + with self._lock: + # Check if token still needs refresh (prevent stampede) + token_data = self._token_data + if token_data is not None and not token_data.is_old( + self._proactive_refresh_threshold + ): + # Token was already refreshed by another thread + return + + try: + self._refresh_token() + except Exception as e: # noqa: BLE001 + # Log but don't raise - background refresh failures + # are recoverable (next get_token() will retry) + logger.debug( + "Background token refresh failed, " + "will retry on next request: %s", + e, + ) + def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. @@ -131,10 +234,12 @@ def _refresh_token(self) -> None: expires_in: int = raw.get("expires_in", 3600) # Default 1 hour token_type = raw.get("token_type", "Bearer") scheme: str = token_type.strip() if token_type else "Bearer" + current_time = time.time() self._token_data = TokenData( access_token=access_token, - expires_at=time.time() + expires_in, + expires_at=current_time + expires_in, scheme=scheme, + issued_at=current_time, ) except httpx.HTTPStatusError as e: diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 4d17ce5..33f8d85 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -97,6 +97,23 @@ def test_init_with_empty_client_secret(self) -> None: client_secret="", ) + @pytest.mark.parametrize( + "invalid", + [-0.1, 1.1, -0.5, 2.0], + ) + def test_init_with_invalid_proactive_refresh_threshold( + self, invalid: float + ) -> None: + with pytest.raises( + ValueError, + match="proactive_refresh_threshold must be a float between 0 and 1", + ): + _ = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + proactive_refresh_threshold=invalid, + ) + def test_is_token_valid_with_no_token(self) -> None: """Test token is not valid when no token data is set.""" helper = CredentialHelper( @@ -117,6 +134,7 @@ def test_is_token_valid_with_expired_token(self) -> None: access_token="test_token", expires_at=time.time() - 100, scheme="Bearer", + issued_at=time.time() - 3700, ) assert not helper._token_data.is_valid() @@ -132,6 +150,7 @@ def test_is_token_valid_with_valid_token(self) -> None: access_token="test_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) assert helper._token_data.is_valid() @@ -147,6 +166,7 @@ def test_is_token_valid_with_soon_expiring_token(self) -> None: access_token="test_token", expires_at=time.time() + 20, scheme="Bearer", + issued_at=time.time() - 3580, ) assert not helper._token_data.is_valid() @@ -222,6 +242,44 @@ def test_get_token_defaults_to_bearer(self) -> None: assert token_data.scheme == "Bearer" + def test_get_token_preserves_server_casing(self) -> None: + """Test that server-provided token_type casing is preserved.""" + test_cases = [ + ("Bearer", "Bearer"), + ("bearer", "bearer"), + ("BEARER", "BEARER"), + ("DPoP", "DPoP"), + ("dpop", "dpop"), + (" Bearer ", "Bearer"), # Whitespace is stripped + ] + + for server_type, expected_scheme in test_cases: + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "test_token", + "expires_in": 3600, + "token_type": server_type, + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = ( + mock_client.return_value.__enter__.return_value + ) + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + assert token_data.scheme == expected_scheme, ( + f"Expected {expected_scheme} for {server_type}, " + f"got {token_data.scheme}" + ) + def test_get_token_cached(self) -> None: """Test that cached token is returned when valid.""" helper = CredentialHelper( @@ -234,6 +292,7 @@ def test_get_token_cached(self) -> None: access_token="cached_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) token_data = helper.get_token() @@ -321,6 +380,7 @@ def test_invalidate_token(self) -> None: access_token="valid_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) helper.invalidate_token() @@ -339,6 +399,7 @@ def test_get_token_refreshes_after_invalidation(self) -> None: access_token="old_token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) helper.invalidate_token() @@ -369,6 +430,7 @@ def test_plugin_passes_bearer_token_to_callback(self) -> None: access_token="test-bearer-token", expires_at=time.time() + 3600, scheme="Bearer", + issued_at=time.time(), ) plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) @@ -388,6 +450,7 @@ def test_plugin_respects_token_scheme(self) -> None: access_token="dpop-token", expires_at=time.time() + 3600, scheme="Dpop", + issued_at=time.time(), ) plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) @@ -426,3 +489,222 @@ def test_plugin_catches_unexpected_exceptions(self) -> None: plugin(mock_context, mock_callback) mock_callback.assert_called_once_with((), runtime_error) + + +class TestBackgroundTokenRefresh: + """Tests for background token refresh functionality.""" + + def test_token_is_old_when_past_halfway_lifetime(self) -> None: + """Test that a token is considered old when past 25% of its lifetime.""" + current_time = time.time() + # Token with 1 hour lifetime, 20 minutes remaining (33%) + token = TokenData( + access_token="test_token", + expires_at=current_time + 600, # 10 minutes from now + scheme="Bearer", + issued_at=current_time - 3_000, # 50 minutes ago + ) + # Total lifetime = 3600s, remaining = 600s (1/6th), so it's old + assert token.is_old(0.25) + + def test_token_is_not_old_when_fresh(self) -> None: + """Test that a token is not old when more than 25% lifetime remains.""" + current_time = time.time() + # Token with 1 hour lifetime, 40 minutes remaining (67%) + token = TokenData( + access_token="test_token", + expires_at=current_time + 2400, # 40 minutes from now + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago + ) + # Total lifetime = 3600s, remaining = 2400s (67%), so it's fresh + assert not token.is_old(0.25) + + def test_get_token_triggers_background_refresh_for_old_token(self) -> None: + """Test that get_token triggers background refresh for old tokens.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up an old but valid token + helper._token_data = TokenData( + access_token="old_token", + expires_at=current_time + 600, # 10 minutes from now + scheme="Bearer", + issued_at=current_time - 3_000, # 50 minutes ago + ) + + with mock.patch.object( + helper, "_start_background_refresh" + ) as mock_start: + token_data = helper.get_token() + + # Should return current token immediately + assert token_data.access_token == "old_token" + # Should have triggered background refresh + mock_start.assert_called_once() + + def test_get_token_does_not_trigger_refresh_for_fresh_token(self) -> None: + """Test that get_token does not trigger refresh for fresh tokens.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up a fresh, valid token + helper._token_data = TokenData( + access_token="fresh_token", + expires_at=current_time + 2400, # 40 minutes remaining + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago, so it's fresh + ) + + with mock.patch.object( + helper, "_start_background_refresh" + ) as mock_start: + token_data = helper.get_token() + + # Should return current token + assert token_data.access_token == "fresh_token" + # Should NOT have triggered background refresh + mock_start.assert_not_called() + + def test_background_refresh_does_not_start_if_already_running(self) -> None: + """Test that background refresh doesn't start duplicate threads.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Mock a running refresh thread + mock_thread = mock.Mock() + mock_thread.is_alive.return_value = True + helper._refresh_thread = mock_thread + + with mock.patch("threading.Thread") as mock_thread_class: + helper._start_background_refresh() + + # Should not create a new thread + mock_thread_class.assert_not_called() + + def test_background_refresh_starts_new_thread_if_none_exists(self) -> None: + """Test that background refresh starts a thread when none exists.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_thread = mock.Mock() + with mock.patch("threading.Thread", return_value=mock_thread): + helper._start_background_refresh() + + # Should have started the thread + mock_thread.start.assert_called_once() + + def test_background_refresh_silently_handles_errors(self) -> None: + """Test that background refresh silently ignores errors.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Mock refresh to raise an error + with mock.patch.object( + helper, "_refresh_token", side_effect=OAuthError("Test error") + ): + # Should not raise an exception + helper._background_refresh() + + def test_background_refresh_prevents_stampede(self) -> None: + """Test background refresh skips refresh if token is fresh.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + current_time = time.time() + # Set up a fresh token (already refreshed by another thread) + helper._token_data = TokenData( + access_token="fresh_token", + expires_at=current_time + 2400, # 40 minutes remaining + scheme="Bearer", + issued_at=current_time - 1200, # 20 minutes ago, so it's fresh + ) + + # Mock refresh to track if it's called + with mock.patch.object(helper, "_refresh_token") as mock_refresh: + helper._background_refresh() + + # Should NOT have called refresh since token is fresh + mock_refresh.assert_not_called() + + def test_get_token_blocks_for_expired_token(self) -> None: + """Test that get_token blocks and refreshes when token is expired.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + # Set up an expired token + helper._token_data = TokenData( + access_token="expired_token", + expires_at=time.time() - 100, # Expired + scheme="Bearer", + issued_at=time.time() - 3700, + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "new_token", + "expires_in": 3600, + "token_type": "bearer", + } + mock_response.raise_for_status.return_value = None + + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + token_data = helper.get_token() + + # Should have refreshed and returned new token + assert token_data.access_token == "new_token" + # Should have called the OAuth endpoint + mock_response_obj.post.assert_called_once() + + def test_refresh_token_sets_issued_at(self) -> None: + """Test that _refresh_token sets the issued_at timestamp.""" + helper = CredentialHelper( + client_id="test_client_id", + client_secret="test_client_secret", + ) + + mock_response = mock.Mock() + mock_response.json.return_value = { + "access_token": "new_token", + "expires_in": 3600, + "token_type": "bearer", + } + mock_response.raise_for_status.return_value = None + + before_time = time.time() + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value + mock_response_obj.post.return_value = mock_response + + _ = helper.get_token() + + after_time = time.time() + + # Check that issued_at was set to a reasonable value + assert helper._token_data is not None + assert before_time <= helper._token_data.issued_at <= after_time + # Check that expires_at is approximately issued_at + 3600 + # Allow 1 second tolerance for test execution time + assert ( + helper._token_data.expires_at - helper._token_data.issued_at - 3600 + < 1 + ) From 05597ca9809fb1ac30e7ea6aaef80c2d765b149d Mon Sep 17 00:00:00 2001 From: anna-singleton-resolver Date: Mon, 23 Feb 2026 11:38:22 +0000 Subject: [PATCH 11/12] refactor: TokenData as frozen dataclass and param validity checks at init --- src/resolver_athena_client/client/channel.py | 37 ++++++++++++-------- tests/client/test_channel.py | 5 +-- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index 9f51243..367bf78 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -4,7 +4,8 @@ import logging import threading import time -from typing import NamedTuple, override +from dataclasses import dataclass +from typing import override import grpc import httpx @@ -19,19 +20,30 @@ logger = logging.getLogger(__name__) -class TokenData(NamedTuple): +@dataclass(frozen=True) +class TokenData: """Immutable snapshot of token state.""" access_token: str expires_at: float scheme: str issued_at: float + proactive_refresh_threshold: float = 0.25 + + def __post_init__(self) -> None: + """Validate that proactive_refresh_threshold is between 0 and 1.""" + if ( + self.proactive_refresh_threshold <= 0 + or self.proactive_refresh_threshold >= 1 + ): + msg = "proactive_refresh_threshold must be between 0 and 1" + raise ValueError(msg) def is_valid(self) -> bool: """Check if this token is still valid (with a 30-second buffer).""" return time.time() < (self.expires_at - 30) - def is_old(self, proactive_refresh_threshold: float) -> bool: + def is_old(self) -> bool: """Check if this token should be proactively refreshed. A token is considered "old" if less than the @@ -45,13 +57,12 @@ def is_old(self, proactive_refresh_threshold: float) -> bool: to trigger proactive refresh (e.g. 0.25 for 25%) """ - if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1: - msg = "proactive_refresh_threshold must be between 0 and 1" - raise ValueError(msg) current_time = time.time() total_lifetime = self.expires_at - self.issued_at time_remaining = self.expires_at - current_time - return time_remaining < (total_lifetime * proactive_refresh_threshold) + return time_remaining < ( + total_lifetime * self.proactive_refresh_threshold + ) class CredentialHelper: @@ -116,7 +127,7 @@ def get_token(self) -> TokenData: # Fast path: token is valid and fresh if token_data is not None and token_data.is_valid(): # If token is old, trigger background refresh - if token_data.is_old(self._proactive_refresh_threshold): + if token_data.is_old(): self._start_background_refresh() return token_data @@ -157,10 +168,7 @@ def _start_background_refresh(self) -> None: or not self._refresh_thread.is_alive() ) token_needs_refresh = ( - self._token_data is None - or self._token_data.is_old( - self._proactive_refresh_threshold - ) + self._token_data is None or self._token_data.is_old() ) refresh_needed = refresh_not_active and token_needs_refresh if refresh_needed: @@ -182,9 +190,7 @@ def _background_refresh(self) -> None: with self._lock: # Check if token still needs refresh (prevent stampede) token_data = self._token_data - if token_data is not None and not token_data.is_old( - self._proactive_refresh_threshold - ): + if token_data is not None and not token_data.is_old(): # Token was already refreshed by another thread return @@ -240,6 +246,7 @@ def _refresh_token(self) -> None: expires_at=current_time + expires_in, scheme=scheme, issued_at=current_time, + proactive_refresh_threshold=self._proactive_refresh_threshold, ) except httpx.HTTPStatusError as e: diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 33f8d85..0a538fa 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -135,6 +135,7 @@ def test_is_token_valid_with_expired_token(self) -> None: expires_at=time.time() - 100, scheme="Bearer", issued_at=time.time() - 3700, + proactive_refresh_threshold=0.25, ) assert not helper._token_data.is_valid() @@ -505,7 +506,7 @@ def test_token_is_old_when_past_halfway_lifetime(self) -> None: issued_at=current_time - 3_000, # 50 minutes ago ) # Total lifetime = 3600s, remaining = 600s (1/6th), so it's old - assert token.is_old(0.25) + assert token.is_old() def test_token_is_not_old_when_fresh(self) -> None: """Test that a token is not old when more than 25% lifetime remains.""" @@ -518,7 +519,7 @@ def test_token_is_not_old_when_fresh(self) -> None: issued_at=current_time - 1200, # 20 minutes ago ) # Total lifetime = 3600s, remaining = 2400s (67%), so it's fresh - assert not token.is_old(0.25) + assert not token.is_old() def test_get_token_triggers_background_refresh_for_old_token(self) -> None: """Test that get_token triggers background refresh for old tokens.""" From 3cff677c03bf66f6b8f547f4e507e357242c8786 Mon Sep 17 00:00:00 2001 From: anna-singleton-resolver Date: Mon, 23 Feb 2026 11:48:08 +0000 Subject: [PATCH 12/12] doc: remove old part of docstring --- src/resolver_athena_client/client/channel.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index 367bf78..497341a 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -51,11 +51,6 @@ def is_old(self) -> bool: background refresh to happen before expiry while the token is still usable. - Args: - ---- - proactive_refresh_threshold: Fraction of token lifetime past which - to trigger proactive refresh (e.g. 0.25 for 25%) - """ current_time = time.time() total_lifetime = self.expires_at - self.issued_at