diff --git a/README.rst b/README.rst index 9bcae41..ab57c14 100644 --- a/README.rst +++ b/README.rst @@ -70,7 +70,7 @@ Basic Usage if device: # Access status information status = device.status - print(f"Water Temperature: {status.dhw_temperature}°F") + print(f"Water Temperature: {status.dhw_temperature}") print(f"Tank Charge: {status.dhw_charge_per}%") print(f"Power Consumption: {status.current_inst_power}W") diff --git a/src/nwp500/api_client.py b/src/nwp500/api_client.py index bec1e5f..2e4a325 100644 --- a/src/nwp500/api_client.py +++ b/src/nwp500/api_client.py @@ -78,7 +78,8 @@ def __init__( self.base_url = base_url.rstrip("/") self._auth_client = auth_client - self._session = session or getattr(auth_client, "_session", None) + self._session = session or auth_client.session + if self._session is None: raise ValueError( "auth_client must have an active session or a session " diff --git a/src/nwp500/auth.py b/src/nwp500/auth.py index 1a93bb7..5a2f228 100644 --- a/src/nwp500/auth.py +++ b/src/nwp500/auth.py @@ -633,6 +633,11 @@ async def ensure_valid_token(self) -> AuthTokens | None: return tokens + @property + def session(self) -> aiohttp.ClientSession | None: + """Get the active aiohttp session.""" + return self._session + @property def is_authenticated(self) -> bool: """Check if client is currently authenticated.""" diff --git a/src/nwp500/models.py b/src/nwp500/models.py index 9c06e98..4cdd8f5 100644 --- a/src/nwp500/models.py +++ b/src/nwp500/models.py @@ -775,7 +775,6 @@ class DeviceStatus(NavienBaseModel): recirc_dhw_flow_rate: FlowRate = Field( description="Recirculation DHW flow rate (dynamic units: LPM/GPM)", json_schema_extra={ - "unit_of_measurement": "GPM", "device_class": "flow_rate", }, ) diff --git a/src/nwp500/mqtt/client.py b/src/nwp500/mqtt/client.py index 701274a..256885e 100644 --- a/src/nwp500/mqtt/client.py +++ b/src/nwp500/mqtt/client.py @@ -375,6 +375,7 @@ async def _active_reconnect(self) -> None: self._subscription_manager.update_connection( self._connection ) + await self._subscription_manager.resubscribe_all() _logger.info("Active reconnection successful") else: @@ -894,6 +895,16 @@ async def subscribe_device_feature( "subscribe_device_feature", device, callback ) + async def unsubscribe_device_feature( + self, device: Device, callback: Callable[[DeviceFeature], None] + ) -> None: + """Unsubscribe a specific device feature callback.""" + if not self._connected or not self._subscription_manager: + return + await self._subscription_manager.unsubscribe_device_feature( + device, callback + ) + async def subscribe_energy_usage( self, device: Device, @@ -961,10 +972,8 @@ def on_feature(feature: DeviceFeature) -> None: ) return False finally: - # Note: We don't unsubscribe token here because it might - # interfere with other subscribers if we're not careful. - # But the subscription manager handles multiple callbacks. - pass + # Unsubscribe using the specific callback to avoid leaking resources + await self.unsubscribe_device_feature(device, on_feature) @property def control(self) -> MqttDeviceController: diff --git a/src/nwp500/mqtt/subscriptions.py b/src/nwp500/mqtt/subscriptions.py index 2c1823a..632cee9 100644 --- a/src/nwp500/mqtt/subscriptions.py +++ b/src/nwp500/mqtt/subscriptions.py @@ -15,7 +15,7 @@ import json import logging from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from awscrt import mqtt from awscrt.exceptions import AwsCrtError @@ -124,7 +124,7 @@ def _on_message_received( try: # Parse JSON payload message = json.loads(payload.decode("utf-8")) - _logger.debug("Received message on topic: %s", topic) + _logger.debug("Received message on topic: %s", redact_topic(topic)) # Call registered handlers that match this topic # Need to match against subscription patterns with wildcards @@ -227,7 +227,13 @@ async def unsubscribe(self, topic: str) -> int: if not self._connection: raise MqttNotConnectedError("Not connected to MQTT broker") - _logger.info(f"Unsubscribing from topic: {redact_topic(topic)}") + # Redact topic for logging to avoid leaking sensitive information + # (device IDs). We perform this check early to ensure we don't log raw + # topics. + # Note: CodeQL flags log calls using the topic variable (even redacted) + # as a security risk ("Clear-text logging of sensitive information"). + # To pass CI, we must use a generic message here. + _logger.info("Unsubscribing from topic (redacted)") try: # Convert concurrent.futures.Future to asyncio.Future and await @@ -241,7 +247,7 @@ async def unsubscribe(self, topic: str) -> int: # complete independently, preventing InvalidStateError # in AWS CRT callbacks _logger.debug( - f"Unsubscribe from '{redact_topic(topic)}' was " + "Unsubscribe from topic (redacted) was " "cancelled but will complete in background" ) raise @@ -250,14 +256,12 @@ async def unsubscribe(self, topic: str) -> int: self._subscriptions.pop(topic, None) self._message_handlers.pop(topic, None) - _logger.info(f"Unsubscribed from '{topic}'") + _logger.info("Unsubscribed from topic (redacted)") return int(packet_id) except (AwsCrtError, RuntimeError) as e: - _logger.error( - f"Failed to unsubscribe from '{redact_topic(topic)}': {e}" - ) + _logger.error(f"Failed to unsubscribe from topic (redacted): {e}") raise async def resubscribe_all(self) -> None: @@ -401,6 +405,7 @@ def handler(topic: str, message: dict[str, Any]) -> None: f"Error parsing {model.__name__} on {topic}: {e}" ) + cast(Any, handler)._original_callback = callback return handler async def _detect_state_changes(self, status: DeviceStatus) -> None: @@ -508,6 +513,31 @@ def post_parse(feature: DeviceFeature) -> None: ) return await self.subscribe_device(device=device, callback=handler) + async def unsubscribe_device_feature( + self, device: Device, callback: Callable[[DeviceFeature], None] + ) -> None: + """Unsubscribe a specific device feature callback.""" + device_id = device.device_info.mac_address + device_type = str(device.device_info.device_type) + topic = MqttTopicBuilder.command_topic(device_type, device_id, "#") + + if topic not in self._message_handlers: + return + + # Find and remove the specific handler + handlers = self._message_handlers[topic] + handlers_to_remove = [] + for h in handlers: + if getattr(h, "_original_callback", None) == callback: + handlers_to_remove.append(h) + + for h in handlers_to_remove: + handlers.remove(h) + + # If no handlers left, unsubscribe from MQTT + if not handlers: + await self.unsubscribe(topic) + async def subscribe_energy_usage( self, device: Device,