Skip to content

Commit 9940178

Browse files
authored
Merge pull request #25 from eman/fix-token-refresh-401
Fix 401 authentication errors with automatic token refresh
2 parents 0966ad8 + f87ab65 commit 9940178

3 files changed

Lines changed: 71 additions & 1 deletion

File tree

src/nwp500/api_client.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
import aiohttp
1111

12-
from .auth import AuthenticationError, NavienAuthClient
12+
from .auth import (
13+
AuthenticationError,
14+
NavienAuthClient,
15+
TokenRefreshError,
16+
)
1317
from .config import API_BASE_URL
1418
from .models import Device, FirmwareInfo, TOUInfo
1519

@@ -114,6 +118,7 @@ async def _make_request(
114118
endpoint: str,
115119
json_data: Optional[dict[str, Any]] = None,
116120
params: Optional[dict[str, Any]] = None,
121+
retry_on_auth_failure: bool = True,
117122
) -> dict[str, Any]:
118123
"""
119124
Make an authenticated API request.
@@ -123,6 +128,7 @@ async def _make_request(
123128
endpoint: API endpoint path
124129
json_data: JSON body data
125130
params: Query parameters
131+
retry_on_auth_failure: Whether to retry once on 401 errors
126132
127133
Returns:
128134
Response data dictionary
@@ -158,6 +164,42 @@ async def _make_request(
158164
msg = response_data.get("msg", "")
159165

160166
if code != 200 or not response.ok:
167+
# If we get a 401 and haven't retried yet, try refreshing
168+
# token
169+
if code == 401 and retry_on_auth_failure:
170+
_logger.warning(
171+
"Received 401 Unauthorized. "
172+
"Attempting to refresh token..."
173+
)
174+
try:
175+
# Try to refresh the token
176+
tokens = self._auth_client.current_tokens
177+
if tokens and tokens.refresh_token:
178+
await self._auth_client.refresh_token(
179+
tokens.refresh_token
180+
)
181+
# Retry the request once with new token
182+
return await self._make_request(
183+
method,
184+
endpoint,
185+
json_data,
186+
params,
187+
retry_on_auth_failure=False,
188+
)
189+
else:
190+
_logger.error(
191+
"Cannot refresh token: "
192+
"refresh_token not available"
193+
)
194+
except (
195+
TokenRefreshError,
196+
AuthenticationError,
197+
) as refresh_error:
198+
_logger.error(
199+
f"Token refresh failed: {refresh_error}"
200+
)
201+
# Fall through to raise original error
202+
161203
_logger.error(f"API error: {code} - {msg}")
162204
raise APIError(
163205
f"API request failed: {msg}",

src/nwp500/auth.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,30 @@ async def refresh_token(self, refresh_token: str) -> AuthTokens:
400400
data = response_data.get("data", {})
401401
new_tokens = AuthTokens.from_dict(data)
402402

403+
# Preserve AWS credentials from old tokens if not in refresh
404+
# response
405+
if self._auth_response and self._auth_response.tokens:
406+
old_tokens = self._auth_response.tokens
407+
if (
408+
not new_tokens.access_key_id
409+
and old_tokens.access_key_id
410+
):
411+
new_tokens.access_key_id = old_tokens.access_key_id
412+
if not new_tokens.secret_key and old_tokens.secret_key:
413+
new_tokens.secret_key = old_tokens.secret_key
414+
if (
415+
not new_tokens.session_token
416+
and old_tokens.session_token
417+
):
418+
new_tokens.session_token = old_tokens.session_token
419+
if (
420+
not new_tokens.authorization_expires_in
421+
and old_tokens.authorization_expires_in
422+
):
423+
new_tokens.authorization_expires_in = (
424+
old_tokens.authorization_expires_in
425+
)
426+
403427
# Update stored auth response if we have one
404428
if self._auth_response:
405429
self._auth_response.tokens = new_tokens

src/nwp500/cli/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ async def async_main(args: argparse.Namespace) -> int:
130130
_logger.info("Fetching device information...")
131131
device = await api_client.get_first_device()
132132

133+
# Save tokens if they were refreshed during API call
134+
if auth_client.current_tokens and auth_client.user_email:
135+
save_tokens(auth_client.current_tokens, auth_client.user_email)
136+
133137
if not device:
134138
_logger.error("No devices found for this account.")
135139
return 1

0 commit comments

Comments
 (0)