Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions src/nwp500/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,21 @@ class AuthTokens:
_expires_at: datetime = field(
default=datetime.now(), init=False, repr=False
)
_aws_expires_at: Optional[datetime] = field(
default=None, init=False, repr=False
)

def __post_init__(self) -> None:
"""Cache the expiration timestamp after initialization."""
# Pre-calculate and cache the expiration time
self._expires_at = self.issued_at + timedelta(
seconds=self.authentication_expires_in
)
# Calculate AWS credentials expiration if available
if self.authorization_expires_in:
self._aws_expires_at = self.issued_at + timedelta(
seconds=self.authorization_expires_in
)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AuthTokens":
Expand All @@ -106,6 +114,25 @@ def is_expired(self) -> bool:
# Consider expired if within 5 minutes of expiration
return datetime.now() >= (self._expires_at - timedelta(minutes=5))

@property
def are_aws_credentials_expired(self) -> bool:
"""Check if AWS credentials have expired.

AWS credentials have a separate expiration time from JWT tokens.
If AWS credentials are expired, a full re-authentication is needed
since the token refresh endpoint doesn't provide new AWS credentials.

Returns:
True if AWS credentials are expired, False if expiration time is
unknown or credentials are still valid
"""
if not self._aws_expires_at:
# If we don't know when AWS credentials expire, consider them valid
# This handles cases where authorization_expires_in wasn't provided
return False
# Consider expired if within 5 minutes of expiration
return datetime.now() >= (self._aws_expires_at - timedelta(minutes=5))

@property
def time_until_expiry(self) -> timedelta:
"""Get time remaining until token expiration.
Expand Down Expand Up @@ -423,6 +450,8 @@ async def refresh_token(self, refresh_token: str) -> AuthTokens:
new_tokens.authorization_expires_in = (
old_tokens.authorization_expires_in
)
# Also preserve the AWS expiration timestamp
new_tokens._aws_expires_at = old_tokens._aws_expires_at

# Update stored auth response if we have one
if self._auth_response:
Expand All @@ -446,23 +475,36 @@ async def ensure_valid_token(self) -> Optional[AuthTokens]:
"""
Ensure we have a valid access token, refreshing if necessary.

This method checks both JWT token and AWS credentials expiration.
If AWS credentials are expired, it triggers a full re-authentication
since the token refresh endpoint doesn't provide new AWS credentials.

Returns:
Valid AuthTokens or None if not authenticated

Raises:
TokenRefreshError: If token refresh fails
AuthenticationError: If re-authentication fails
"""
if not self._auth_response:
_logger.warning("No authentication response available")
return None

if self._auth_response.tokens.is_expired:
tokens = self._auth_response.tokens

# Check if AWS credentials have expired
if tokens.are_aws_credentials_expired:
_logger.info("AWS credentials expired, re-authenticating...")
# Re-authenticate to get fresh AWS credentials
await self.sign_in(self._user_id, self._password)
return self._auth_response.tokens if self._auth_response else None

# Check if JWT token has expired
if tokens.is_expired:
_logger.info("Token expired, refreshing...")
return await self.refresh_token(
self._auth_response.tokens.refresh_token
)
return await self.refresh_token(tokens.refresh_token)

return self._auth_response.tokens
return tokens

@property
def is_authenticated(self) -> bool:
Expand Down
Loading