From 8e1025bf951e5af4e18ed2b1a856d15e6b951763 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Mon, 27 Oct 2025 14:18:59 -0700 Subject: [PATCH 1/4] Add token restoration support for session persistence - Add stored_tokens parameter to NavienAuthClient for restoring saved tokens - Add AuthTokens.to_dict() for serializing tokens with issued_at timestamp - Enhance AuthTokens.from_dict() to support both API and stored data formats - Skip authentication when valid stored tokens provided - Auto-refresh expired JWT tokens or re-authenticate if AWS creds expired - Add 7 new tests covering token serialization and restoration flows - Add examples/token_restoration_example.py demonstrating workflow - Update authentication documentation with token restoration guide Benefits: - Reduces API load and improves startup time - Prevents rate limiting for frequently restarting applications - Enables session persistence across application restarts --- CHANGELOG.rst | 19 +++ docs/python_api/auth_client.rst | 106 +++++++++++- examples/token_restoration_example.py | 154 +++++++++++++++++ src/nwp500/auth.py | 133 +++++++++++++-- tests/test_auth.py | 237 ++++++++++++++++++++++++++ 5 files changed, 635 insertions(+), 14 deletions(-) create mode 100644 examples/token_restoration_example.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b2dd316..b360f37 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,25 @@ Changelog ========= +Version 4.8.0 (2025-10-27) +========================== + +Added +----- + +- **Token Restoration Support**: Enable session persistence across application restarts + + - Added ``stored_tokens`` parameter to ``NavienAuthClient.__init__()`` for restoring saved tokens + - Added ``AuthTokens.to_dict()`` method for serializing tokens (includes ``issued_at`` timestamp) + - Enhanced ``AuthTokens.from_dict()`` to support both API responses (camelCase) and stored data (snake_case) + - Modified ``NavienAuthClient.__aenter__()`` to skip authentication when valid stored tokens are provided + - Automatically refreshes expired JWT tokens or re-authenticates if AWS credentials expired + - Added 7 new tests for token serialization, deserialization, and restoration flows + - Added ``examples/token_restoration_example.py`` demonstrating save/restore workflow + - Updated authentication documentation with token restoration guide + +- **Benefits**: Reduces API load, improves startup time, prevents rate limiting for frequently restarting applications (e.g., Home Assistant) + Version 4.7.1 (2025-10-27) ========================== diff --git a/docs/python_api/auth_client.rst b/docs/python_api/auth_client.rst index ffc7417..a3049d2 100644 --- a/docs/python_api/auth_client.rst +++ b/docs/python_api/auth_client.rst @@ -60,7 +60,7 @@ API Reference NavienAuthClient ---------------- -.. py:class:: NavienAuthClient(email=None, password=None, base_url=API_BASE_URL) +.. py:class:: NavienAuthClient(email=None, password=None, base_url=API_BASE_URL, stored_tokens=None) JWT-based authentication client for Navien Smart Control API. @@ -70,6 +70,8 @@ NavienAuthClient :type password: str or None :param base_url: API base URL :type base_url: str + :param stored_tokens: Previously saved tokens to restore session + :type stored_tokens: AuthTokens or None **Example:** @@ -81,11 +83,24 @@ NavienAuthClient # From environment variables auth = NavienAuthClient() + # With stored tokens (skip re-authentication) + stored = AuthTokens.from_dict(saved_data) + auth = NavienAuthClient( + "email@example.com", + "password", + stored_tokens=stored + ) + # Always use as context manager async with auth: # Authenticated pass + .. note:: + If ``stored_tokens`` are provided and still valid, the initial + sign-in is skipped. If tokens are expired, they're automatically + refreshed or re-authenticated as needed. + Authentication Methods ---------------------- @@ -318,6 +333,7 @@ AuthTokens :param access_key_id: AWS access key (for MQTT) :param secret_key: AWS secret key (for MQTT) :param session_token: AWS session token (for MQTT) + :param issued_at: Token issue timestamp (auto-set if not provided) **Properties:** @@ -325,6 +341,39 @@ AuthTokens * ``is_expired`` - Check if expired * ``time_until_expiry`` - Time remaining * ``bearer_token`` - Formatted bearer token + * ``are_aws_credentials_expired`` - Check if AWS credentials expired + + **Methods:** + + .. py:method:: from_dict(data) + :classmethod: + + Create AuthTokens from dictionary (API response or saved data). + + :param data: Token data dictionary + :type data: dict[str, Any] + :return: AuthTokens instance + :rtype: AuthTokens + + Supports both camelCase keys (API response) and snake_case keys (saved data). + + .. py:method:: to_dict() + + Serialize tokens to dictionary for storage. + + :return: Dictionary with all token data including issued_at timestamp + :rtype: dict[str, Any] + + **Example:** + + .. code-block:: python + + # Save tokens + tokens = auth.current_tokens + token_data = tokens.to_dict() + + # Later, restore tokens + restored = AuthTokens.from_dict(token_data) AuthenticationResponse ---------------------- @@ -418,6 +467,61 @@ Example 4: Long-Running Application # Sleep await asyncio.sleep(3600) +Example 5: Token Restoration (Skip Re-authentication) +------------------------------------------------------ + +.. code-block:: python + + import json + from nwp500 import NavienAuthClient + from nwp500.auth import AuthTokens + + async def save_tokens(): + """Save tokens for later reuse.""" + async with NavienAuthClient(email, password) as auth: + tokens = auth.current_tokens + + # Serialize tokens to dictionary + token_data = tokens.to_dict() + + # Save to file (or database, cache, etc.) + with open('tokens.json', 'w') as f: + json.dump(token_data, f) + + print("Tokens saved for future use") + + async def restore_tokens(): + """Restore authentication from saved tokens.""" + # Load saved tokens + with open('tokens.json') as f: + token_data = json.load(f) + + # Deserialize tokens + stored_tokens = AuthTokens.from_dict(token_data) + + # Initialize client with stored tokens + # This skips initial authentication if tokens are still valid + async with NavienAuthClient( + email, password, + stored_tokens=stored_tokens + ) as auth: + # If tokens were expired, they're automatically refreshed + # If AWS credentials expired, re-authentication occurs + print(f"Authenticated (from stored tokens): {auth.user_email}") + + # Always save updated tokens after refresh + new_tokens = auth.current_tokens + if new_tokens.issued_at != stored_tokens.issued_at: + token_data = new_tokens.to_dict() + with open('tokens.json', 'w') as f: + json.dump(token_data, f) + print("Tokens were refreshed and re-saved") + +.. note:: + Token restoration is especially useful for applications that restart + frequently (like Home Assistant) to avoid unnecessary authentication + requests on every restart. + Error Handling ============== diff --git a/examples/token_restoration_example.py b/examples/token_restoration_example.py new file mode 100644 index 0000000..6815858 --- /dev/null +++ b/examples/token_restoration_example.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +"""Example demonstrating token restoration/persistence. + +This example shows how to save and restore authentication tokens to avoid +re-authenticating on every application restart. This is especially useful +for applications like Home Assistant that restart frequently. + +Usage: + # First run - authenticate and save tokens + python3 token_restoration_example.py --save + + # Subsequent runs - restore from saved tokens + python3 token_restoration_example.py --restore +""" + +import argparse +import asyncio +import json +import logging +import os +from pathlib import Path + +from nwp500 import NavienAuthClient + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# Token storage file +TOKEN_FILE = Path.home() / ".navien_tokens.json" + + +async def save_tokens_example(): + """Authenticate and save tokens for future use.""" + email = os.getenv("NAVIEN_EMAIL") + password = os.getenv("NAVIEN_PASSWORD") + + if not email or not password: + raise ValueError( + "Please set NAVIEN_EMAIL and NAVIEN_PASSWORD environment variables" + ) + + logger.info("Authenticating with Navien API...") + + # Authenticate normally + async with NavienAuthClient(email, password) as auth_client: + tokens = auth_client.current_tokens + if not tokens: + raise RuntimeError("Failed to obtain tokens") + + logger.info("✓ Authentication successful") + logger.info(f"Token expires at: {tokens.expires_at}") + + # Serialize tokens to dictionary + token_data = tokens.to_dict() + + # Save to file + with open(TOKEN_FILE, "w") as f: + json.dump(token_data, f, indent=2) + + logger.info(f"✓ Tokens saved to {TOKEN_FILE}") + logger.info("You can now use --restore to skip authentication on future runs") + + +async def restore_tokens_example(): + """Restore authentication from saved tokens.""" + if not TOKEN_FILE.exists(): + raise FileNotFoundError( + f"Token file not found: {TOKEN_FILE}\n" + "Please run with --save first to authenticate and save tokens" + ) + + email = os.getenv("NAVIEN_EMAIL") + password = os.getenv("NAVIEN_PASSWORD") + + if not email or not password: + raise ValueError( + "Please set NAVIEN_EMAIL and NAVIEN_PASSWORD environment variables" + ) + + # Load saved tokens + with open(TOKEN_FILE) as f: + token_data = json.load(f) + + logger.info(f"Loading tokens from {TOKEN_FILE}...") + + # Import after getting token_data to avoid circular import issues + from nwp500.auth import AuthTokens + + stored_tokens = AuthTokens.from_dict(token_data) + + logger.info(f"Stored tokens issued at: {stored_tokens.issued_at}") + logger.info(f"Stored tokens expire at: {stored_tokens.expires_at}") + + if stored_tokens.is_expired: + logger.warning("⚠ Stored tokens are expired, will refresh...") + elif stored_tokens.are_aws_credentials_expired: + logger.warning("⚠ AWS credentials expired, will re-authenticate...") + else: + logger.info("✓ Stored tokens are still valid") + + # Use stored tokens to initialize client + async with NavienAuthClient( + email, password, stored_tokens=stored_tokens + ) as auth_client: + tokens = auth_client.current_tokens + if not tokens: + raise RuntimeError("Failed to restore authentication") + + logger.info("✓ Successfully authenticated using stored tokens") + logger.info(f"Current token expires at: {tokens.expires_at}") + + # If tokens were refreshed, save them + if tokens.issued_at != stored_tokens.issued_at: + logger.info("Tokens were refreshed, updating stored copy...") + token_data = tokens.to_dict() + with open(TOKEN_FILE, "w") as f: + json.dump(token_data, f, indent=2) + logger.info(f"✓ Updated tokens saved to {TOKEN_FILE}") + + +async def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Token restoration example for nwp500-python" + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--save", + action="store_true", + help="Authenticate and save tokens for future use", + ) + group.add_argument( + "--restore", + action="store_true", + help="Restore authentication from saved tokens", + ) + + args = parser.parse_args() + + try: + if args.save: + await save_tokens_example() + else: + await restore_tokens_example() + except Exception as e: + logger.error(f"Error: {e}") + raise + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/nwp500/auth.py b/src/nwp500/auth.py index 51f165c..5b162de 100644 --- a/src/nwp500/auth.py +++ b/src/nwp500/auth.py @@ -91,18 +91,80 @@ def __post_init__(self) -> None: @classmethod def from_dict(cls, data: dict[str, Any]) -> "AuthTokens": - """Create AuthTokens from API response dictionary.""" + """Create AuthTokens from API response dictionary or stored data. + + Args: + data: Dictionary containing token data. Can be from API response + (using camelCase keys) or from stored data (using snake_case + keys from to_dict()). + + Returns: + AuthTokens instance + + Example: + # From API response + >>> tokens = AuthTokens.from_dict({ + ... "idToken": "...", + ... "accessToken": "...", + ... "refreshToken": "...", + ... "authenticationExpiresIn": 3600 + ... }) + + # From stored data (after to_dict()) + >>> stored = tokens.to_dict() + >>> restored = AuthTokens.from_dict(stored) + """ + # Support both camelCase (API) and snake_case (stored) keys return cls( - id_token=data.get("idToken", ""), - access_token=data.get("accessToken", ""), - refresh_token=data.get("refreshToken", ""), - authentication_expires_in=data.get("authenticationExpiresIn", 3600), - access_key_id=data.get("accessKeyId"), - secret_key=data.get("secretKey"), - session_token=data.get("sessionToken"), - authorization_expires_in=data.get("authorizationExpiresIn"), + id_token=data.get("idToken") or data.get("id_token", ""), + access_token=data.get("accessToken") + or data.get("access_token", ""), + refresh_token=data.get("refreshToken") + or data.get("refresh_token", ""), + authentication_expires_in=data.get("authenticationExpiresIn") + or data.get("authentication_expires_in", 3600), + access_key_id=data.get("accessKeyId") or data.get("access_key_id"), + secret_key=data.get("secretKey") or data.get("secret_key"), + session_token=data.get("sessionToken") or data.get("session_token"), + authorization_expires_in=data.get("authorizationExpiresIn") + or data.get("authorization_expires_in"), + issued_at=datetime.fromisoformat(data["issued_at"]) + if "issued_at" in data + else datetime.now(), ) + def to_dict(self) -> dict[str, Any]: + """Convert AuthTokens to a dictionary for storage. + + Returns a dictionary with all token data including the issued_at + timestamp, which is essential for correctly calculating expiration + times when restoring tokens. + + Returns: + Dictionary with snake_case keys suitable for JSON serialization + + Example: + >>> tokens = auth_client.current_tokens + >>> stored_data = tokens.to_dict() + >>> # Save to file/database + >>> import json + >>> json.dump(stored_data, file) + >>> + >>> # Later, restore tokens + >>> restored_tokens = AuthTokens.from_dict(json.load(file)) + """ + return { + "id_token": self.id_token, + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "authentication_expires_in": self.authentication_expires_in, + "access_key_id": self.access_key_id, + "secret_key": self.secret_key, + "session_token": self.session_token, + "authorization_expires_in": self.authorization_expires_in, + "issued_at": self.issued_at.isoformat(), + } + @property def expires_at(self) -> datetime: """Get the cached expiration timestamp.""" @@ -236,7 +298,7 @@ class NavienAuthClient: - AWS credentials (if provided by API) Authentication is performed automatically when entering the async context - manager. + manager, unless valid stored tokens are provided. Example: >>> async with NavienAuthClient(user_id="user@example.com", @@ -251,6 +313,16 @@ class NavienAuthClient: ... if client.current_tokens.is_expired: ... new_tokens = await client.refresh_token(client.current_tokens.refresh_token) + + Restore session from stored tokens: + >>> stored_tokens = AuthTokens.from_dict(saved_data) + >>> async with NavienAuthClient( + ... user_id="user@example.com", + ... password="password", + ... stored_tokens=stored_tokens + ... ) as client: + ... # Authentication skipped if tokens are still valid + ... print(f"Access token: {client.current_tokens.access_token}") """ def __init__( @@ -260,6 +332,7 @@ def __init__( base_url: str = API_BASE_URL, session: Optional[aiohttp.ClientSession] = None, timeout: int = 30, + stored_tokens: Optional[AuthTokens] = None, ): """ Initialize the authentication client. @@ -270,10 +343,13 @@ def __init__( base_url: Base URL for the API (default: official Navien API) session: Optional aiohttp ClientSession to use timeout: Request timeout in seconds + stored_tokens: Previously saved tokens to restore session. + If provided and valid, skips initial sign_in. Note: Authentication is performed automatically when entering the - async context manager (using async with statement). + async context manager (using async with statement), unless + valid stored_tokens are provided. """ self.base_url = base_url.rstrip("/") self._session = session @@ -288,13 +364,44 @@ def __init__( self._auth_response: Optional[AuthenticationResponse] = None self._user_email: Optional[str] = None + # Restore tokens if provided + if stored_tokens: + # Create a minimal AuthenticationResponse with stored tokens + # UserInfo will be populated on first API call if needed + self._auth_response = AuthenticationResponse( + user_info=UserInfo( + user_type="", + user_first_name="", + user_last_name="", + user_status="", + user_seq=0, + ), + tokens=stored_tokens, + ) + self._user_email = user_id + async def __aenter__(self) -> "NavienAuthClient": """Async context manager entry.""" if self._owned_session: self._session = aiohttp.ClientSession(timeout=self.timeout) - # Automatically authenticate - await self.sign_in(self._user_id, self._password) + # Check if we have valid stored tokens + if self._auth_response and self._auth_response.tokens: + tokens = self._auth_response.tokens + # If tokens are expired, refresh or re-authenticate + if tokens.are_aws_credentials_expired: + _logger.info( + "Stored AWS credentials expired, re-authenticating..." + ) + await self.sign_in(self._user_id, self._password) + elif tokens.is_expired: + _logger.info("Stored JWT token expired, refreshing...") + await self.refresh_token(tokens.refresh_token) + else: + _logger.info("Using stored tokens, skipping authentication") + else: + # No stored tokens, perform full authentication + await self.sign_in(self._user_id, self._password) return self diff --git a/tests/test_auth.py b/tests/test_auth.py index 421ee54..121ae7f 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -835,3 +835,240 @@ def test_aws_credentials_preservation_in_token_refresh(): assert new_tokens.session_token == "old_session" assert new_tokens.authorization_expires_in == 3600 assert new_tokens._aws_expires_at == old_tokens._aws_expires_at + + +# Test token restoration functionality +def test_auth_tokens_to_dict(): + """Test AuthTokens.to_dict serialization.""" + issued_at = datetime.now() + tokens = AuthTokens( + id_token="test_id", + access_token="test_access", + refresh_token="test_refresh", + authentication_expires_in=3600, + access_key_id="test_key", + secret_key="test_secret", + session_token="test_session", + authorization_expires_in=1800, + issued_at=issued_at, + ) + + result = tokens.to_dict() + + assert result["id_token"] == "test_id" + assert result["access_token"] == "test_access" + assert result["refresh_token"] == "test_refresh" + assert result["authentication_expires_in"] == 3600 + assert result["access_key_id"] == "test_key" + assert result["secret_key"] == "test_secret" + assert result["session_token"] == "test_session" + assert result["authorization_expires_in"] == 1800 + assert result["issued_at"] == issued_at.isoformat() + + +def test_auth_tokens_from_dict_with_issued_at(): + """Test AuthTokens.from_dict with issued_at timestamp.""" + issued_at = datetime.now() - timedelta(seconds=1800) + data = { + "id_token": "test_id", + "access_token": "test_access", + "refresh_token": "test_refresh", + "authentication_expires_in": 3600, + "access_key_id": "test_key", + "secret_key": "test_secret", + "session_token": "test_session", + "authorization_expires_in": 1800, + "issued_at": issued_at.isoformat(), + } + + tokens = AuthTokens.from_dict(data) + + assert tokens.id_token == "test_id" + assert tokens.access_token == "test_access" + assert tokens.refresh_token == "test_refresh" + assert tokens.authentication_expires_in == 3600 + assert tokens.access_key_id == "test_key" + assert tokens.secret_key == "test_secret" + assert tokens.session_token == "test_session" + assert tokens.authorization_expires_in == 1800 + # Check that issued_at was correctly restored + assert abs((tokens.issued_at - issued_at).total_seconds()) < 1 + + +def test_auth_tokens_serialization_roundtrip(): + """Test that tokens can be serialized and deserialized without data loss.""" + issued_at = datetime.now() - timedelta(seconds=1800) + original = AuthTokens( + id_token="test_id", + access_token="test_access", + refresh_token="test_refresh", + authentication_expires_in=3600, + access_key_id="test_key", + secret_key="test_secret", + session_token="test_session", + authorization_expires_in=1800, + issued_at=issued_at, + ) + + # Serialize and deserialize + serialized = original.to_dict() + restored = AuthTokens.from_dict(serialized) + + # Verify all fields match + assert restored.id_token == original.id_token + assert restored.access_token == original.access_token + assert restored.refresh_token == original.refresh_token + assert ( + restored.authentication_expires_in == original.authentication_expires_in + ) + assert restored.access_key_id == original.access_key_id + assert restored.secret_key == original.secret_key + assert restored.session_token == original.session_token + assert ( + restored.authorization_expires_in == original.authorization_expires_in + ) + # Verify issued_at is preserved (critical for expiration calculations) + assert abs((restored.issued_at - original.issued_at).total_seconds()) < 1 + # Verify expiration calculations are the same + assert abs((restored.expires_at - original.expires_at).total_seconds()) < 1 + assert restored.is_expired == original.is_expired + + +def test_navien_auth_client_initialization_with_stored_tokens(): + """Test NavienAuthClient initialization with stored tokens.""" + stored_tokens = AuthTokens( + id_token="stored_id", + access_token="stored_access", + refresh_token="stored_refresh", + authentication_expires_in=3600, + access_key_id="stored_key", + secret_key="stored_secret", + session_token="stored_session", + authorization_expires_in=1800, + ) + + client = NavienAuthClient( + user_id="test@example.com", + password="test_password", + stored_tokens=stored_tokens, + ) + + # Should have auth response set up with stored tokens + assert client.is_authenticated is True + assert client.current_tokens == stored_tokens + assert client.user_email == "test@example.com" + + +@pytest.mark.asyncio +async def test_context_manager_with_valid_stored_tokens(): + """Test async context manager skips auth with valid stored tokens.""" + stored_tokens = AuthTokens( + id_token="stored_id", + access_token="stored_access", + refresh_token="stored_refresh", + authentication_expires_in=3600, # Valid for 1 hour + access_key_id="stored_key", + secret_key="stored_secret", + session_token="stored_session", + authorization_expires_in=3600, # AWS creds valid for 1 hour + ) + + with patch.object( + NavienAuthClient, "sign_in", new_callable=AsyncMock + ) as mock_sign_in: + async with NavienAuthClient( + user_id="test@example.com", + password="test_password", + stored_tokens=stored_tokens, + ) as client: + # Should NOT have called sign_in since tokens are valid + mock_sign_in.assert_not_called() + assert client.current_tokens == stored_tokens + assert client._session is not None + + +@pytest.mark.asyncio +async def test_context_manager_with_expired_jwt_stored_tokens(): + """Test async context manager with expired JWT refreshes tokens.""" + old_time = datetime.now() - timedelta(seconds=3900) # 65 minutes ago + stored_tokens = AuthTokens( + id_token="stored_id", + access_token="stored_access", + refresh_token="stored_refresh", + authentication_expires_in=3600, # Expired 5 minutes ago + issued_at=old_time, + ) + + new_tokens = AuthTokens( + id_token="new_id", + access_token="new_access", + refresh_token="new_refresh", + authentication_expires_in=3600, + ) + + with patch.object( + NavienAuthClient, "refresh_token", new_callable=AsyncMock + ) as mock_refresh: + mock_refresh.return_value = new_tokens + + async with NavienAuthClient( + user_id="test@example.com", + password="test_password", + stored_tokens=stored_tokens, + ) as client: + # Should have called refresh_token + mock_refresh.assert_called_once_with("stored_refresh") + assert client._session is not None + + +@pytest.mark.asyncio +async def test_context_manager_with_expired_aws_credentials(): + """Test async context manager re-authenticates on AWS creds expiry.""" + old_time = datetime.now() - timedelta(seconds=3900) # 65 minutes ago + stored_tokens = AuthTokens( + id_token="stored_id", + access_token="stored_access", + refresh_token="stored_refresh", + authentication_expires_in=7200, # JWT still valid for 55 minutes + access_key_id="stored_key", + secret_key="stored_secret", + session_token="stored_session", + authorization_expires_in=3600, # AWS creds expired 5 minutes ago + issued_at=old_time, + ) + + new_tokens = AuthTokens( + id_token="new_id", + access_token="new_access", + refresh_token="new_refresh", + authentication_expires_in=3600, + access_key_id="new_key", + secret_key="new_secret", + session_token="new_session", + authorization_expires_in=3600, + ) + + with patch.object( + NavienAuthClient, "sign_in", new_callable=AsyncMock + ) as mock_sign_in: + mock_sign_in.return_value = AuthenticationResponse( + user_info=UserInfo( + user_type="test", + user_first_name="Test", + user_last_name="User", + user_status="active", + user_seq=1, + ), + tokens=new_tokens, + ) + + async with NavienAuthClient( + user_id="test@example.com", + password="test_password", + stored_tokens=stored_tokens, + ) as client: + # Should have called sign_in due to expired AWS credentials + mock_sign_in.assert_called_once_with( + "test@example.com", "test_password" + ) + assert client._session is not None From 87b8bdb7a5ae62cc189a00ac7c9bf8fa3fb30268 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Mon, 27 Oct 2025 15:24:36 -0700 Subject: [PATCH 2/4] Update CLI to use token restoration API - Replace manual token reconstruction with stored_tokens parameter - Use AuthTokens.to_dict() and from_dict() for serialization - Simplify token storage code by leveraging built-in methods - Remove unnecessary manual validation logic - Properly manage context manager lifecycle for auth client Benefits: - Cleaner, more maintainable code - Automatic token refresh and AWS credential re-authentication - Consistent with new token restoration pattern --- src/nwp500/cli/__main__.py | 55 ++++++++++++--------------------- src/nwp500/cli/token_storage.py | 47 +++++++--------------------- 2 files changed, 31 insertions(+), 71 deletions(-) diff --git a/src/nwp500/cli/__main__.py b/src/nwp500/cli/__main__.py index 4fe9e8d..3a2cc84 100644 --- a/src/nwp500/cli/__main__.py +++ b/src/nwp500/cli/__main__.py @@ -12,11 +12,7 @@ from typing import Optional from nwp500 import NavienAPIClient, NavienAuthClient, __version__ -from nwp500.auth import ( - AuthenticationResponse, - InvalidCredentialsError, - UserInfo, -) +from nwp500.auth import InvalidCredentialsError from .commands import ( handle_device_feature_request, @@ -55,37 +51,15 @@ async def get_authenticated_client( Returns: NavienAuthClient instance or None if authentication fails """ + # Get credentials + email = args.email or os.getenv("NAVIEN_EMAIL") + password = args.password or os.getenv("NAVIEN_PASSWORD") + # Try loading cached tokens tokens, cached_email = load_tokens() - # Check if cached tokens are valid and complete - if ( - tokens - and cached_email - and not tokens.is_expired - and tokens.access_key_id - and tokens.secret_key - and tokens.session_token - ): - _logger.info("Using valid cached tokens.") - # The password argument is unused when cached tokens are present. - auth_client = NavienAuthClient(cached_email, "cached_auth") - auth_client._user_email = cached_email - await auth_client._ensure_session() - - # Manually construct the auth response since we are not signing in - auth_client._auth_response = AuthenticationResponse( - user_info=UserInfo.from_dict({}), tokens=tokens - ) - return auth_client - - _logger.info( - "Cached tokens are invalid, expired, or incomplete. " - "Re-authenticating..." - ) - # Fallback to email/password - email = args.email or os.getenv("NAVIEN_EMAIL") - password = args.password or os.getenv("NAVIEN_PASSWORD") + # Use cached email if available, otherwise fall back to provided email + email = cached_email or email if not email or not password: _logger.error( @@ -95,11 +69,19 @@ async def get_authenticated_client( return None try: - auth_client = NavienAuthClient(email, password) - await auth_client.sign_in(email, password) + # Use the new stored_tokens parameter for token restoration + # This will automatically handle token validation and refresh + auth_client = NavienAuthClient(email, password, stored_tokens=tokens) + + # Enter the context manager to authenticate/restore session + await auth_client.__aenter__() + + # Save refreshed/new tokens if auth_client.current_tokens and auth_client.user_email: save_tokens(auth_client.current_tokens, auth_client.user_email) + return auth_client + except InvalidCredentialsError: _logger.error("Invalid email or password.") return None @@ -243,7 +225,8 @@ async def async_main(args: argparse.Namespace) -> int: return 1 finally: # Auth client close will close the underlying aiohttp session - await auth_client.close() + # Also call __aexit__ to properly clean up context manager + await auth_client.__aexit__(None, None, None) _logger.info("Cleanup complete.") return 0 diff --git a/src/nwp500/cli/token_storage.py b/src/nwp500/cli/token_storage.py index 4204819..676a9f7 100644 --- a/src/nwp500/cli/token_storage.py +++ b/src/nwp500/cli/token_storage.py @@ -2,7 +2,6 @@ import json import logging -from datetime import datetime from pathlib import Path from typing import Optional @@ -23,24 +22,10 @@ def save_tokens(tokens: AuthTokens, email: str) -> None: """ try: with open(TOKEN_FILE, "w") as f: - json.dump( - { - "email": email, - "id_token": tokens.id_token, - "access_token": tokens.access_token, - "refresh_token": tokens.refresh_token, - "authentication_expires_in": ( - tokens.authentication_expires_in - ), - "issued_at": tokens.issued_at.isoformat(), - # AWS Credentials - "access_key_id": tokens.access_key_id, - "secret_key": tokens.secret_key, - "session_token": tokens.session_token, - "authorization_expires_in": tokens.authorization_expires_in, - }, - f, - ) + # Use the built-in to_dict() method for serialization + token_data = tokens.to_dict() + token_data["email"] = email + json.dump(token_data, f) _logger.info(f"Tokens saved to {TOKEN_FILE}") except OSError as e: _logger.error(f"Failed to save tokens: {e}") @@ -58,24 +43,16 @@ def load_tokens() -> tuple[Optional[AuthTokens], Optional[str]]: try: with open(TOKEN_FILE) as f: data = json.load(f) - email = data["email"] - # Reconstruct the AuthTokens object - tokens = AuthTokens( - id_token=data["id_token"], - access_token=data["access_token"], - refresh_token=data["refresh_token"], - authentication_expires_in=data["authentication_expires_in"], - # AWS Credentials (use .get for backward compatibility) - access_key_id=data.get("access_key_id"), - secret_key=data.get("secret_key"), - session_token=data.get("session_token"), - authorization_expires_in=data.get("authorization_expires_in"), - ) - # Manually set the issued_at from the stored ISO format string - tokens.issued_at = datetime.fromisoformat(data["issued_at"]) + email = data.get("email") + if not email: + _logger.error("No email found in token file") + return None, None + + # Use the built-in from_dict() method for deserialization + tokens = AuthTokens.from_dict(data) _logger.info(f"Tokens loaded from {TOKEN_FILE} for user {email}") return tokens, email - except (OSError, json.JSONDecodeError, KeyError) as e: + except (OSError, json.JSONDecodeError, KeyError, ValueError) as e: _logger.error( f"Failed to load or parse tokens, will re-authenticate: {e}" ) From 2b5e571c6fdd75c6322ac4a0811b0556122a6954 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Mon, 27 Oct 2025 15:45:44 -0700 Subject: [PATCH 3/4] Fix AuthTokens.from_dict to properly handle empty strings - Replace 'or' operator with explicit None/empty string checks - Add helper function get_value() to correctly check both camelCase and snake_case - Prevent empty strings in camelCase keys from blocking snake_case fallback - Add test for empty string and None value handling This fixes the issue where empty strings would be treated as truthy, preventing the fallback to snake_case alternatives. --- src/nwp500/auth.py | 38 ++++++++++++++++++++++++++------------ tests/test_auth.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/src/nwp500/auth.py b/src/nwp500/auth.py index 5b162de..3d4b22e 100644 --- a/src/nwp500/auth.py +++ b/src/nwp500/auth.py @@ -114,20 +114,34 @@ def from_dict(cls, data: dict[str, Any]) -> "AuthTokens": >>> stored = tokens.to_dict() >>> restored = AuthTokens.from_dict(stored) """ + + # Helper to get value from either camelCase or snake_case key + def get_value( + camel_key: str, snake_key: str, default: Any = None + ) -> Any: + """Get value, checking camelCase first, then snake_case.""" + value = data.get(camel_key) + if value is not None and value != "": + return value + value = data.get(snake_key) + if value is not None and value != "": + return value + return default + # Support both camelCase (API) and snake_case (stored) keys return cls( - id_token=data.get("idToken") or data.get("id_token", ""), - access_token=data.get("accessToken") - or data.get("access_token", ""), - refresh_token=data.get("refreshToken") - or data.get("refresh_token", ""), - authentication_expires_in=data.get("authenticationExpiresIn") - or data.get("authentication_expires_in", 3600), - access_key_id=data.get("accessKeyId") or data.get("access_key_id"), - secret_key=data.get("secretKey") or data.get("secret_key"), - session_token=data.get("sessionToken") or data.get("session_token"), - authorization_expires_in=data.get("authorizationExpiresIn") - or data.get("authorization_expires_in"), + id_token=get_value("idToken", "id_token", ""), + access_token=get_value("accessToken", "access_token", ""), + refresh_token=get_value("refreshToken", "refresh_token", ""), + authentication_expires_in=get_value( + "authenticationExpiresIn", "authentication_expires_in", 3600 + ), + access_key_id=get_value("accessKeyId", "access_key_id"), + secret_key=get_value("secretKey", "secret_key"), + session_token=get_value("sessionToken", "session_token"), + authorization_expires_in=get_value( + "authorizationExpiresIn", "authorization_expires_in" + ), issued_at=datetime.fromisoformat(data["issued_at"]) if "issued_at" in data else datetime.now(), diff --git a/tests/test_auth.py b/tests/test_auth.py index 121ae7f..b32ef1e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -934,6 +934,35 @@ def test_auth_tokens_serialization_roundtrip(): assert restored.is_expired == original.is_expired +def test_auth_tokens_from_dict_with_empty_strings(): + """Test AuthTokens.from_dict handles empty strings in camelCase.""" + # Simulate API response with empty optional fields (camelCase) + # Should fall back to snake_case alternatives + data = { + "idToken": "test_id", + "accessToken": "", # Empty string - should check snake_case + "refreshToken": "test_refresh", + "authenticationExpiresIn": 3600, + "accessKeyId": "", # Empty string - should check snake_case + "secretKey": None, # None - should check snake_case + "sessionToken": "test_session", + # Provide values in snake_case as fallback + "access_token": "fallback_access", + "access_key_id": "fallback_key", + "secret_key": "fallback_secret", + } + + tokens = AuthTokens.from_dict(data) + + assert tokens.id_token == "test_id" + assert tokens.access_token == "fallback_access" # Should use snake_case + assert tokens.refresh_token == "test_refresh" + assert tokens.authentication_expires_in == 3600 + assert tokens.access_key_id == "fallback_key" # Should use snake_case + assert tokens.secret_key == "fallback_secret" # Should use snake_case + assert tokens.session_token == "test_session" + + def test_navien_auth_client_initialization_with_stored_tokens(): """Test NavienAuthClient initialization with stored tokens.""" stored_tokens = AuthTokens( From fe6ee5193bb95e229430110a9dc70240f656417b Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Mon, 27 Oct 2025 15:49:47 -0700 Subject: [PATCH 4/4] Refactor CLI to use proper async with pattern - Replace manual __aenter__/__aexit__ calls with async with context manager - Eliminate error-prone manual context manager lifecycle management - Move authentication logic directly into async_main() - Remove now-unused get_authenticated_client() helper function - Remove unused Optional import Benefits: - Guaranteed proper cleanup even on exceptions - More Pythonic and follows best practices - Simpler control flow with single entry point - Eliminates potential resource leaks from failed __aenter__ calls --- src/nwp500/cli/__main__.py | 286 +++++++++++++++++-------------------- 1 file changed, 130 insertions(+), 156 deletions(-) diff --git a/src/nwp500/cli/__main__.py b/src/nwp500/cli/__main__.py index 3a2cc84..b452cea 100644 --- a/src/nwp500/cli/__main__.py +++ b/src/nwp500/cli/__main__.py @@ -9,7 +9,6 @@ import logging import os import sys -from typing import Optional from nwp500 import NavienAPIClient, NavienAuthClient, __version__ from nwp500.auth import InvalidCredentialsError @@ -39,17 +38,15 @@ _logger = logging.getLogger(__name__) -async def get_authenticated_client( - args: argparse.Namespace, -) -> Optional[NavienAuthClient]: +async def async_main(args: argparse.Namespace) -> int: """ - Get an authenticated NavienAuthClient using cached tokens or credentials. + Asynchronous main function. Args: args: Parsed command-line arguments Returns: - NavienAuthClient instance or None if authentication fails + Exit code (0 for success, 1 for failure) """ # Get credentials email = args.email or os.getenv("NAVIEN_EMAIL") @@ -66,169 +63,146 @@ async def get_authenticated_client( "Credentials not found. Please provide --email and --password, " "or set NAVIEN_EMAIL and NAVIEN_PASSWORD environment variables." ) - return None + return 1 try: - # Use the new stored_tokens parameter for token restoration - # This will automatically handle token validation and refresh - auth_client = NavienAuthClient(email, password, stored_tokens=tokens) - - # Enter the context manager to authenticate/restore session - await auth_client.__aenter__() + # Use async with to properly manage auth client lifecycle + async with NavienAuthClient( + email, password, stored_tokens=tokens + ) as auth_client: + # Save refreshed/new tokens after authentication + if auth_client.current_tokens and auth_client.user_email: + save_tokens(auth_client.current_tokens, auth_client.user_email) + + api_client = NavienAPIClient(auth_client=auth_client) + _logger.info("Fetching device information...") + device = await api_client.get_first_device() + + # Save tokens if they were refreshed during API call + if auth_client.current_tokens and auth_client.user_email: + save_tokens(auth_client.current_tokens, auth_client.user_email) + + if not device: + _logger.error("No devices found for this account.") + return 1 + + _logger.info(f"Found device: {device.device_info.device_name}") + + from nwp500 import NavienMqttClient + + mqtt = NavienMqttClient(auth_client) + try: + await mqtt.connect() + _logger.info("MQTT client connected.") + + # Route to appropriate handler based on arguments + if args.device_info: + await handle_device_info_request(mqtt, device) + elif args.device_feature: + await handle_device_feature_request(mqtt, device) + elif args.get_controller_serial: + await handle_get_controller_serial_request(mqtt, device) + elif args.power_on: + await handle_power_request(mqtt, device, power_on=True) + if args.status: + _logger.info("Getting updated status after power on...") + await asyncio.sleep(2) + await handle_status_request(mqtt, device) + elif args.power_off: + await handle_power_request(mqtt, device, power_on=False) + if args.status: + _logger.info( + "Getting updated status after power off..." + ) + await asyncio.sleep(2) + await handle_status_request(mqtt, device) + elif args.set_mode: + await handle_set_mode_request(mqtt, device, args.set_mode) + if args.status: + _logger.info( + "Getting updated status after mode change..." + ) + await asyncio.sleep(2) + await handle_status_request(mqtt, device) + elif args.set_dhw_temp: + await handle_set_dhw_temp_request( + mqtt, device, args.set_dhw_temp + ) + if args.status: + _logger.info( + "Getting updated status after temperature change..." + ) + await asyncio.sleep(2) + await handle_status_request(mqtt, device) + elif args.get_reservations: + await handle_get_reservations_request(mqtt, device) + elif args.set_reservations: + await handle_update_reservations_request( + mqtt, + device, + args.set_reservations, + args.reservations_enabled, + ) + elif args.get_tou: + await handle_get_tou_request(mqtt, device, api_client) + elif args.set_tou_enabled: + enabled = args.set_tou_enabled.lower() == "on" + await handle_set_tou_enabled_request(mqtt, device, enabled) + if args.status: + _logger.info( + "Getting updated status after TOU change..." + ) + await asyncio.sleep(2) + await handle_status_request(mqtt, device) + elif args.get_energy: + if not args.energy_year or not args.energy_months: + _logger.error( + "--energy-year and --energy-months are required " + "for --get-energy" + ) + return 1 + try: + months = [ + int(m.strip()) + for m in args.energy_months.split(",") + ] + if not all(1 <= m <= 12 for m in months): + _logger.error("Months must be between 1 and 12") + return 1 + except ValueError: + _logger.error( + "Invalid month format. Use comma-separated " + "numbers (e.g., '9' or '8,9,10')" + ) + return 1 + await handle_get_energy_request( + mqtt, device, args.energy_year, months + ) + elif args.status_raw: + await handle_status_raw_request(mqtt, device) + elif args.status: + await handle_status_request(mqtt, device) + else: # Default to monitor + await handle_monitoring(mqtt, device, args.output) - # Save refreshed/new tokens - if auth_client.current_tokens and auth_client.user_email: - save_tokens(auth_client.current_tokens, auth_client.user_email) + except asyncio.CancelledError: + _logger.info("Monitoring stopped by user.") + finally: + _logger.info("Disconnecting MQTT client...") + await mqtt.disconnect() - return auth_client + _logger.info("Cleanup complete.") + return 0 except InvalidCredentialsError: _logger.error("Invalid email or password.") - return None - except Exception as e: - _logger.error( - f"An unexpected error occurred during authentication: {e}" - ) - return None - - -async def async_main(args: argparse.Namespace) -> int: - """ - Asynchronous main function. - - Args: - args: Parsed command-line arguments - - Returns: - Exit code (0 for success, 1 for failure) - """ - auth_client = await get_authenticated_client(args) - if not auth_client: - return 1 # Authentication failed - - api_client = None - try: - api_client = NavienAPIClient(auth_client=auth_client) - _logger.info("Fetching device information...") - device = await api_client.get_first_device() - - # Save tokens if they were refreshed during API call - if auth_client.current_tokens and auth_client.user_email: - save_tokens(auth_client.current_tokens, auth_client.user_email) - - if not device: - _logger.error("No devices found for this account.") - return 1 - - _logger.info(f"Found device: {device.device_info.device_name}") - - from nwp500 import NavienMqttClient - - mqtt = NavienMqttClient(auth_client) - try: - await mqtt.connect() - _logger.info("MQTT client connected.") - - # Route to appropriate handler based on arguments - if args.device_info: - await handle_device_info_request(mqtt, device) - elif args.device_feature: - await handle_device_feature_request(mqtt, device) - elif args.get_controller_serial: - await handle_get_controller_serial_request(mqtt, device) - elif args.power_on: - await handle_power_request(mqtt, device, power_on=True) - if args.status: - _logger.info("Getting updated status after power on...") - await asyncio.sleep(2) - await handle_status_request(mqtt, device) - elif args.power_off: - await handle_power_request(mqtt, device, power_on=False) - if args.status: - _logger.info("Getting updated status after power off...") - await asyncio.sleep(2) - await handle_status_request(mqtt, device) - elif args.set_mode: - await handle_set_mode_request(mqtt, device, args.set_mode) - if args.status: - _logger.info("Getting updated status after mode change...") - await asyncio.sleep(2) - await handle_status_request(mqtt, device) - elif args.set_dhw_temp: - await handle_set_dhw_temp_request( - mqtt, device, args.set_dhw_temp - ) - if args.status: - _logger.info( - "Getting updated status after temperature change..." - ) - await asyncio.sleep(2) - await handle_status_request(mqtt, device) - elif args.get_reservations: - await handle_get_reservations_request(mqtt, device) - elif args.set_reservations: - await handle_update_reservations_request( - mqtt, - device, - args.set_reservations, - args.reservations_enabled, - ) - elif args.get_tou: - await handle_get_tou_request(mqtt, device, api_client) - elif args.set_tou_enabled: - enabled = args.set_tou_enabled.lower() == "on" - await handle_set_tou_enabled_request(mqtt, device, enabled) - if args.status: - _logger.info("Getting updated status after TOU change...") - await asyncio.sleep(2) - await handle_status_request(mqtt, device) - elif args.get_energy: - if not args.energy_year or not args.energy_months: - _logger.error( - "--energy-year and --energy-months are required " - "for --get-energy" - ) - return 1 - try: - months = [ - int(m.strip()) for m in args.energy_months.split(",") - ] - if not all(1 <= m <= 12 for m in months): - _logger.error("Months must be between 1 and 12") - return 1 - except ValueError: - _logger.error( - "Invalid month format. Use comma-separated numbers " - "(e.g., '9' or '8,9,10')" - ) - return 1 - await handle_get_energy_request( - mqtt, device, args.energy_year, months - ) - elif args.status_raw: - await handle_status_raw_request(mqtt, device) - elif args.status: - await handle_status_request(mqtt, device) - else: # Default to monitor - await handle_monitoring(mqtt, device, args.output) - - except asyncio.CancelledError: - _logger.info("Monitoring stopped by user.") - finally: - _logger.info("Disconnecting MQTT client...") - await mqtt.disconnect() + return 1 except asyncio.CancelledError: _logger.info("Operation cancelled by user.") return 1 except Exception as e: _logger.error(f"An unexpected error occurred: {e}", exc_info=True) return 1 - finally: - # Auth client close will close the underlying aiohttp session - # Also call __aexit__ to properly clean up context manager - await auth_client.__aexit__(None, None, None) - _logger.info("Cleanup complete.") - return 0 def parse_args(args: list[str]) -> argparse.Namespace: