diff --git a/src/hiero_sdk_python/query/topic_message_query.py b/src/hiero_sdk_python/query/topic_message_query.py index fca9f9472..ad0269d70 100644 --- a/src/hiero_sdk_python/query/topic_message_query.py +++ b/src/hiero_sdk_python/query/topic_message_query.py @@ -1,10 +1,15 @@ from __future__ import annotations +import logging +import re import threading import time from collections.abc import Callable +from dataclasses import dataclass, field from datetime import datetime +import grpc + from hiero_sdk_python.client.client import Client from hiero_sdk_python.consensus.topic_id import TopicId from hiero_sdk_python.consensus.topic_message import TopicMessage @@ -14,6 +19,19 @@ from hiero_sdk_python.utils.subscription_handle import SubscriptionHandle +logger = logging.getLogger(__name__) + +RST_STREAM = re.compile(r"\brst[^0-9a-zA-Z]stream\b", re.IGNORECASE | re.DOTALL) + + +@dataclass +class SubscriptionState: + attempt: int = 0 + count: int = 0 + last_message: mirror_proto.ConsensusTopicResponse | None = None + pending_messages: dict[TransactionId, list[mirror_proto.ConsensusTopicResponse]] = field(default_factory=dict) + + class TopicMessageQuery: """ A query to subscribe to messages from a specific HCS topic, via a mirror node. @@ -31,48 +49,49 @@ def __init__( chunking_enabled: bool = False, ) -> None: """Initializes a TopicMessageQuery.""" - self._topic_id: TopicId | None = self._parse_topic_id(topic_id) if topic_id else None + self._topic_id: basic_types_pb2.TopicID | None = self._parse_topic_id(topic_id) if topic_id else None self._start_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(start_time) if start_time else None self._end_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(end_time) if end_time else None - self._limit: int | None = limit + self._limit: int = limit if limit is not None else 0 self._chunking_enabled: bool = chunking_enabled - self._completion_handler: Callable[[], None] | None = None self._max_attempts: int = 10 self._max_backoff: float = 8.0 + self._completion_handler: Callable[[], None] | None = self._on_complete + self._error_handler: Callable[[], None] | None = self._on_error + def set_max_attempts(self, attempts: int) -> TopicMessageQuery: """Sets the maximum number of attempts to reconnect on failure.""" + if attempts <= 0: + raise ValueError("max_attempts must be greater than 0") + self._max_attempts = attempts return self def set_max_backoff(self, backoff: float) -> TopicMessageQuery: """Sets the maximum backoff time in seconds for reconnection attempts.""" + if backoff < 0.5: + raise ValueError("max_backoff must be at least 500 ms") + self._max_backoff = backoff return self def set_completion_handler(self, handler: Callable[[], None]) -> TopicMessageQuery: """Sets a completion handler that is called when the subscription completes.""" + if not callable(handler): + raise TypeError("handler must be a callable object") + self._completion_handler = handler return self - def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID: - """Parses a topic ID from a string or TopicId object into a protobuf TopicID.""" - if isinstance(topic_id, str): - parts = topic_id.strip().split(".") - if len(parts) != 3: - raise ValueError(f"Invalid topic ID string: {topic_id}") - shard, realm, topic = map(int, parts) - return basic_types_pb2.TopicID(shardNum=shard, realmNum=realm, topicNum=topic) - if isinstance(topic_id, TopicId): - return topic_id._to_proto() - raise TypeError("Invalid topic_id format. Must be a string or TopicId.") + def set_error_handler(self, handler: Callable[[Exception], None]) -> TopicMessageQuery: + """Sets an error handler that is called when the subscription causes an error.""" + if not callable(handler): + raise TypeError("handler must be a callable object") - def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp: - """Converts a datetime object to a protobuf Timestamp.""" - seconds = int(dt.timestamp()) - nanos = int((dt.timestamp() - seconds) * 1e9) - return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + self._error_handler = handler + return self def set_topic_id(self, topic_id: str | TopicId) -> TopicMessageQuery: """Sets the topic ID for the query.""" @@ -99,6 +118,94 @@ def set_chunking_enabled(self, enabled: bool) -> TopicMessageQuery: self._chunking_enabled = enabled return self + def _on_complete(self) -> None: + logger.info(f"Subscription to topic {self._topic_id} complete") + + def _on_error(self, err: Exception) -> None: + if isinstance(err, grpc.RpcError) and err.code() == grpc.StatusCode.CANCELLED: + logger.warning(f"Call is cancelled for topic {self._topic_id}") + else: + logger.error(f"Error attempting to subscribe to topic {self._topic_id}: {err}") + + def _should_retry(self, err: Exception) -> bool: + if isinstance(err, grpc.RpcError): + return err.code() in ( + grpc.StatusCode.NOT_FOUND, + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + ) or (err.code() == grpc.StatusCode.INTERNAL and bool(RST_STREAM.search(err.details()))) + + return True + + def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID: + """Parses a topic ID from a string or TopicId object into a protobuf TopicID.""" + if isinstance(topic_id, str): + topic_id = TopicId.from_string(topic_id) + + if isinstance(topic_id, TopicId): + return topic_id._to_proto() + + raise TypeError("Invalid topic_id format. Must be a string or TopicId.") + + def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp: + """Converts a datetime object to a protobuf Timestamp.""" + seconds = int(dt.timestamp()) + nanos = int((dt.timestamp() - seconds) * 1e9) + return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + + def _build_query_request(self, state: SubscriptionState) -> mirror_proto.ConsensusTopicQuery: + """Build the request object based on current subscription state.""" + request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) + + if self._end_time is not None: + request.consensusEndTime.CopyFrom(self._end_time) + + if state.last_message is not None: + last_message_time = state.last_message.consensusTimestamp + + seconds = last_message_time.seconds + nanos = last_message_time.nanos + 1 + + if nanos >= 1_000_000_000: + seconds += 1 + nanos = 0 + + request.consensusStartTime.seconds = seconds + request.consensusStartTime.nanos = nanos + + if self._limit > 0: + request.limit = max(0, self._limit - state.count) + else: + if self._start_time is not None: + request.consensusStartTime.CopyFrom(self._start_time) + request.limit = self._limit + + return request + + def _handle_response(self, response, state: SubscriptionState, on_message: Callable[[TopicMessage], None]) -> None: + """Handles single or chunked messages.""" + state.count += 1 + state.last_message = response + + if not self._chunking_enabled or not response.HasField("chunkInfo") or response.chunkInfo.total <= 1: + message = TopicMessage.of_single(response) + on_message(message) + return + + initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID) + + if initial_tx_id not in state.pending_messages: + state.pending_messages[initial_tx_id] = [] + + chunks = state.pending_messages[initial_tx_id] + chunks.append(response) + + if len(chunks) == response.chunkInfo.total: + del state.pending_messages[initial_tx_id] + + message = TopicMessage.of_many(chunks) + on_message(message) + def subscribe( self, client: Client, @@ -111,49 +218,23 @@ def subscribe( if not client.mirror_stub: raise ValueError("Client has no mirror_stub. Did you configure a mirror node address?") - request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) - if self._start_time: - request.consensusStartTime.CopyFrom(self._start_time) - if self._end_time: - request.consensusEndTime.CopyFrom(self._end_time) - if self._limit is not None: - request.limit = self._limit - subscription_handle = SubscriptionHandle() - - pending_chunks: dict[str, list[mirror_proto.ConsensusTopicResponse]] = {} + state = SubscriptionState() def run_stream(): - attempt = 0 - while attempt < self._max_attempts and not subscription_handle.is_cancelled(): + while state.attempt < self._max_attempts and not subscription_handle.is_cancelled(): + state.attempt += 1 + request = self._build_query_request(state) + try: message_stream = client.mirror_stub.subscribeTopic(request) + subscription_handle._set_call(message_stream) for response in message_stream: if subscription_handle.is_cancelled(): return - if ( - not self._chunking_enabled - or not response.HasField("chunkInfo") - or response.chunkInfo.total <= 1 - ): - msg_obj = TopicMessage.of_single(response) - on_message(msg_obj) - continue - - initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID) - - if initial_tx_id not in pending_chunks: - pending_chunks[initial_tx_id] = [] - - pending_chunks[initial_tx_id].append(response) - - if len(pending_chunks[initial_tx_id]) == response.chunkInfo.total: - chunk_list = pending_chunks.pop(initial_tx_id) - - msg_obj = TopicMessage.of_many(chunk_list) - on_message(msg_obj) + self._handle_response(response, state, on_message) if self._completion_handler: self._completion_handler() @@ -163,13 +244,16 @@ def run_stream(): if subscription_handle.is_cancelled(): return - attempt += 1 - if attempt >= self._max_attempts: + if state.attempt >= self._max_attempts or not self._should_retry(e): + if self._error_handler: + self._error_handler(e) if on_error: on_error(e) return - delay = min(0.5 * (2 ** (attempt - 1)), self._max_backoff) + delay = min(0.5 * (2 ** (state.attempt)), self._max_backoff) + logger.warning(f"Error subscribing to topic attempt {state.attempt}. Retrying in {int(delay)}s...") + time.sleep(delay) thread = threading.Thread(target=run_stream, daemon=True) diff --git a/src/hiero_sdk_python/utils/subscription_handle.py b/src/hiero_sdk_python/utils/subscription_handle.py index f4ad199c1..b29f07cfb 100644 --- a/src/hiero_sdk_python/utils/subscription_handle.py +++ b/src/hiero_sdk_python/utils/subscription_handle.py @@ -1,6 +1,7 @@ from __future__ import annotations import threading +from typing import Any class SubscriptionHandle: @@ -12,11 +13,25 @@ class SubscriptionHandle: def __init__(self): self._cancelled = threading.Event() - self._thread = None + self._thread: threading.Thread | None = None + self._call: Any | None = None + self._lock = threading.Lock() + + def _set_call(self, call: Any): + """Sets the active gRPC call so it can be cancelled.""" + with self._lock: + self._call = call + + if self._cancelled.is_set(): + self._call.cancel() def cancel(self): """Signals to cancel the subscription.""" - self._cancelled.set() + with self._lock: + self._cancelled.set() + + if self._call: + self._call.cancel() def is_cancelled(self) -> bool: """Returns True if this subscription is already cancelled.""" diff --git a/tests/unit/subscription_handle_test.py b/tests/unit/subscription_handle_test.py index 218da4c2d..5262c3bf1 100644 --- a/tests/unit/subscription_handle_test.py +++ b/tests/unit/subscription_handle_test.py @@ -6,17 +6,20 @@ def test_not_cancelled_by_default(): + """Test a new handle starts in a non-cancelled state.""" handle = SubscriptionHandle() assert not handle.is_cancelled() def test_cancel_marks_as_cancelled(): + """Test calling cancel updates the is_cancelled status.""" handle = SubscriptionHandle() handle.cancel() assert handle.is_cancelled() def test_set_thread_and_join_calls_thread_join_with_timeout(): + """Test that join correctly forwards the timeout to the underlying thread.""" handle = SubscriptionHandle() mock_thread = Mock() handle.set_thread(mock_thread) @@ -25,6 +28,25 @@ def test_set_thread_and_join_calls_thread_join_with_timeout(): def test_join_without_thread_raises_nothing(): + """Test join is a no-op if no thread has been associated.""" handle = SubscriptionHandle() # should not raise handle.join() + + +def test_cancel_triggers_grpc_termination(): + """Test that cancelling the handle terminates the active gRPC call.""" + handle = SubscriptionHandle() + mock_call = Mock() + handle._set_call(mock_call) + handle.cancel() + mock_call.cancel.assert_called_once() + + +def test_immediate_cancellation_of_late_call(): + """Test a gRPC call is cancelled immediately if set after the handle was cancelled.""" + handle = SubscriptionHandle() + mock_call = Mock() + handle.cancel() + handle._set_call(mock_call) + mock_call.cancel.assert_called_once() diff --git a/tests/unit/topic_message_query_test.py b/tests/unit/topic_message_query_test.py index 270aba881..9a1828265 100644 --- a/tests/unit/topic_message_query_test.py +++ b/tests/unit/topic_message_query_test.py @@ -1,15 +1,21 @@ from __future__ import annotations +import time from datetime import datetime, timezone from unittest.mock import MagicMock, patch +import grpc import pytest +from hiero_sdk_python.account.account_id import AccountId from hiero_sdk_python.client.client import Client from hiero_sdk_python.consensus.topic_id import TopicId from hiero_sdk_python.hapi.mirror import consensus_service_pb2 as mirror_proto from hiero_sdk_python.hapi.services import timestamp_pb2 as hapi_timestamp_pb2 -from hiero_sdk_python.query.topic_message_query import TopicMessageQuery +from hiero_sdk_python.hapi.services.consensus_submit_message_pb2 import ConsensusMessageChunkInfo +from hiero_sdk_python.query.topic_message_query import SubscriptionState, TopicMessageQuery +from hiero_sdk_python.transaction.transaction_id import TransactionId +from tests.unit.mock_server import RealRpcError pytestmark = pytest.mark.unit @@ -19,7 +25,9 @@ def mock_client(): """Fixture to provide a mock Client instance.""" client = MagicMock(spec=Client) - client.operator_account_id = "0.0.12345" + client.operator_account_id = AccountId(0, 0, 12345) + client.mirror_stub = MagicMock() + return client @@ -34,17 +42,220 @@ def mock_subscription_response(): """Fixture to provide a mock response from a topic subscription.""" return mirror_proto.ConsensusTopicResponse( consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=12345, nanos=67890), - message=b"Hello, world!", + message=b"Hello Hiero!", runningHash=b"\x00" * 48, sequenceNumber=1, ) -# This test uses fixtures (mock_client, mock_topic_id, mock_subscription_response) as parameters +# Initialization + + +def test_topic_message_query_initialization(): + """Test initializing the query with various parameter types and setters.""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc) + + def mock_complete(): + pass + + def mock_error(e): + pass + + query = ( + TopicMessageQuery() + .set_topic_id("0.0.123") + .set_start_time(start) + .set_limit(5) + .set_chunking_enabled(True) + .set_completion_handler(mock_complete) + .set_error_handler(mock_error) + ) + + assert query._topic_id.topicNum == 123 + assert query._start_time.seconds == int(start.timestamp()) + assert query._limit == 5 + assert query._chunking_enabled is True + assert query._completion_handler == mock_complete + assert query._error_handler == mock_error + + +def test_topic_message_query_invalid_max_backoff(): + """Test that invalid max_backoff raises errors.""" + query = TopicMessageQuery() + + with pytest.raises(ValueError, match="max_backoff must be at least 500 ms"): + query.set_max_backoff(0.1) + + +def test_topic_message_query_invalid_max_attempts(): + """Test that invalid max_attempts raises errors.""" + query = TopicMessageQuery() + + with pytest.raises(ValueError, match="max_attempts must be greater than 0"): + query.set_max_attempts(0) + + +def test_topic_message_query_invalid_topic_id(): + """Test that invalid topic_id raises errors.""" + query = TopicMessageQuery() + + # Invalid TopicId type + with pytest.raises(TypeError, match="Invalid topic_id format"): + query.set_topic_id(12345) + + # Invalid TopicId format + with pytest.raises(ValueError, match="Invalid topic ID string"): + query.set_topic_id("12345") + + +def test_subscribe_missing_config(mock_client): + """Test that subscribe fails if Topic ID or Mirror Stub is missing.""" + # No TopicId + query_no_id = TopicMessageQuery() + with pytest.raises(ValueError, match="Topic ID must be set before subscribing"): + query_no_id.subscribe(mock_client, on_message=MagicMock()) + + # No MirrorStub + query_ok = TopicMessageQuery(topic_id="0.0.123") + mock_client.mirror_stub = None + with pytest.raises(ValueError, match="Client has no mirror_stub"): + query_ok.subscribe(mock_client, on_message=MagicMock()) + + +@pytest.mark.parametrize( + "handler", + ["string", 1, True, None, [], {}], +) +def test_topic_message_query_invalid_handler_param(handler): + """Test that a non-callable handler raises a TypeError.""" + query = TopicMessageQuery() + + # For complete_handler + with pytest.raises(TypeError, match="handler must be a callable object"): + query.set_completion_handler(handler) + # For error_handler + with pytest.raises(TypeError, match="handler must be a callable object"): + query.set_error_handler(handler) + + +# build_query_request + + +def test_build_query_request_uses_provided_start_time(): + """Test that the request uses the provided start_time when no last_message exists.""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc) + query = TopicMessageQuery(topic_id="0.0.123").set_start_time(start) + state = SubscriptionState() + + state.last_message = None + + expected_start = query._start_time + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == expected_start.seconds + assert request.consensusStartTime.nanos == expected_start.nanos + + +def test_build_query_request_from_last_message_timestamp(): + """Test that a reconnection request overrides start_time using the last message timestamp + 1 nano""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc) + query = TopicMessageQuery(topic_id="0.0.123").set_start_time(start) + state = SubscriptionState() + + state.last_message = state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=50, nanos=10) + ) + + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == 50 + assert request.consensusStartTime.nanos == 11 + + +def test_build_query_request_nanosecond_rollover(): + """Test that nanos reaching 1_000_000_000 correctly increments the seconds field.""" + query = TopicMessageQuery(topic_id="0.0.123") + state = SubscriptionState() + + # Mock message arriving at the 999_999_999 nanosecond of second 50 + state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=50, nanos=999_999_999) + ) + + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == 51 + assert request.consensusStartTime.nanos == 0 + + +def test_build_query_request_limit_decrements_on_retry(): + """Test that retry requests ask only for the remaining messages within the limit.""" + query = TopicMessageQuery(topic_id="0.0.123").set_limit(10) + state = SubscriptionState() + + # Mock state already have collect 4 message + state.count = 4 + state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=100, nanos=500) + ) + + request = query._build_query_request(state) + + # New request only ask for the remaining 6 message (10 - 4) + assert request.limit == 6 + + +def test_build_query_request_limit_floor_at_zero(): + """Test that remaining request limits never drop below zero.""" + query = TopicMessageQuery(topic_id="0.0.123").set_limit(5) + state = SubscriptionState() + + state.count = 6 + state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=100, nanos=500) + ) + + request = query._build_query_request(state) + + # The limit should be 0 rather than becoming negative + assert request.limit == 0 + + +def test_build_query_request_set_end_time_if_provided(): + """Test that request is created with the end_time if provided.""" + end = datetime(2023, 1, 1, tzinfo=timezone.utc) + query = TopicMessageQuery(topic_id="0.0.123").set_end_time(end) + + state = SubscriptionState() + state.last_message = None + + expected_end = query._end_time + request = query._build_query_request(state) + + assert request.consensusEndTime.seconds == expected_end.seconds + assert request.consensusEndTime.nanos == expected_end.nanos + + +def test_build_query_request_set_start_end_time_to_default(): + """Test that request is created with start_time and end_time as None if not present.""" + query = TopicMessageQuery(topic_id="0.0.123") + state = SubscriptionState() + state.last_message = None + + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == 0 + assert request.consensusStartTime.nanos == 0 + + assert request.consensusEndTime.seconds == 0 + assert request.consensusEndTime.nanos == 0 + + +# handle_response / subscribe + + def test_topic_message_query_subscription(mock_client, mock_topic_id, mock_subscription_response): - """ - Test subscribing to topic messages using TopicMessageQuery. - """ + """Test subscribing to topic messages using TopicMessageQuery.""" query = TopicMessageQuery().set_topic_id(mock_topic_id).set_start_time(datetime.now(tz=timezone.utc)) with patch("hiero_sdk_python.query.topic_message_query.TopicMessageQuery.subscribe") as mock_subscribe: @@ -63,9 +274,135 @@ def side_effect(client, on_message, on_error): # noqa: ARG001 called_args = on_message.call_args[0][0] assert called_args.consensusTimestamp.seconds == 12345 assert called_args.consensusTimestamp.nanos == 67890 - assert called_args.message == b"Hello, world!" + assert called_args.message == b"Hello Hiero!" assert called_args.sequenceNumber == 1 on_error.assert_not_called() - print("Test passed: Subscription handled messages correctly.") + +def test_chunk_message_handling(mock_client): + """Test that multiple chunks are correctly buffered and released as a single message.""" + query = TopicMessageQuery(topic_id="0.0.123", chunking_enabled=True) + + tx_id = TransactionId.generate(mock_client.operator_account_id) + tx_id_proto = tx_id._to_proto() + + chunk1 = mirror_proto.ConsensusTopicResponse( + message=b"chunk-1", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=1) + ) + chunk2 = mirror_proto.ConsensusTopicResponse( + message=b"chunk-2", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=2) + ) + + mock_client.mirror_stub.subscribeTopic.return_value = iter([chunk1, chunk2]) + + received_messages = [] + handle = query.subscribe(mock_client, on_message=lambda m: received_messages.append(m)) + + handle._thread.join(timeout=1.0) + + assert len(received_messages) == 1 + assert b"chunk-1" in received_messages[0].contents + assert b"chunk-2" in received_messages[0].contents + + +def test_chunk_message_handling_when_chunking_is_disabled(mock_client): + """Test that when chunking is disabled only single chunk is released as a single message.""" + query = TopicMessageQuery(topic_id="0.0.123", chunking_enabled=False) + + tx_id = TransactionId.generate(mock_client.operator_account_id) + tx_id_proto = tx_id._to_proto() + + chunk1 = mirror_proto.ConsensusTopicResponse( + message=b"chunk-1", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=1) + ) + chunk2 = mirror_proto.ConsensusTopicResponse( + message=b"chunk-2", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=2) + ) + + mock_client.mirror_stub.subscribeTopic.return_value = iter([chunk1, chunk2]) + + received_messages = [] + handle = query.subscribe(mock_client, on_message=lambda m: received_messages.append(m)) + + handle._thread.join(timeout=1.0) + + assert len(received_messages) == 2 # since we will get 2 seperate message + assert b"chunk-1" in received_messages[0].contents + assert b"chunk-2" not in received_messages[0].contents + + +@pytest.mark.parametrize( + "error", + [ + RealRpcError(grpc.StatusCode.NOT_FOUND, "unavailable"), + RealRpcError(grpc.StatusCode.UNAVAILABLE, "unavailable"), + RealRpcError(grpc.StatusCode.RESOURCE_EXHAUSTED, "busy"), + RealRpcError(grpc.StatusCode.INTERNAL, "received rst stream"), # internal with rst stream + Exception("non grpc exception"), # non grpc exception + ], +) +def test_retry_logic_on_retryable_error(mock_client, error): + """Test that the query retries on retryable errors but stops after max_attempts.""" + query = TopicMessageQuery(topic_id="0.0.123").set_max_attempts(2).set_max_backoff(0.5) + + mock_client.mirror_stub.subscribeTopic.side_effect = [error, error] + handle = query.subscribe(mock_client, on_message=MagicMock(), on_error=MagicMock()) + + handle._thread.join(timeout=2.0) + + assert mock_client.mirror_stub.subscribeTopic.call_count == 2 + + +@pytest.mark.parametrize( + "non_retryable_error", + [ + RealRpcError(grpc.StatusCode.PERMISSION_DENIED, "permission denied"), + RealRpcError(grpc.StatusCode.INVALID_ARGUMENT, "invalid argument"), + RealRpcError(grpc.StatusCode.UNAUTHENTICATED, "unauthenticated"), + RealRpcError(grpc.StatusCode.INTERNAL, "internal error"), + ], +) +def test_retry_logic_on_non_retryable_error(mock_client, non_retryable_error): + """Test that the query stops immediately on non-transient errors.""" + query = TopicMessageQuery(topic_id="0.0.123").set_max_attempts(5).set_max_backoff(0.5) + + mock_client.mirror_stub.subscribeTopic.side_effect = [non_retryable_error] * 5 + + on_error = MagicMock() + handle = query.subscribe(mock_client, on_message=MagicMock(), on_error=on_error) + + handle._thread.join(timeout=1.0) + + assert mock_client.mirror_stub.subscribeTopic.call_count == 1 + on_error.assert_called_once_with(non_retryable_error) + + assert not handle._thread.is_alive() + + +def test_subscription_cancellation(mock_client): + """Test that cancelling a handle stops the subscription thread.""" + query = TopicMessageQuery(topic_id="0.0.123") + + def infinite_stream(): + while True: + yield mirror_proto.ConsensusTopicResponse(message=b"ping") + time.sleep(0.1) + + mock_call = MagicMock() + mock_call.__iter__.return_value = infinite_stream() + + mock_client.mirror_stub.subscribeTopic.return_value = mock_call + + on_message = MagicMock() + handle = query.subscribe(mock_client, on_message=on_message) + + time.sleep(0.2) + assert handle._thread.is_alive() + + handle.cancel() + + handle._thread.join(timeout=1.0) + + assert not handle._thread.is_alive() + mock_call.cancel.assert_called()