Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 110 additions & 5 deletions src/resolver_athena_client/client/channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Channel creation utilities for the Athena client."""

import json
import logging
import threading
import time
from typing import NamedTuple, override
Expand All @@ -15,6 +16,8 @@
OAuthError,
)

logger = logging.getLogger(__name__)


class TokenData(NamedTuple):
"""Immutable snapshot of token state.
Expand All @@ -28,11 +31,34 @@ class TokenData(NamedTuple):
access_token: str
expires_at: float
scheme: str
issued_at: float

def is_valid(self) -> bool:
"""Check if this token is still valid (with a 30-second buffer)."""
return time.time() < (self.expires_at - 30)

def is_old(self, proactive_refresh_threshold: float) -> bool:
"""Check if this token should be proactively refreshed.

A token is considered "old" if less than the
proactive_refresh_threshold of its lifetime remains. This allows
background refresh to happen before expiry while the token is still
usable.

Args:
----
proactive_refresh_threshold: Fraction of token lifetime past which
to trigger proactive refresh (e.g. 0.25 for 25%)

"""
if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1:
msg = "proactive_refresh_threshold must be between 0 and 1"
raise ValueError(msg)
current_time = time.time()
total_lifetime = self.expires_at - self.issued_at
time_remaining = self.expires_at - current_time
return time_remaining < (total_lifetime * proactive_refresh_threshold)


class CredentialHelper:
"""OAuth credential helper for managing authentication tokens."""
Expand All @@ -43,6 +69,7 @@ def __init__(
client_secret: str,
auth_url: str = "https://crispthinking.auth0.com/oauth/token",
audience: str = "crisp-athena-live",
proactive_refresh_threshold: float = 0.25,
) -> None:
"""Initialize the credential helper.

Expand All @@ -52,6 +79,8 @@ def __init__(
client_secret: OAuth client secret
auth_url: OAuth token endpoint URL
audience: OAuth audience
proactive_refresh_threshold: Fraction of token lifetime to trigger
proactive refresh (default 0.25 for 25%)

"""
if not client_id:
Expand All @@ -67,14 +96,17 @@ def __init__(
self._audience: str = audience
self._token_data: TokenData | None = None
self._lock: threading.Lock = threading.Lock()
self._refresh_thread: threading.Thread | None = None

if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1:
msg = "proactive_refresh_threshold must be a float between 0 and 1"
raise ValueError(msg)

self._proactive_refresh_threshold: float = proactive_refresh_threshold

def get_token(self) -> TokenData:
"""Get valid token data, refreshing if necessary.

Uses double-checked locking: the happy path (token is valid)
avoids acquiring the lock entirely. The lock is only taken
when the token needs to be refreshed.

Returns
-------
A valid ``TokenData`` containing access token, expiry, and scheme
Expand All @@ -86,9 +118,15 @@ def get_token(self) -> TokenData:

"""
token_data = self._token_data

# Fast path: token is valid and fresh
if token_data is not None and token_data.is_valid():
# If token is old, trigger background refresh
if token_data.is_old(self._proactive_refresh_threshold):
self._start_background_refresh()
return token_data

# Slow path: token is expired or missing, must block
with self._lock:
token_data = self._token_data
if token_data is not None and token_data.is_valid():
Expand All @@ -102,6 +140,71 @@ def get_token(self) -> TokenData:
raise RuntimeError(msg)
return token_data

def _start_background_refresh(self) -> None:
"""Start a background thread to refresh the token.

Only starts a new thread if one isn't already running.

This method is safe to call multiple times - it only starts a new
thread if no refresh is currently in progress.
"""
# Quick check without lock - if refresh thread exists and is
# alive, skip
if self._refresh_thread is not None and self._refresh_thread.is_alive():
return

# Try to acquire lock and start refresh
if self._lock.acquire(blocking=False):
try:
# Double-check: another thread might have started refresh,
# or the token may have been refreshed.
refresh_not_active = (
self._refresh_thread is None
or not self._refresh_thread.is_alive()
)
token_needs_refresh = (
self._token_data is None
or self._token_data.is_old(
self._proactive_refresh_threshold
)
)
refresh_needed = refresh_not_active and token_needs_refresh
if refresh_needed:
self._refresh_thread = threading.Thread(
target=self._background_refresh,
daemon=True,
)
self._refresh_thread.start()
finally:
self._lock.release()

def _background_refresh(self) -> None:
"""Background thread target for token refresh.

Acquires the lock and refreshes the token. Errors are logged
but silently ignored since the next foreground request will
retry if needed.
"""
with self._lock:
# Check if token still needs refresh (prevent stampede)
token_data = self._token_data
if token_data is not None and not token_data.is_old(
self._proactive_refresh_threshold
):
# Token was already refreshed by another thread
return

try:
self._refresh_token()
except Exception as e: # noqa: BLE001
# Log but don't raise - background refresh failures
# are recoverable (next get_token() will retry)
logger.debug(
"Background token refresh failed, "
"will retry on next request: %s",
e,
)

def _refresh_token(self) -> None:
"""Refresh the authentication token by making an OAuth request.

Expand Down Expand Up @@ -138,10 +241,12 @@ def _refresh_token(self) -> None:
token_type = raw.get("token_type", "Bearer")
# Preserve server-provided casing, only strip whitespace
scheme: str = token_type.strip() if token_type else "Bearer"
current_time = time.time()
self._token_data = TokenData(
access_token=access_token,
expires_at=time.time() + expires_in,
expires_at=current_time + expires_in,
scheme=scheme,
issued_at=current_time,
)

except httpx.HTTPStatusError as e:
Expand Down
Loading