diff --git a/src/nwp500/api_client.py b/src/nwp500/api_client.py index abb5fa2..ac90246 100644 --- a/src/nwp500/api_client.py +++ b/src/nwp500/api_client.py @@ -9,7 +9,11 @@ import aiohttp -from .auth import AuthenticationError, NavienAuthClient +from .auth import ( + AuthenticationError, + NavienAuthClient, + TokenRefreshError, +) from .config import API_BASE_URL from .models import Device, FirmwareInfo, TOUInfo @@ -114,6 +118,7 @@ async def _make_request( endpoint: str, json_data: Optional[dict[str, Any]] = None, params: Optional[dict[str, Any]] = None, + retry_on_auth_failure: bool = True, ) -> dict[str, Any]: """ Make an authenticated API request. @@ -123,6 +128,7 @@ async def _make_request( endpoint: API endpoint path json_data: JSON body data params: Query parameters + retry_on_auth_failure: Whether to retry once on 401 errors Returns: Response data dictionary @@ -158,6 +164,42 @@ async def _make_request( msg = response_data.get("msg", "") if code != 200 or not response.ok: + # If we get a 401 and haven't retried yet, try refreshing + # token + if code == 401 and retry_on_auth_failure: + _logger.warning( + "Received 401 Unauthorized. " + "Attempting to refresh token..." + ) + try: + # Try to refresh the token + tokens = self._auth_client.current_tokens + if tokens and tokens.refresh_token: + await self._auth_client.refresh_token( + tokens.refresh_token + ) + # Retry the request once with new token + return await self._make_request( + method, + endpoint, + json_data, + params, + retry_on_auth_failure=False, + ) + else: + _logger.error( + "Cannot refresh token: " + "refresh_token not available" + ) + except ( + TokenRefreshError, + AuthenticationError, + ) as refresh_error: + _logger.error( + f"Token refresh failed: {refresh_error}" + ) + # Fall through to raise original error + _logger.error(f"API error: {code} - {msg}") raise APIError( f"API request failed: {msg}", diff --git a/src/nwp500/auth.py b/src/nwp500/auth.py index ebb72fc..9febb09 100644 --- a/src/nwp500/auth.py +++ b/src/nwp500/auth.py @@ -400,6 +400,30 @@ async def refresh_token(self, refresh_token: str) -> AuthTokens: data = response_data.get("data", {}) new_tokens = AuthTokens.from_dict(data) + # Preserve AWS credentials from old tokens if not in refresh + # response + if self._auth_response and self._auth_response.tokens: + old_tokens = self._auth_response.tokens + if ( + not new_tokens.access_key_id + and old_tokens.access_key_id + ): + new_tokens.access_key_id = old_tokens.access_key_id + if not new_tokens.secret_key and old_tokens.secret_key: + new_tokens.secret_key = old_tokens.secret_key + if ( + not new_tokens.session_token + and old_tokens.session_token + ): + new_tokens.session_token = old_tokens.session_token + if ( + not new_tokens.authorization_expires_in + and old_tokens.authorization_expires_in + ): + new_tokens.authorization_expires_in = ( + old_tokens.authorization_expires_in + ) + # Update stored auth response if we have one if self._auth_response: self._auth_response.tokens = new_tokens diff --git a/src/nwp500/cli/__main__.py b/src/nwp500/cli/__main__.py index 394ea70..4fe9e8d 100644 --- a/src/nwp500/cli/__main__.py +++ b/src/nwp500/cli/__main__.py @@ -130,6 +130,10 @@ async def async_main(args: argparse.Namespace) -> int: _logger.info("Fetching device information...") device = await api_client.get_first_device() + # Save tokens if they were refreshed during API call + if auth_client.current_tokens and auth_client.user_email: + save_tokens(auth_client.current_tokens, auth_client.user_email) + if not device: _logger.error("No devices found for this account.") return 1