diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 280714b..285532d 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -43,6 +43,14 @@ Report the results of these checks in your final summary. - **MQTT topics**: `cmd/{deviceType}/{deviceId}/ctrl` for control, `cmd/{deviceType}/{deviceId}/st` for status - **Command queuing**: Commands sent while disconnected are queued and sent when reconnected - **No base64 encoding/decoding** of MQTT payloads; all payloads are JSON-encoded/decoded +- **Exception handling**: Use specific exception types instead of catch-all `except Exception`. Common types: + - `AwsCrtError` - AWS IoT Core/MQTT errors + - `AuthenticationError`, `TokenRefreshError` - Authentication errors + - `RuntimeError` - Runtime state errors (not connected, etc.) + - `ValueError` - Invalid values or parameters + - `TypeError`, `AttributeError`, `KeyError` - Data structure errors + - `asyncio.CancelledError` - Task cancellation + - Only catch exceptions you can handle; let unexpected exceptions propagate ## Integration Points - **AWS IoT Core**: MQTT client uses `awscrt` and `awsiot` libraries for connection and messaging diff --git a/src/nwp500/auth.py b/src/nwp500/auth.py index 78d101b..51f165c 100644 --- a/src/nwp500/auth.py +++ b/src/nwp500/auth.py @@ -471,6 +471,34 @@ async def refresh_token(self, refresh_token: str) -> AuthTokens: _logger.error(f"Failed to parse refresh response: {e}") raise TokenRefreshError(f"Invalid response format: {str(e)}") + async def re_authenticate(self) -> AuthenticationResponse: + """ + Re-authenticate using stored credentials. + + This is a convenience method that uses the stored user_id and password + from initialization to perform a fresh sign-in. Useful for recovering + from expired tokens or connection issues. + + Returns: + AuthenticationResponse with fresh tokens and user info + + Raises: + ValueError: If stored credentials are not available + AuthenticationError: If authentication fails + + Example: + >>> client = NavienAuthClient(email, password) + >>> await client.re_authenticate() # Uses stored credentials + """ + if not self.has_stored_credentials: + raise ValueError( + "No stored credentials available for re-authentication. " + "Credentials must be provided during initialization." + ) + + _logger.info("Re-authenticating with stored credentials") + return await self.sign_in(self._user_id, self._password) + async def ensure_valid_token(self) -> Optional[AuthTokens]: """ Ensure we have a valid access token, refreshing if necessary. @@ -526,6 +554,15 @@ def user_email(self) -> Optional[str]: """Get the email address of the authenticated user.""" return self._user_email + @property + def has_stored_credentials(self) -> bool: + """Check if user credentials are stored for re-authentication. + + Returns: + True if both user_id and password are available for re-auth + """ + return bool(self._user_id and self._password) + async def close(self) -> None: """Close the aiohttp session if we own it.""" if self._owned_session and self._session: diff --git a/src/nwp500/events.py b/src/nwp500/events.py index 11cdbc1..b2e1b3b 100644 --- a/src/nwp500/events.py +++ b/src/nwp500/events.py @@ -253,6 +253,11 @@ async def emit(self, event: str, *args: Any, **kwargs: Any) -> int: self._once_callbacks.discard((event, listener.callback)) except Exception as e: + # Catch all exceptions from user callbacks to ensure + # resilience. We intentionally catch Exception here because: + # 1. User callbacks can raise any exception type + # 2. One bad callback shouldn't break other callbacks + # 3. This is an event emitter pattern where resilience is key _logger.error( f"Error in '{event}' event handler: {e}", exc_info=True, diff --git a/src/nwp500/mqtt_client.py b/src/nwp500/mqtt_client.py index 0fbf7c0..8f69e21 100644 --- a/src/nwp500/mqtt_client.py +++ b/src/nwp500/mqtt_client.py @@ -19,7 +19,11 @@ from awscrt import mqtt from awscrt.exceptions import AwsCrtError -from .auth import NavienAuthClient +from .auth import ( + AuthenticationError, + NavienAuthClient, + TokenRefreshError, +) from .events import EventEmitter from .models import ( Device, @@ -205,7 +209,8 @@ def _schedule_coroutine(self, coro: Any) -> None: # Schedule the coroutine in the stored loop using thread-safe method try: asyncio.run_coroutine_threadsafe(coro, self._loop) - except Exception as e: + except RuntimeError as e: + # Event loop is closed or not running _logger.error(f"Failed to schedule coroutine: {e}", exc_info=True) def _on_connection_interrupted_internal( @@ -218,7 +223,6 @@ def _on_connection_interrupted_internal( error: Error that caused the interruption **kwargs: Forward-compatibility kwargs from AWS SDK """ - _logger.warning(f"Connection interrupted: {error}") self._connected = False # Emit event @@ -232,7 +236,7 @@ def _on_connection_interrupted_internal( # Fallback for callbacks expecting no arguments try: self._on_connection_interrupted() # type: ignore - except Exception as e: + except (TypeError, AttributeError) as e: _logger.error( f"Error in connection_interrupted callback: {e}" ) @@ -339,12 +343,113 @@ async def _active_reconnect(self) -> None: "No connection manager available for reconnection" ) - except Exception as e: + except (AwsCrtError, AuthenticationError, RuntimeError) as e: _logger.error( f"Error during active reconnection: {e}", exc_info=True ) raise + async def _deep_reconnect(self) -> None: + """ + Perform a deep reconnection by completely rebuilding the connection. + + This method is called after multiple quick reconnection failures. + It performs a full teardown and rebuild: + - Disconnects existing connection + - Refreshes authentication tokens + - Creates new connection manager + - Re-establishes all subscriptions + + This is more expensive but can recover from issues that a simple + reconnection cannot fix (e.g., stale credentials, corrupted state). + """ + if self._connected: + _logger.debug("Already connected, skipping deep reconnection") + return + + _logger.warning( + "Performing deep reconnection (full rebuild)... " + "This may take longer." + ) + + try: + # Step 1: Clean up existing connection if any + if self._connection_manager: + _logger.debug("Cleaning up old connection...") + try: + if self._connection_manager.is_connected: + await self._connection_manager.disconnect() + except (AwsCrtError, RuntimeError) as e: + # Expected: connection already dead or in bad state + _logger.debug(f"Error during cleanup: {e} (expected)") + + # Step 2: Force token refresh to get fresh AWS credentials + _logger.debug("Refreshing authentication tokens...") + try: + # Use the stored refresh token from current tokens + current_tokens = self._auth_client.current_tokens + if current_tokens and current_tokens.refresh_token: + await self._auth_client.refresh_token( + current_tokens.refresh_token + ) + else: + _logger.warning("No refresh token available") + raise ValueError("No refresh token available for refresh") + except (TokenRefreshError, ValueError, AuthenticationError) as e: + # If refresh fails, try full re-authentication with stored + # credentials + if self._auth_client.has_stored_credentials: + _logger.warning( + f"Token refresh failed: {e}. Attempting full " + "re-authentication..." + ) + await self._auth_client.re_authenticate() + else: + _logger.error( + "Cannot re-authenticate: no stored credentials" + ) + raise + + # Step 3: Create completely new connection manager + _logger.debug("Creating new connection manager...") + self._connection_manager = MqttConnection( + config=self.config, + auth_client=self._auth_client, + on_connection_interrupted=self._on_connection_interrupted_internal, + on_connection_resumed=self._on_connection_resumed_internal, + ) + + # Step 4: Attempt connection + success = await self._connection_manager.connect() + + if success: + # Update connection references + self._connection = self._connection_manager.connection + self._connected = True + + # Step 5: Re-establish subscriptions + if self._subscription_manager and self._connection: + _logger.debug("Re-establishing subscriptions...") + self._subscription_manager.update_connection( + self._connection + ) + await self._subscription_manager.resubscribe_all() + + _logger.info( + "Deep reconnection successful - fully rebuilt connection" + ) + else: + _logger.error("Deep reconnection failed to connect") + + except ( + AwsCrtError, + AuthenticationError, + RuntimeError, + ValueError, + ) as e: + _logger.error(f"Error during deep reconnection: {e}", exc_info=True) + raise + async def connect(self) -> bool: """ Establish connection to AWS IoT Core. @@ -394,6 +499,7 @@ async def connect(self) -> bool: is_connected_func=lambda: self._connected, schedule_coroutine_func=self._schedule_coroutine, reconnect_func=self._active_reconnect, + deep_reconnect_func=self._deep_reconnect, emit_event_func=self.emit, ) self._reconnection_handler.enable() @@ -428,7 +534,12 @@ async def connect(self) -> bool: return False - except Exception as e: + except ( + AwsCrtError, + AuthenticationError, + RuntimeError, + ValueError, + ) as e: _logger.error(f"Failed to connect: {e}") raise @@ -473,7 +584,7 @@ async def disconnect(self) -> None: self._connection = None _logger.info("Disconnected successfully") - except Exception as e: + except (AwsCrtError, RuntimeError) as e: _logger.error(f"Error during disconnect: {e}") raise @@ -493,7 +604,7 @@ def _on_message_received( except json.JSONDecodeError as e: _logger.error(f"Failed to parse message payload: {e}") - except Exception as e: + except (AttributeError, KeyError, TypeError) as e: _logger.error(f"Error processing message: {e}") def _topic_matches_pattern(self, topic: str, pattern: str) -> bool: @@ -618,12 +729,11 @@ async def publish( try: return await self._connection_manager.publish(topic, payload, qos) - except Exception as e: + except AwsCrtError as e: # Handle clean session cancellation gracefully - # Check exception type and name attribute for proper - # error identification + # Safely check e.name attribute (may not exist or be None) if ( - isinstance(e, AwsCrtError) + hasattr(e, "name") and e.name == "AWS_ERROR_MQTT_CANCELLED_FOR_CLEAN_SESSION" ): _logger.warning( @@ -641,9 +751,9 @@ async def publish( raise RuntimeError( "Publish cancelled due to clean session and " "command queue is disabled" - ) + ) from e - # Note: redact_topic is already used elsewhere in the file + # Other AWS CRT errors _logger.error(f"Failed to publish to topic: {e}") raise diff --git a/src/nwp500/mqtt_connection.py b/src/nwp500/mqtt_connection.py index 00b4c16..9eff060 100644 --- a/src/nwp500/mqtt_connection.py +++ b/src/nwp500/mqtt_connection.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union from awscrt import mqtt +from awscrt.exceptions import AwsCrtError from awsiot import mqtt_connection_builder if TYPE_CHECKING: @@ -147,7 +148,7 @@ async def connect(self) -> bool: return True - except Exception as e: + except (AwsCrtError, RuntimeError, ValueError) as e: _logger.error(f"Failed to connect: {e}") raise @@ -195,7 +196,7 @@ async def disconnect(self) -> None: self._connected = False self._connection = None _logger.info("Disconnected successfully") - except Exception as e: + except (AwsCrtError, RuntimeError) as e: _logger.error(f"Error during disconnect: {e}") raise diff --git a/src/nwp500/mqtt_periodic.py b/src/nwp500/mqtt_periodic.py index c52958f..7926048 100644 --- a/src/nwp500/mqtt_periodic.py +++ b/src/nwp500/mqtt_periodic.py @@ -186,13 +186,13 @@ async def periodic_request() -> None: f"for {redacted_device_id}" ) break - except Exception as e: + except (AwsCrtError, RuntimeError) as e: # Handle clean session cancellation gracefully (expected # during reconnection) - # Check exception type and name attribute for proper error - # identification + # Safely check exception name attribute if ( isinstance(e, AwsCrtError) + and hasattr(e, "name") and e.name == "AWS_ERROR_MQTT_CANCELLED_FOR_CLEAN_SESSION" ): diff --git a/src/nwp500/mqtt_reconnection.py b/src/nwp500/mqtt_reconnection.py index f223a5d..efde312 100644 --- a/src/nwp500/mqtt_reconnection.py +++ b/src/nwp500/mqtt_reconnection.py @@ -11,6 +11,8 @@ from collections.abc import Awaitable from typing import TYPE_CHECKING, Any, Callable, Optional +from awscrt.exceptions import AwsCrtError + if TYPE_CHECKING: from .mqtt_utils import MqttConnectionConfig @@ -35,6 +37,7 @@ def __init__( is_connected_func: Callable[[], bool], schedule_coroutine_func: Callable[[Any], None], reconnect_func: Callable[[], Awaitable[None]], + deep_reconnect_func: Optional[Callable[[], Awaitable[None]]] = None, emit_event_func: Optional[Callable[..., Awaitable[Any]]] = None, ): """ @@ -46,6 +49,8 @@ def __init__( schedule_coroutine_func: Function to schedule coroutines from any thread reconnect_func: Async function to trigger active reconnection + deep_reconnect_func: Optional async function to trigger deep + reconnection (full rebuild) emit_event_func: Optional async function to emit events (e.g., EventEmitter.emit) """ @@ -53,6 +58,7 @@ def __init__( self._is_connected_func = is_connected_func self._schedule_coroutine = schedule_coroutine_func self._reconnect_func = reconnect_func + self._deep_reconnect_func = deep_reconnect_func self._emit_event = emit_event_func self._reconnect_attempts = 0 @@ -135,15 +141,38 @@ async def _reconnect_with_backoff(self) -> None: Attempt to reconnect with exponential backoff. This method is called automatically when connection is interrupted - if auto_reconnect is enabled. + if auto_reconnect is enabled. Supports unlimited retries when + max_reconnect_attempts is -1. + + Uses a two-tier strategy: + - Quick reconnects (attempts 1-N): Fast reconnection with existing setup + - Deep reconnects (attempts N+): Full rebuild including token refresh """ + unlimited_retries = self.config.max_reconnect_attempts < 0 + while ( not self._is_connected_func() and not self._manual_disconnect - and self._reconnect_attempts < self.config.max_reconnect_attempts + and ( + unlimited_retries + or self._reconnect_attempts < self.config.max_reconnect_attempts + ) ): self._reconnect_attempts += 1 + # Determine if we should do a deep reconnection + has_deep_reconnect = self._deep_reconnect_func is not None + is_at_threshold = ( + self._reconnect_attempts >= self.config.deep_reconnect_threshold + ) + is_threshold_multiple = ( + self._reconnect_attempts % self.config.deep_reconnect_threshold + == 0 + ) + use_deep_reconnect = ( + has_deep_reconnect and is_at_threshold and is_threshold_multiple + ) + # Calculate delay with exponential backoff delay = min( self.config.initial_reconnect_delay @@ -154,12 +183,21 @@ async def _reconnect_with_backoff(self) -> None: self.config.max_reconnect_delay, ) - _logger.info( - "Reconnection attempt %d/%d in %.1f seconds...", - self._reconnect_attempts, - self.config.max_reconnect_attempts, - delay, - ) + if unlimited_retries: + reconnect_type = "deep" if use_deep_reconnect else "quick" + _logger.info( + "Reconnection attempt %d (%s) in %.1f seconds...", + self._reconnect_attempts, + reconnect_type, + delay, + ) + else: + _logger.info( + "Reconnection attempt %d/%d in %.1f seconds...", + self._reconnect_attempts, + self.config.max_reconnect_attempts, + delay, + ) try: await asyncio.sleep(delay) @@ -171,30 +209,50 @@ async def _reconnect_with_backoff(self) -> None: ) break - # Trigger active reconnection - _logger.info("Triggering active reconnection...") - try: - await self._reconnect_func() - if self._is_connected_func(): - _logger.info("Successfully reconnected") - break - except Exception as e: - _logger.warning( - f"Active reconnection failed: {e}. " - "Will retry if attempts remain." + # Trigger appropriate reconnection type + if use_deep_reconnect and self._deep_reconnect_func is not None: + _logger.info( + "Triggering deep reconnection " + "(full rebuild with token refresh)..." ) + try: + await self._deep_reconnect_func() + if self._is_connected_func(): + _logger.info( + "Successfully reconnected via deep reconnection" + ) + break + except (AwsCrtError, RuntimeError, ValueError) as e: + _logger.warning( + f"Deep reconnection failed: {e}. Will retry..." + ) + else: + _logger.info("Triggering quick reconnection...") + try: + await self._reconnect_func() + if self._is_connected_func(): + _logger.info( + "Successfully reconnected via " + "quick reconnection" + ) + break + except (AwsCrtError, RuntimeError) as e: + _logger.warning( + f"Quick reconnection failed: {e}. Will retry..." + ) except asyncio.CancelledError: _logger.info("Reconnection task cancelled") break - except Exception as e: + except (AwsCrtError, RuntimeError) as e: _logger.error( f"Error during reconnection attempt: {e}", exc_info=True ) - # Check final state + # Check final state (only if not unlimited retries) if ( - self._reconnect_attempts >= self.config.max_reconnect_attempts + not unlimited_retries + and self._reconnect_attempts >= self.config.max_reconnect_attempts and not self._is_connected_func() ): _logger.error( @@ -208,7 +266,7 @@ async def _reconnect_with_backoff(self) -> None: await self._emit_event( "reconnection_failed", self._reconnect_attempts ) - except Exception as e: + except (TypeError, RuntimeError) as e: _logger.error( f"Error emitting reconnection_failed event: {e}" ) diff --git a/src/nwp500/mqtt_subscriptions.py b/src/nwp500/mqtt_subscriptions.py index c74999c..3bf74cf 100644 --- a/src/nwp500/mqtt_subscriptions.py +++ b/src/nwp500/mqtt_subscriptions.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Optional from awscrt import mqtt +from awscrt.exceptions import AwsCrtError from .events import EventEmitter from .models import Device, DeviceFeature, DeviceStatus, EnergyUsageResponse @@ -117,12 +118,12 @@ def _on_message_received( for handler in handlers: try: handler(topic, message) - except Exception as e: + except (TypeError, AttributeError, KeyError) as e: _logger.error(f"Error in message handler: {e}") except json.JSONDecodeError as e: _logger.error(f"Failed to parse message payload: {e}") - except Exception as e: + except (AttributeError, KeyError, TypeError) as e: _logger.error(f"Error processing message: {e}") def _topic_matches_pattern(self, topic: str, pattern: str) -> bool: @@ -230,7 +231,7 @@ async def subscribe( return int(packet_id) - except Exception as e: + except (AwsCrtError, RuntimeError) as e: _logger.error( f"Failed to subscribe to '{redact_topic(topic)}': {e}" ) @@ -268,12 +269,75 @@ async def unsubscribe(self, topic: str) -> int: return int(packet_id) - except Exception as e: + except (AwsCrtError, RuntimeError) as e: _logger.error( f"Failed to unsubscribe from '{redact_topic(topic)}': {e}" ) raise + async def resubscribe_all(self) -> None: + """ + Re-establish all subscriptions after a connection rebuild. + + This method is called after a deep reconnection to restore all + active subscriptions. It uses the stored subscription information + to re-subscribe to all topics with their original QoS settings + and handlers. + + Note: + This is typically called automatically during deep reconnection + and should not need to be called manually. + + Raises: + RuntimeError: If not connected to MQTT broker + Exception: If any subscription fails + """ + if not self._connection: + raise RuntimeError("Not connected to MQTT broker") + + if not self._subscriptions: + _logger.debug("No subscriptions to restore") + return + + subscription_count = len(self._subscriptions) + _logger.info(f"Re-establishing {subscription_count} subscription(s)...") + + # Store subscriptions to re-establish (avoid modifying dict during + # iteration) + subscriptions_to_restore = list(self._subscriptions.items()) + handlers_to_restore = { + topic: handlers.copy() + for topic, handlers in self._message_handlers.items() + } + + # Clear current subscriptions (will be re-added by subscribe()) + self._subscriptions.clear() + self._message_handlers.clear() + + # Re-establish each subscription + failed_subscriptions = set() + for topic, qos in subscriptions_to_restore: + handlers = handlers_to_restore.get(topic, []) + for handler in handlers: + try: + await self.subscribe(topic, handler, qos) + except (AwsCrtError, RuntimeError) as e: + _logger.error( + f"Failed to re-subscribe to " + f"'{redact_topic(topic)}': {e}" + ) + # Mark topic as failed and skip remaining handlers + # since they will fail for the same reason + failed_subscriptions.add(topic) + break # Exit handler loop, move to next topic + + if failed_subscriptions: + _logger.warning( + f"Failed to restore {len(failed_subscriptions)} subscription(s)" + ) + else: + _logger.info("All subscriptions re-established successfully") + async def subscribe_device( self, device: Device, callback: Callable[[str, dict[str, Any]], None] ) -> int: @@ -405,7 +469,7 @@ def status_message_handler(topic: str, message: dict[str, Any]) -> None: _logger.warning( f"Invalid value in status message: {e}", exc_info=True ) - except Exception as e: + except (TypeError, AttributeError) as e: _logger.error( f"Error parsing device status: {e}", exc_info=True ) @@ -492,7 +556,7 @@ async def _detect_state_changes(self, status: DeviceStatus) -> None: await self._event_emitter.emit("error_cleared", prev.errorCode) _logger.info(f"Error cleared: {prev.errorCode}") - except Exception as e: + except (TypeError, AttributeError, RuntimeError) as e: _logger.error(f"Error detecting state changes: {e}", exc_info=True) finally: # Always update previous status @@ -594,7 +658,7 @@ def feature_message_handler( _logger.warning( f"Invalid value in feature message: {e}", exc_info=True ) - except Exception as e: + except (TypeError, AttributeError) as e: _logger.error( f"Error parsing device feature: {e}", exc_info=True ) @@ -684,7 +748,7 @@ def energy_message_handler(topic: str, message: dict[str, Any]) -> None: _logger.warning( "Failed to parse energy usage message - missing key: %s", e ) - except Exception as e: + except (TypeError, ValueError, AttributeError) as e: _logger.error( "Error in energy usage message handler: %s", e, diff --git a/src/nwp500/mqtt_utils.py b/src/nwp500/mqtt_utils.py index 72fdc45..3cdc5db 100644 --- a/src/nwp500/mqtt_utils.py +++ b/src/nwp500/mqtt_utils.py @@ -153,9 +153,12 @@ class MqttConnectionConfig: auto_reconnect: Enable automatic reconnection max_reconnect_attempts: Maximum reconnection attempts + (-1 for unlimited) initial_reconnect_delay: Initial delay between reconnect attempts max_reconnect_delay: Maximum delay between reconnect attempts reconnect_backoff_multiplier: Exponential backoff multiplier + deep_reconnect_threshold: Attempt count to trigger full + connection rebuild enable_command_queue: Enable command queueing when disconnected max_queued_commands: Maximum number of queued commands @@ -169,10 +172,13 @@ class MqttConnectionConfig: # Reconnection settings auto_reconnect: bool = True - max_reconnect_attempts: int = 10 + max_reconnect_attempts: int = -1 # -1 = unlimited retries initial_reconnect_delay: float = 1.0 # seconds max_reconnect_delay: float = 120.0 # seconds reconnect_backoff_multiplier: float = 2.0 + deep_reconnect_threshold: int = ( + 10 # Switch to full rebuild after N attempts + ) # Command queue settings enable_command_queue: bool = True