Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/nwp500/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,23 @@ def has_stored_credentials(self) -> bool:
"""
return bool(self._user_id and self._password)

@property
def has_valid_tokens(self) -> bool:
"""Check if both JWT and AWS credentials are valid and not expired.

Returns True only if tokens exist AND neither JWT tokens nor AWS
credentials have expired. This is useful for pre-flight checks before
operations that require valid credentials (e.g., MQTT connection).

Returns:
True if tokens exist AND not expired (JWT + AWS creds), False
otherwise
"""
if not self._auth_response:
return False
tokens = self._auth_response.tokens
return not tokens.is_expired and not tokens.are_aws_credentials_expired

async def close(self) -> None:
"""Close the aiohttp session if we own it."""
if self._owned_session and self._session:
Expand Down
63 changes: 61 additions & 2 deletions src/nwp500/mqtt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,21 @@ def __init__(
config: Optional connection configuration

Raises:
ValueError: If auth client is not authenticated or AWS
credentials are not available
MqttCredentialsError: If auth client is not authenticated, tokens
are stale/expired, or AWS credentials are not available
"""
if not auth_client.is_authenticated:
raise MqttCredentialsError(
"Authentication client must be authenticated before "
"creating MQTT client. Call auth_client.sign_in() first."
)

if not auth_client.has_valid_tokens:
raise MqttCredentialsError(
"Tokens are stale/expired. "
"Call ensure_valid_token() or re_authenticate() first."
)

if not auth_client.current_tokens:
raise MqttCredentialsError("No tokens available from auth client")

Expand Down Expand Up @@ -528,6 +534,59 @@ async def connect(self) -> bool:
_logger.error(f"Failed to connect: {e}")
raise

async def recover_connection(self) -> bool:
"""Recover from authentication-related connection failures.

This method is useful when MQTT connection fails due to stale/expired
authentication tokens. It refreshes the tokens and attempts to reconnect
the MQTT client.

Returns:
True if recovery was successful and MQTT is reconnected, False
otherwise

Raises:
TokenRefreshError: If token refresh fails
AuthenticationError: If re-authentication fails

Example:
>>> mqtt_client = NavienMqttClient(auth_client)
>>> try:
... await mqtt_client.connect()
... except MqttConnectionError:
... # Connection may have failed due to stale tokens
... if await mqtt_client.recover_connection():
... print("Successfully recovered connection")
... else:
... print("Recovery failed, check logs")
"""
_logger.info(
"Attempting to recover MQTT connection by refreshing tokens"
)

try:
# Step 1: Refresh authentication tokens
await self._auth_client.ensure_valid_token()
_logger.debug("Authentication tokens refreshed")

# Step 2: Attempt to reconnect
if self._connected:
_logger.info("Already connected after token refresh")
return True

# If not connected, try to reconnect
success = await self.connect()
if success:
_logger.info("MQTT connection successfully recovered")
return True
else:
_logger.error("MQTT reconnection failed despite valid tokens")
return False

except (TokenRefreshError, AuthenticationError) as e:
_logger.error(f"Failed to recover connection: {e}")
raise

def _create_credentials_provider(self) -> Any:
"""Create AWS credentials provider from auth tokens."""
from awscrt.auth import AwsCredentialsProvider
Expand Down
Loading