Skip to content
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Basic Usage
if device:
# Access status information
status = device.status
print(f"Water Temperature: {status.dhw_temperature}°F")
print(f"Water Temperature: {status.dhw_temperature}")
print(f"Tank Charge: {status.dhw_charge_per}%")
print(f"Power Consumption: {status.current_inst_power}W")

Expand Down
3 changes: 2 additions & 1 deletion src/nwp500/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(

self.base_url = base_url.rstrip("/")
self._auth_client = auth_client
self._session = session or getattr(auth_client, "_session", None)
self._session = session or auth_client.session

if self._session is None:
raise ValueError(
"auth_client must have an active session or a session "
Expand Down
5 changes: 5 additions & 0 deletions src/nwp500/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,11 @@ async def ensure_valid_token(self) -> AuthTokens | None:

return tokens

@property
def session(self) -> aiohttp.ClientSession | None:
"""Get the active aiohttp session."""
return self._session

@property
def is_authenticated(self) -> bool:
"""Check if client is currently authenticated."""
Expand Down
1 change: 0 additions & 1 deletion src/nwp500/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,6 @@ class DeviceStatus(NavienBaseModel):
recirc_dhw_flow_rate: FlowRate = Field(
description="Recirculation DHW flow rate (dynamic units: LPM/GPM)",
json_schema_extra={
"unit_of_measurement": "GPM",
"device_class": "flow_rate",
},
)
Expand Down
17 changes: 13 additions & 4 deletions src/nwp500/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ async def _active_reconnect(self) -> None:
self._subscription_manager.update_connection(
self._connection
)
await self._subscription_manager.resubscribe_all()

_logger.info("Active reconnection successful")
else:
Expand Down Expand Up @@ -894,6 +895,16 @@ async def subscribe_device_feature(
"subscribe_device_feature", device, callback
)

async def unsubscribe_device_feature(
self, device: Device, callback: Callable[[DeviceFeature], None]
) -> None:
"""Unsubscribe a specific device feature callback."""
if not self._connected or not self._subscription_manager:
return
await self._subscription_manager.unsubscribe_device_feature(
device, callback
)

async def subscribe_energy_usage(
self,
device: Device,
Expand Down Expand Up @@ -961,10 +972,8 @@ def on_feature(feature: DeviceFeature) -> None:
)
return False
finally:
# Note: We don't unsubscribe token here because it might
# interfere with other subscribers if we're not careful.
# But the subscription manager handles multiple callbacks.
pass
# Unsubscribe using the specific callback to avoid leaking resources
await self.unsubscribe_device_feature(device, on_feature)

@property
def control(self) -> MqttDeviceController:
Expand Down
46 changes: 38 additions & 8 deletions src/nwp500/mqtt/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from awscrt import mqtt
from awscrt.exceptions import AwsCrtError
Expand Down Expand Up @@ -124,7 +124,7 @@ def _on_message_received(
try:
# Parse JSON payload
message = json.loads(payload.decode("utf-8"))
_logger.debug("Received message on topic: %s", topic)
_logger.debug("Received message on topic: %s", redact_topic(topic))

# Call registered handlers that match this topic
# Need to match against subscription patterns with wildcards
Expand Down Expand Up @@ -227,7 +227,13 @@ async def unsubscribe(self, topic: str) -> int:
if not self._connection:
raise MqttNotConnectedError("Not connected to MQTT broker")

_logger.info(f"Unsubscribing from topic: {redact_topic(topic)}")
# Redact topic for logging to avoid leaking sensitive information
# (device IDs). We perform this check early to ensure we don't log raw
# topics.
# Note: CodeQL flags log calls using the topic variable (even redacted)
# as a security risk ("Clear-text logging of sensitive information").
# To pass CI, we must use a generic message here.
_logger.info("Unsubscribing from topic (redacted)")

try:
# Convert concurrent.futures.Future to asyncio.Future and await
Expand All @@ -241,7 +247,7 @@ async def unsubscribe(self, topic: str) -> int:
# complete independently, preventing InvalidStateError
# in AWS CRT callbacks
_logger.debug(
f"Unsubscribe from '{redact_topic(topic)}' was "
"Unsubscribe from topic (redacted) was "
"cancelled but will complete in background"
)
raise
Expand All @@ -250,14 +256,12 @@ async def unsubscribe(self, topic: str) -> int:
self._subscriptions.pop(topic, None)
self._message_handlers.pop(topic, None)

_logger.info(f"Unsubscribed from '{topic}'")
_logger.info("Unsubscribed from topic (redacted)")

return int(packet_id)

except (AwsCrtError, RuntimeError) as e:
_logger.error(
f"Failed to unsubscribe from '{redact_topic(topic)}': {e}"
)
_logger.error(f"Failed to unsubscribe from topic (redacted): {e}")
raise

async def resubscribe_all(self) -> None:
Expand Down Expand Up @@ -401,6 +405,7 @@ def handler(topic: str, message: dict[str, Any]) -> None:
f"Error parsing {model.__name__} on {topic}: {e}"
)

cast(Any, handler)._original_callback = callback
return handler

async def _detect_state_changes(self, status: DeviceStatus) -> None:
Expand Down Expand Up @@ -508,6 +513,31 @@ def post_parse(feature: DeviceFeature) -> None:
)
return await self.subscribe_device(device=device, callback=handler)

async def unsubscribe_device_feature(
self, device: Device, callback: Callable[[DeviceFeature], None]
) -> None:
"""Unsubscribe a specific device feature callback."""
device_id = device.device_info.mac_address
device_type = str(device.device_info.device_type)
topic = MqttTopicBuilder.command_topic(device_type, device_id, "#")

if topic not in self._message_handlers:
return

# Find and remove the specific handler
handlers = self._message_handlers[topic]
handlers_to_remove = []
for h in handlers:
if getattr(h, "_original_callback", None) == callback:
handlers_to_remove.append(h)

for h in handlers_to_remove:
handlers.remove(h)

# If no handlers left, unsubscribe from MQTT
if not handlers:
await self.unsubscribe(topic)

async def subscribe_energy_usage(
self,
device: Device,
Expand Down