diff --git a/src/nwp500/mqtt_connection.py b/src/nwp500/mqtt_connection.py index d149a8b..ad1e515 100644 --- a/src/nwp500/mqtt_connection.py +++ b/src/nwp500/mqtt_connection.py @@ -140,9 +140,23 @@ async def connect(self) -> bool: _logger.info("Establishing MQTT connection...") # Convert concurrent.futures.Future to asyncio.Future and await + # Use shield to prevent cancellation from propagating to + # underlying future if self._connection is not None: connect_future = self._connection.connect() - connect_result = await asyncio.wrap_future(connect_future) + try: + connect_result = await asyncio.shield( + asyncio.wrap_future(connect_future) + ) + except asyncio.CancelledError: + # Shield was cancelled - the underlying connect will + # complete independently, preventing InvalidStateError + # in AWS CRT callbacks + _logger.debug( + "Connect operation was cancelled but will complete " + "in background" + ) + raise else: raise MqttConnectionError("Connection not initialized") @@ -196,8 +210,20 @@ async def disconnect(self) -> None: try: # Convert concurrent.futures.Future to asyncio.Future and await + # Use shield to prevent cancellation from propagating to + # underlying future disconnect_future = self._connection.disconnect() - await asyncio.wrap_future(disconnect_future) + try: + await asyncio.shield(asyncio.wrap_future(disconnect_future)) + except asyncio.CancelledError: + # Shield was cancelled - the underlying disconnect will + # complete independently, preventing InvalidStateError + # in AWS CRT callbacks + _logger.debug( + "Disconnect operation was cancelled but will complete " + "in background" + ) + raise self._connected = False self._connection = None @@ -232,10 +258,22 @@ async def subscribe( _logger.debug(f"Subscribing to topic: {topic}") # Convert concurrent.futures.Future to asyncio.Future and await + # Use shield to prevent cancellation from propagating to + # underlying future subscribe_future, packet_id = self._connection.subscribe( topic=topic, qos=qos, callback=callback ) - await asyncio.wrap_future(subscribe_future) + try: + await asyncio.shield(asyncio.wrap_future(subscribe_future)) + except asyncio.CancelledError: + # Shield was cancelled - the underlying subscribe will + # complete independently, preventing InvalidStateError + # in AWS CRT callbacks + _logger.debug( + f"Subscribe to '{topic}' was cancelled but will complete " + "in background" + ) + raise _logger.info(f"Subscribed to '{topic}' with packet_id {packet_id}") return (subscribe_future, packet_id) @@ -259,10 +297,22 @@ async def unsubscribe(self, topic: str) -> int: _logger.debug(f"Unsubscribing from topic: {topic}") # Convert concurrent.futures.Future to asyncio.Future and await + # Use shield to prevent cancellation from propagating to + # underlying future unsubscribe_future, packet_id = self._connection.unsubscribe( topic=topic ) - await asyncio.wrap_future(unsubscribe_future) + try: + await asyncio.shield(asyncio.wrap_future(unsubscribe_future)) + except asyncio.CancelledError: + # Shield was cancelled - the underlying unsubscribe will + # complete independently, preventing InvalidStateError + # in AWS CRT callbacks + _logger.debug( + f"Unsubscribe from '{topic}' was cancelled but will " + "complete in background" + ) + raise _logger.info(f"Unsubscribed from '{topic}' with packet_id {packet_id}") return int(packet_id) @@ -286,6 +336,7 @@ async def publish( Raises: RuntimeError: If not connected + asyncio.CancelledError: If operation cancelled during disconnect """ if not self._connected or not self._connection: raise MqttNotConnectedError("Not connected to MQTT broker") @@ -303,11 +354,26 @@ async def publish( # Try to JSON encode other types payload_bytes = json.dumps(payload).encode("utf-8") - # Convert concurrent.futures.Future to asyncio.Future and await + # Publish and get the concurrent.futures.Future publish_future, packet_id = self._connection.publish( topic=topic, payload=payload_bytes, qos=qos ) - await asyncio.wrap_future(publish_future) + + # Shield the operation to prevent cancellation from propagating to + # the underlying concurrent.futures.Future. This avoids + # InvalidStateError when AWS CRT tries to set exception on a + # cancelled future. + try: + await asyncio.shield(asyncio.wrap_future(publish_future)) + except asyncio.CancelledError: + # Shield was cancelled - the underlying publish will complete + # independently, preventing InvalidStateError in AWS CRT + # callbacks + _logger.debug( + f"Publish to '{topic}' was cancelled but will complete " + "in background" + ) + raise _logger.debug(f"Published to '{topic}' with packet_id {packet_id}") return int(packet_id) diff --git a/src/nwp500/mqtt_subscriptions.py b/src/nwp500/mqtt_subscriptions.py index 42eb5ab..e4e37a2 100644 --- a/src/nwp500/mqtt_subscriptions.py +++ b/src/nwp500/mqtt_subscriptions.py @@ -214,10 +214,24 @@ async def subscribe( try: # Convert concurrent.futures.Future to asyncio.Future and await + # Use shield to prevent cancellation from propagating to + # underlying future subscribe_future, packet_id = self._connection.subscribe( topic=topic, qos=qos, callback=self._on_message_received ) - subscribe_result = await asyncio.wrap_future(subscribe_future) + try: + subscribe_result = await asyncio.shield( + asyncio.wrap_future(subscribe_future) + ) + except asyncio.CancelledError: + # Shield was cancelled - the underlying subscribe will + # complete independently, preventing InvalidStateError + # in AWS CRT callbacks + _logger.debug( + f"Subscribe to '{redact_topic(topic)}' was cancelled " + "but will complete in background" + ) + raise _logger.info( f"Subscription succeeded (topic redacted) with QoS " @@ -259,8 +273,20 @@ async def unsubscribe(self, topic: str) -> int: try: # Convert concurrent.futures.Future to asyncio.Future and await + # Use shield to prevent cancellation from propagating to + # underlying future unsubscribe_future, packet_id = self._connection.unsubscribe(topic) - await asyncio.wrap_future(unsubscribe_future) + try: + await asyncio.shield(asyncio.wrap_future(unsubscribe_future)) + except asyncio.CancelledError: + # Shield was cancelled - the underlying unsubscribe will + # complete independently, preventing InvalidStateError + # in AWS CRT callbacks + _logger.debug( + f"Unsubscribe from '{redact_topic(topic)}' was " + "cancelled but will complete in background" + ) + raise # Remove from tracking self._subscriptions.pop(topic, None)