diff --git a/roborock/cli.py b/roborock/cli.py index 4532ca21..0e6881eb 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -12,9 +12,9 @@ from pyshark.packet.packet import Packet # type: ignore from roborock import RoborockException -from roborock.containers import DeviceData, HomeDataProduct, LoginData -from roborock.mqtt.roborock_session import create_mqtt_session -from roborock.protocol import MessageParser, create_mqtt_params +from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData +from roborock.devices.device_manager import create_device_manager, create_home_data_api +from roborock.protocol import MessageParser from roborock.util import run_sync from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 @@ -101,44 +101,25 @@ async def session(ctx, duration: int): context: RoborockContext = ctx.obj login_data = context.login_data() - # Discovery devices if not already available - if not login_data.home_data: - await _discover(ctx) - login_data = context.login_data() - if not login_data.home_data or not login_data.home_data.devices: - raise RoborockException("Unable to discover devices") - - all_devices = login_data.home_data.devices + login_data.home_data.received_devices - click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}") - - rriot = login_data.user_data.rriot - params = create_mqtt_params(rriot) - - mqtt_session = await create_mqtt_session(params) - click.echo("Starting MQTT session...") - if not mqtt_session.connected: - raise RoborockException("Failed to connect to MQTT broker") + home_data_api = create_home_data_api(login_data.email, login_data.user_data) - def on_message(bytes: bytes): - """Callback function to handle incoming MQTT messages.""" - # Decode the first 20 bytes of the message for display - bytes = bytes[:20] + async def home_data_cache() -> HomeData: + if login_data.home_data is None: + login_data.home_data = await home_data_api() + context.update(login_data) + return login_data.home_data - click.echo(f"Received message: {bytes}...") + # Create device manager + device_manager = await create_device_manager(login_data.user_data, home_data_cache) - unsubs = [] - for device in all_devices: - device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}" - unsub = await mqtt_session.subscribe(device_topic, on_message) - unsubs.append(unsub) + devices = await device_manager.get_devices() + click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}") click.echo("MQTT session started. Listening for messages...") await asyncio.sleep(duration) - click.echo("Stopping MQTT session...") - for unsub in unsubs: - unsub() - await mqtt_session.close() + # Close the device manager (this will close all devices and MQTT session) + await device_manager.close() async def _discover(ctx): diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 926be8c4..44cdfd01 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -6,9 +6,13 @@ import enum import logging +from collections.abc import Callable from functools import cached_property from roborock.containers import HomeDataDevice, HomeDataProduct, UserData +from roborock.roborock_message import RoborockMessage + +from .mqtt_channel import MqttChannel _LOGGER = logging.getLogger(__name__) @@ -29,11 +33,25 @@ class DeviceVersion(enum.StrEnum): class RoborockDevice: """Unified Roborock device class with automatic connection setup.""" - def __init__(self, user_data: UserData, device_info: HomeDataDevice, product_info: HomeDataProduct) -> None: - """Initialize the RoborockDevice with device info, user data, and capabilities.""" + def __init__( + self, + user_data: UserData, + device_info: HomeDataDevice, + product_info: HomeDataProduct, + mqtt_channel: MqttChannel, + ) -> None: + """Initialize the RoborockDevice. + + The device takes ownership of the MQTT channel for communication with the device. + Use `connect()` to establish the connection, which will set up the MQTT channel + for receiving messages from the device. Use `close()` to unsubscribe from the MQTT + channel. + """ self._user_data = user_data self._device_info = device_info self._product_info = product_info + self._mqtt_channel = mqtt_channel + self._unsub: Callable[[], None] | None = None @property def duid(self) -> str: @@ -63,3 +81,28 @@ def device_version(self) -> str: self._device_info.name, ) return DeviceVersion.UNKNOWN + + async def connect(self) -> None: + """Connect to the device using MQTT. + + This method will set up the MQTT channel for communication with the device. + """ + if self._unsub: + raise ValueError("Already connected to the device") + self._unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + + async def close(self) -> None: + """Close the MQTT connection to the device. + + This method will unsubscribe from the MQTT channel and clean up resources. + """ + if self._unsub: + self._unsub() + self._unsub = None + + def _on_mqtt_message(self, message: RoborockMessage) -> None: + """Handle incoming MQTT messages from the device. + + This method should be overridden in subclasses to handle specific device messages. + """ + _LOGGER.debug("Received message from device %s: %s", self.duid, message) diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 3a95dd13..3244b261 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -1,5 +1,6 @@ """Module for discovering Roborock devices.""" +import asyncio import logging from collections.abc import Awaitable, Callable @@ -10,8 +11,13 @@ UserData, ) from roborock.devices.device import RoborockDevice +from roborock.mqtt.roborock_session import create_mqtt_session +from roborock.mqtt.session import MqttSession +from roborock.protocol import create_mqtt_params from roborock.web_api import RoborockApiClient +from .mqtt_channel import MqttChannel + _LOGGER = logging.getLogger(__name__) __all__ = [ @@ -34,11 +40,16 @@ def __init__( self, home_data_api: HomeDataApi, device_creator: DeviceCreator, + mqtt_session: MqttSession, ) -> None: - """Initialize the DeviceManager with user data and optional cache storage.""" + """Initialize the DeviceManager with user data and optional cache storage. + + This takes ownership of the MQTT session and will close it when the manager is closed. + """ self._home_data_api = home_data_api self._device_creator = device_creator self._devices: dict[str, RoborockDevice] = {} + self._mqtt_session = mqtt_session async def discover_devices(self) -> list[RoborockDevice]: """Discover all devices for the logged-in user.""" @@ -46,9 +57,16 @@ async def discover_devices(self) -> list[RoborockDevice]: device_products = home_data.device_products _LOGGER.debug("Discovered %d devices %s", len(device_products), home_data) - self._devices = { - duid: self._device_creator(device, product) for duid, (device, product) in device_products.items() - } + # These are connected serially to avoid overwhelming the MQTT broker + new_devices = {} + for duid, (device, product) in device_products.items(): + if duid in self._devices: + continue + new_device = self._device_creator(device, product) + await new_device.connect() + new_devices[duid] = new_device + + self._devices.update(new_devices) return list(self._devices.values()) async def get_device(self, duid: str) -> RoborockDevice | None: @@ -59,6 +77,13 @@ async def get_devices(self) -> list[RoborockDevice]: """Get all discovered devices.""" return list(self._devices.values()) + async def close(self) -> None: + """Close all MQTT connections and clean up resources.""" + tasks = [device.close() for device in self._devices.values()] + self._devices.clear() + tasks.append(self._mqtt_session.close()) + await asyncio.gather(*tasks) + def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi: """Create a home data API wrapper. @@ -67,7 +92,9 @@ def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi: home data for the user. """ - client = RoborockApiClient(email, user_data) + # Note: This will auto discover the API base URL. This can be improved + # by caching this next to `UserData` if needed to avoid unnecessary API calls. + client = RoborockApiClient(email) async def home_data_api() -> HomeData: return await client.get_home_data(user_data) @@ -83,9 +110,13 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi) include caching or other optimizations. """ + mqtt_params = create_mqtt_params(user_data.rriot) + mqtt_session = await create_mqtt_session(mqtt_params) + def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice: - return RoborockDevice(user_data, device, product) + mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params) + return RoborockDevice(user_data, device, product, mqtt_channel) - manager = DeviceManager(home_data_api, device_creator) + manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session) await manager.discover_devices() return manager diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py new file mode 100644 index 00000000..00a01210 --- /dev/null +++ b/roborock/devices/mqtt_channel.py @@ -0,0 +1,115 @@ +"""Modules for communicating with specific Roborock devices over MQTT.""" + +import asyncio +import logging +from collections.abc import Callable +from json import JSONDecodeError + +from roborock.containers import RRiot +from roborock.exceptions import RoborockException +from roborock.mqtt.session import MqttParams, MqttSession +from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder +from roborock.roborock_message import RoborockMessage + +_LOGGER = logging.getLogger(__name__) + + +class MqttChannel: + """Simple RPC-style channel for communicating with a device over MQTT. + + Handles request/response correlation and timeouts, but leaves message + format most parsing to higher-level components. + """ + + def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: RRiot, mqtt_params: MqttParams): + self._mqtt_session = mqtt_session + self._duid = duid + self._local_key = local_key + self._rriot = rriot + self._mqtt_params = mqtt_params + + # RPC support + self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {} + self._decoder = create_mqtt_decoder(local_key) + self._encoder = create_mqtt_encoder(local_key) + self._queue_lock = asyncio.Lock() + + @property + def _publish_topic(self) -> str: + """Topic to send commands to the device.""" + return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}" + + @property + def _subscribe_topic(self) -> str: + """Topic to receive responses from the device.""" + return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}" + + async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: + """Subscribe to the device's response topic. + + The callback will be called with the message payload when a message is received. + + All messages received will be processed through the provided callback, even + those sent in response to the `send_command` command. + + Returns a callable that can be used to unsubscribe from the topic. + """ + + def message_handler(payload: bytes) -> None: + if not (messages := self._decoder(payload)): + _LOGGER.warning("Failed to decode MQTT message: %s", payload) + return + for message in messages: + _LOGGER.debug("Received message: %s", message) + asyncio.create_task(self._resolve_future_with_lock(message)) + try: + callback(message) + except Exception as e: + _LOGGER.exception("Uncaught error in message handler callback: %s", e) + + return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler) + + async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: + """Resolve waiting future with proper locking.""" + if (request_id := message.get_request_id()) is None: + _LOGGER.debug("Received message with no request_id") + return + async with self._queue_lock: + if (future := self._waiting_queue.pop(request_id, None)) is not None: + future.set_result(message) + else: + _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id) + + async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: + """Send a command message and wait for the response message. + + Returns the raw response message - caller is responsible for parsing. + """ + try: + if (request_id := message.get_request_id()) is None: + raise RoborockException("Message must have a request_id for RPC calls") + except (ValueError, JSONDecodeError) as err: + _LOGGER.exception("Error getting request_id from message: %s", err) + raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err + + future: asyncio.Future[RoborockMessage] = asyncio.Future() + async with self._queue_lock: + if request_id in self._waiting_queue: + raise RoborockException(f"Request ID {request_id} already pending, cannot send command") + self._waiting_queue[request_id] = future + + try: + encoded_msg = self._encoder(message) + await self._mqtt_session.publish(self._publish_topic, encoded_msg) + + return await asyncio.wait_for(future, timeout=timeout) + + except asyncio.TimeoutError as ex: + async with self._queue_lock: + self._waiting_queue.pop(request_id, None) + raise RoborockException(f"Command timed out after {timeout}s") from ex + except Exception: + logging.exception("Uncaught error sending command") + async with self._queue_lock: + self._waiting_queue.pop(request_id, None) + raise diff --git a/tests/devices/test_device.py b/tests/devices/test_device.py new file mode 100644 index 00000000..6e2b5d5f --- /dev/null +++ b/tests/devices/test_device.py @@ -0,0 +1,40 @@ +"""Tests for the Device class.""" + +from unittest.mock import AsyncMock, Mock + +from roborock.containers import HomeData, UserData +from roborock.devices.device import DeviceVersion, RoborockDevice + +from .. import mock_data + +USER_DATA = UserData.from_dict(mock_data.USER_DATA) +HOME_DATA = HomeData.from_dict(mock_data.HOME_DATA_RAW) + + +async def test_device_connection() -> None: + """Test the Device connection setup.""" + + unsub = Mock() + subscribe = AsyncMock() + subscribe.return_value = unsub + mqtt_channel = AsyncMock() + mqtt_channel.subscribe = subscribe + + device = RoborockDevice( + USER_DATA, + device_info=HOME_DATA.devices[0], + product_info=HOME_DATA.products[0], + mqtt_channel=mqtt_channel, + ) + assert device.duid == "abc123" + assert device.name == "Roborock S7 MaxV" + assert device.device_version == DeviceVersion.V1 + + assert not subscribe.called + + await device.connect() + assert subscribe.called + assert not unsub.called + + await device.close() + assert unsub.called diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index fac33344..e09087f5 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -1,5 +1,6 @@ """Tests for the DeviceManager class.""" +from collections.abc import Generator from unittest.mock import patch import pytest @@ -14,6 +15,13 @@ USER_DATA = UserData.from_dict(mock_data.USER_DATA) +@pytest.fixture(autouse=True) +def setup_mqtt_session() -> Generator[None, None, None]: + """Fixture to set up the MQTT session for the tests.""" + with patch("roborock.devices.device_manager.create_mqtt_session"): + yield + + async def home_home_data_no_devices() -> HomeData: """Mock home data API that returns no devices.""" return HomeData( @@ -52,12 +60,15 @@ async def test_with_device() -> None: assert device.name == "Roborock S7 MaxV" assert device.device_version == DeviceVersion.V1 + await device_manager.close() + async def test_get_non_existent_device() -> None: """Test getting a non-existent device.""" device_manager = await create_device_manager(USER_DATA, mock_home_data) device = await device_manager.get_device("non_existent_duid") assert device is None + await device_manager.close() async def test_home_data_api_exception() -> None: diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py new file mode 100644 index 00000000..8efa5664 --- /dev/null +++ b/tests/devices/test_mqtt_channel.py @@ -0,0 +1,277 @@ +"""Tests for the MqttChannel class.""" + +import asyncio +import json +from collections.abc import Callable, Generator +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from roborock.containers import HomeData, UserData +from roborock.devices.mqtt_channel import MqttChannel +from roborock.exceptions import RoborockException +from roborock.mqtt.session import MqttParams +from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol + +from .. import mock_data + +USER_DATA = UserData.from_dict(mock_data.USER_DATA) +TEST_MQTT_PARAMS = MqttParams( + host="localhost", + port=1883, + tls=False, + username="username", + password="password", + timeout=10.0, +) +TEST_LOCAL_KEY = "local_key" + +TEST_REQUEST = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=json.dumps({"dps": {"101": json.dumps({"id": 12345, "method": "get_status"})}}).encode(), +) +TEST_RESPONSE = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=json.dumps({"dps": {"102": json.dumps({"id": 12345, "result": {"state": "cleaning"}})}}).encode(), +) +TEST_REQUEST2 = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=json.dumps({"dps": {"101": json.dumps({"id": 54321, "method": "get_status"})}}).encode(), +) +TEST_RESPONSE2 = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=json.dumps({"dps": {"102": json.dumps({"id": 54321, "result": {"state": "cleaning"}})}}).encode(), +) +ENCODER = create_mqtt_encoder(TEST_LOCAL_KEY) +DECODER = create_mqtt_decoder(TEST_LOCAL_KEY) + + +@pytest.fixture(name="mqtt_session", autouse=True) +def setup_mqtt_session() -> Generator[Mock, None, None]: + """Fixture to set up the MQTT session for the tests.""" + mock_session = AsyncMock() + with patch("roborock.devices.device_manager.create_mqtt_session", return_value=mock_session): + yield mock_session + + +@pytest.fixture(name="mqtt_channel", autouse=True) +def setup_mqtt_channel(mqtt_session: Mock) -> MqttChannel: + """Fixture to set up the MQTT channel for the tests.""" + return MqttChannel( + mqtt_session, duid="abc123", local_key=TEST_LOCAL_KEY, rriot=USER_DATA.rriot, mqtt_params=TEST_MQTT_PARAMS + ) + + +@pytest.fixture(name="received_messages", autouse=True) +async def setup_subscribe_callback(mqtt_channel: MqttChannel) -> list[RoborockMessage]: + """Fixture to record messages received by the subscriber.""" + messages: list[RoborockMessage] = [] + await mqtt_channel.subscribe(messages.append) + return messages + + +@pytest.fixture(name="mqtt_message_handler") +async def setup_message_handler(mqtt_session: Mock, mqtt_channel: MqttChannel) -> Callable[[bytes], None]: + """Fixture to allow simulating incoming MQTT messages.""" + # Subscribe to set up message handling. We grab the message handler callback + # and use it to simulate receiving a response. + assert mqtt_session.subscribe + subscribe_call_args = mqtt_session.subscribe.call_args + message_handler = subscribe_call_args[0][1] + return message_handler + + +async def home_home_data_no_devices() -> HomeData: + """Mock home data API that returns no devices.""" + return HomeData( + id=1, + name="Test Home", + devices=[], + products=[], + ) + + +async def mock_home_data() -> HomeData: + """Mock home data API that returns devices.""" + return HomeData.from_dict(mock_data.HOME_DATA_RAW) + + +async def test_mqtt_channel(mqtt_session: Mock, mqtt_channel: MqttChannel) -> None: + """Test MQTT channel setup.""" + + unsub = Mock() + mqtt_session.subscribe.return_value = unsub + + callback = Mock() + result = await mqtt_channel.subscribe(callback) + + assert mqtt_session.subscribe.called + assert mqtt_session.subscribe.call_args[0][0] == "rr/m/o/user123/username/abc123" + + assert result == unsub + + +async def test_send_command_success( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test successful RPC command sending and response handling.""" + # Send a test request. We use a task so we can simulate receiving the response + # while the command is still being processed. + command_task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST)) + await asyncio.sleep(0.01) # yield + + # Simulate receiving the response message via MQTT + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + # Get the result + result = await command_task + + # Verify the command was sent + assert mqtt_session.publish.called + assert mqtt_session.publish.call_args[0][0] == "rr/m/i/user123/username/abc123" + raw_sent_msg = mqtt_session.publish.call_args[0][1] # == b"encoded_message" + decoded_message = next(iter(DECODER(raw_sent_msg))) + assert decoded_message == TEST_REQUEST + assert decoded_message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert decoded_message.get_request_id() == 12345 + + # Verify we got the response message back + assert result == TEST_RESPONSE + + +async def test_send_command_without_request_id( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test sending command without request ID raises exception.""" + # Create a message without request ID + test_message = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"no_request_id", + ) + + with pytest.raises(RoborockException, match="Message must have a request_id"): + await mqtt_channel.send_command(test_message) + + +async def test_concurrent_commands( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test handling multiple concurrent RPC commands.""" + + # Create multiple test messages with different request IDs + # Start both commands concurrently + task1 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + task2 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST2, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + # Create responses for both + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + mqtt_message_handler(ENCODER(TEST_RESPONSE2)) + await asyncio.sleep(0.01) # yield + + # Both should complete successfully + result1 = await task1 + result2 = await task2 + + assert result1 == TEST_RESPONSE + assert result2 == TEST_RESPONSE2 + + assert not caplog.records + + +async def test_concurrent_commands_same_request_id( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test that we are not allowed to send two commands with the same request id.""" + + # Create multiple test messages with different request IDs + # Start both commands concurrently + task1 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + task2 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + # Create response + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + # Both should complete successfully + result1 = await task1 + assert result1 == TEST_RESPONSE + + with pytest.raises(RoborockException, match="Request ID 12345 already pending, cannot send command"): + await task2 + + +async def test_handle_completed_future( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test handling response for an already completed future.""" + # Send request + task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + # Send the response twice + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + # Task completes and second message is not associated with a waiting handler + result = await task + assert result == TEST_RESPONSE + + +async def test_subscribe_callback_with_rpc_response( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + received_messages: list[RoborockMessage], + mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test that subscribe callback is called independent of RPC handling.""" + # Send request + task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + assert not received_messages + + # Send the response for this command and an unrelated command + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + mqtt_message_handler(ENCODER(TEST_RESPONSE2)) + await asyncio.sleep(0.01) # yield + + # Task completes + result = await task + assert result == TEST_RESPONSE + + # The subscribe callback should have been called with the same response + assert received_messages == [TEST_RESPONSE, TEST_RESPONSE2] + + +async def test_message_decode_error( + mqtt_message_handler: Callable[[bytes], None], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test an error during message decoding.""" + mqtt_message_handler(b"invalid_payload") + await asyncio.sleep(0.01) # yield + + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert "Failed to decode MQTT message" in caplog.records[0].message