From 197289a52f842725f7d13ff26ba6bbe459ee88b7 Mon Sep 17 00:00:00 2001 From: Manish Dait Date: Fri, 17 Apr 2026 00:42:20 +0530 Subject: [PATCH 1/7] feat: updated topic message query to handle message Signed-off-by: Manish Dait --- .../query/topic_message_query.py | 150 +++++++++++++----- .../utils/subscription_handle.py | 19 ++- 2 files changed, 124 insertions(+), 45 deletions(-) diff --git a/src/hiero_sdk_python/query/topic_message_query.py b/src/hiero_sdk_python/query/topic_message_query.py index fca9f9472..af42d79fd 100644 --- a/src/hiero_sdk_python/query/topic_message_query.py +++ b/src/hiero_sdk_python/query/topic_message_query.py @@ -1,10 +1,15 @@ from __future__ import annotations +import logging +import re import threading import time from collections.abc import Callable +from dataclasses import dataclass, field from datetime import datetime +import grpc + from hiero_sdk_python.client.client import Client from hiero_sdk_python.consensus.topic_id import TopicId from hiero_sdk_python.consensus.topic_message import TopicMessage @@ -14,6 +19,19 @@ from hiero_sdk_python.utils.subscription_handle import SubscriptionHandle +logger = logging.getLogger(__name__) + +RST_STREAM = re.compile(r"\brst[^0-9a-zA-Z]stream\b", re.IGNORECASE | re.DOTALL) + + +@dataclass +class SubscriptionState: + attempt: int = 0 + count: int = 0 + last_message: mirror_proto.ConsensusTopicResponse | None = None + pending_messages: dict[str, list[mirror_proto.ConsensusTopicResponse]] = field(default_factory=dict) + + class TopicMessageQuery: """ A query to subscribe to messages from a specific HCS topic, via a mirror node. @@ -31,23 +49,31 @@ def __init__( chunking_enabled: bool = False, ) -> None: """Initializes a TopicMessageQuery.""" - self._topic_id: TopicId | None = self._parse_topic_id(topic_id) if topic_id else None + self._topic_id: basic_types_pb2.TopicID | None = self._parse_topic_id(topic_id) if topic_id else None self._start_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(start_time) if start_time else None self._end_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(end_time) if end_time else None self._limit: int | None = limit self._chunking_enabled: bool = chunking_enabled - self._completion_handler: Callable[[], None] | None = None self._max_attempts: int = 10 self._max_backoff: float = 8.0 + self._completion_handler: Callable[[], None] | None = self._on_complete + self._error_handler: Callable[[], None] | None = self._on_error + def set_max_attempts(self, attempts: int) -> TopicMessageQuery: """Sets the maximum number of attempts to reconnect on failure.""" + if attempts <= 0: + raise ValueError("max_attempts must be greater than 0") + self._max_attempts = attempts return self def set_max_backoff(self, backoff: float) -> TopicMessageQuery: """Sets the maximum backoff time in seconds for reconnection attempts.""" + if backoff < 0.5: + raise ValueError("max_backoff must be at least 500 ms") + self._max_backoff = backoff return self @@ -56,23 +82,10 @@ def set_completion_handler(self, handler: Callable[[], None]) -> TopicMessageQue self._completion_handler = handler return self - def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID: - """Parses a topic ID from a string or TopicId object into a protobuf TopicID.""" - if isinstance(topic_id, str): - parts = topic_id.strip().split(".") - if len(parts) != 3: - raise ValueError(f"Invalid topic ID string: {topic_id}") - shard, realm, topic = map(int, parts) - return basic_types_pb2.TopicID(shardNum=shard, realmNum=realm, topicNum=topic) - if isinstance(topic_id, TopicId): - return topic_id._to_proto() - raise TypeError("Invalid topic_id format. Must be a string or TopicId.") - - def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp: - """Converts a datetime object to a protobuf Timestamp.""" - seconds = int(dt.timestamp()) - nanos = int((dt.timestamp() - seconds) * 1e9) - return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + def set_error_handler(self, handler: Callable[[], None]) -> TopicMessageQuery: + """Sets a completion handler that is called when the subscription completes.""" + self._error_handler = handler + return self def set_topic_id(self, topic_id: str | TopicId) -> TopicMessageQuery: """Sets the topic ID for the query.""" @@ -99,6 +112,41 @@ def set_chunking_enabled(self, enabled: bool) -> TopicMessageQuery: self._chunking_enabled = enabled return self + def _on_complete(self) -> None: + logger.info(f"Subscription to topic {self._topic_id} complete") + + def _on_error(self, err: Exception) -> None: + if isinstance(err, grpc.RpcError) and err.code() == grpc.StatusCode.CANCELLED: + logger.warning(f"Call is cancelled for topic {self._topic_id}") + else: + logger.error(f"Error attempting to subscribe to topic {self._topic_id}: {err}") + + def _should_retry(self, err: Exception) -> bool: + if isinstance(err, grpc.RpcError): + return err.code() in ( + grpc.StatusCode.NOT_FOUND, + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + ) or (err.code() == grpc.StatusCode.INTERNAL and bool(RST_STREAM.search(err.details()))) + + return True + + def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID: + """Parses a topic ID from a string or TopicId object into a protobuf TopicID.""" + if isinstance(topic_id, str): + topic_id = TopicId.from_string(topic_id) + + if isinstance(topic_id, TopicId): + return topic_id._to_proto() + + raise TypeError("Invalid topic_id format. Must be a string or TopicId.") + + def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp: + """Converts a datetime object to a protobuf Timestamp.""" + seconds = int(dt.timestamp()) + nanos = int((dt.timestamp() - seconds) * 1e9) + return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + def subscribe( self, client: Client, @@ -111,49 +159,61 @@ def subscribe( if not client.mirror_stub: raise ValueError("Client has no mirror_stub. Did you configure a mirror node address?") - request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) - if self._start_time: - request.consensusStartTime.CopyFrom(self._start_time) - if self._end_time: - request.consensusEndTime.CopyFrom(self._end_time) - if self._limit is not None: - request.limit = self._limit - subscription_handle = SubscriptionHandle() - - pending_chunks: dict[str, list[mirror_proto.ConsensusTopicResponse]] = {} + state = SubscriptionState() def run_stream(): - attempt = 0 - while attempt < self._max_attempts and not subscription_handle.is_cancelled(): + while state.attempt < self._max_attempts and not subscription_handle.is_cancelled(): + request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) + + if self._end_time is not None: + request.consensusEndTime.CopyFrom(self._end_time) + + if state.last_message is not None: + last_message_time = state.last_message.consensusTimestamp + request.consensusStartTime.seconds = last_message_time.seconds + request.consensusStartTime.nanos = last_message_time.nanos + 1 + + if self._limit > 0: + request.limit = max(0, self._limit - state.count) + else: + if self._start_time is not None: + request.consensusStartTime.CopyFrom(self._start_time) + request.limit = self._limit + try: message_stream = client.mirror_stub.subscribeTopic(request) + subscription_handle._set_call(message_stream) for response in message_stream: if subscription_handle.is_cancelled(): return + state.count += 1 + state.last_message = response + if ( not self._chunking_enabled or not response.HasField("chunkInfo") or response.chunkInfo.total <= 1 ): - msg_obj = TopicMessage.of_single(response) - on_message(msg_obj) + message = TopicMessage.of_single(response) + on_message(message) continue initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID) - if initial_tx_id not in pending_chunks: - pending_chunks[initial_tx_id] = [] + if initial_tx_id not in state.pending_messages: + state.pending_messages[initial_tx_id] = [] - pending_chunks[initial_tx_id].append(response) + chunks = state.pending_messages[initial_tx_id] + chunks.append(response) - if len(pending_chunks[initial_tx_id]) == response.chunkInfo.total: - chunk_list = pending_chunks.pop(initial_tx_id) + if len(chunks) == response.chunkInfo.total: + del state.pending_messages[initial_tx_id] - msg_obj = TopicMessage.of_many(chunk_list) - on_message(msg_obj) + message = TopicMessage.of_many(chunks) + on_message(message) if self._completion_handler: self._completion_handler() @@ -163,13 +223,17 @@ def run_stream(): if subscription_handle.is_cancelled(): return - attempt += 1 - if attempt >= self._max_attempts: + if state.attempt >= self._max_attempts or not self._should_retry(e): + if self._error_handler: + self._error_handler(e) if on_error: on_error(e) return - delay = min(0.5 * (2 ** (attempt - 1)), self._max_backoff) + delay = min(0.5 * (2 ** (state.attempt)), self._max_backoff) + logger.warning(f"Error subscribing to topic attempt {state.attempt}. Retrying in {int(delay)}s...") + + state.attempt += 1 time.sleep(delay) thread = threading.Thread(target=run_stream, daemon=True) diff --git a/src/hiero_sdk_python/utils/subscription_handle.py b/src/hiero_sdk_python/utils/subscription_handle.py index f4ad199c1..b29f07cfb 100644 --- a/src/hiero_sdk_python/utils/subscription_handle.py +++ b/src/hiero_sdk_python/utils/subscription_handle.py @@ -1,6 +1,7 @@ from __future__ import annotations import threading +from typing import Any class SubscriptionHandle: @@ -12,11 +13,25 @@ class SubscriptionHandle: def __init__(self): self._cancelled = threading.Event() - self._thread = None + self._thread: threading.Thread | None = None + self._call: Any | None = None + self._lock = threading.Lock() + + def _set_call(self, call: Any): + """Sets the active gRPC call so it can be cancelled.""" + with self._lock: + self._call = call + + if self._cancelled.is_set(): + self._call.cancel() def cancel(self): """Signals to cancel the subscription.""" - self._cancelled.set() + with self._lock: + self._cancelled.set() + + if self._call: + self._call.cancel() def is_cancelled(self) -> bool: """Returns True if this subscription is already cancelled.""" From a2963eb95dae2e36e7a64856ce548b83b1a30a75 Mon Sep 17 00:00:00 2001 From: Manish Dait Date: Fri, 17 Apr 2026 00:46:26 +0530 Subject: [PATCH 2/7] chore: test workflow Signed-off-by: Manish Dait --- ...-check-secondary-unit-integration-test.yml | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/.github/workflows/pr-check-secondary-unit-integration-test.yml b/.github/workflows/pr-check-secondary-unit-integration-test.yml index 737df59c7..ba1ace34d 100644 --- a/.github/workflows/pr-check-secondary-unit-integration-test.yml +++ b/.github/workflows/pr-check-secondary-unit-integration-test.yml @@ -2,25 +2,6 @@ name: Secondary PR Check - Hiero Solo Integration & Unit Tests on: push: branches: - - "main" - paths-ignore: - - "**/*.md" - - "docs/**" - - "examples/**" - - "tck/**" - - "tests/fuzz/**" - - ".github/**" - - "!.github/workflows/pr-check-secondary-unit-integration-test.yml" - pull_request: - paths-ignore: - - "**/*.md" - - "docs/**" - - "examples/**" - - "tck/**" - - "tests/fuzz/**" - - ".github/**" - - "!.github/workflows/pr-check-secondary-unit-integration-test.yml" - workflow_dispatch: {} permissions: contents: read From a1b23188d2808f1c183d3ec46c6a8f7e8dd4bb4f Mon Sep 17 00:00:00 2001 From: Manish Dait Date: Fri, 17 Apr 2026 16:16:25 +0530 Subject: [PATCH 3/7] chore: test Signed-off-by: Manish Dait --- tests/unit/subscription_handle_test.py | 22 ++++++++++++ tests/unit/topic_message_query_test.py | 48 +++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/tests/unit/subscription_handle_test.py b/tests/unit/subscription_handle_test.py index 218da4c2d..5262c3bf1 100644 --- a/tests/unit/subscription_handle_test.py +++ b/tests/unit/subscription_handle_test.py @@ -6,17 +6,20 @@ def test_not_cancelled_by_default(): + """Test a new handle starts in a non-cancelled state.""" handle = SubscriptionHandle() assert not handle.is_cancelled() def test_cancel_marks_as_cancelled(): + """Test calling cancel updates the is_cancelled status.""" handle = SubscriptionHandle() handle.cancel() assert handle.is_cancelled() def test_set_thread_and_join_calls_thread_join_with_timeout(): + """Test that join correctly forwards the timeout to the underlying thread.""" handle = SubscriptionHandle() mock_thread = Mock() handle.set_thread(mock_thread) @@ -25,6 +28,25 @@ def test_set_thread_and_join_calls_thread_join_with_timeout(): def test_join_without_thread_raises_nothing(): + """Test join is a no-op if no thread has been associated.""" handle = SubscriptionHandle() # should not raise handle.join() + + +def test_cancel_triggers_grpc_termination(): + """Test that cancelling the handle terminates the active gRPC call.""" + handle = SubscriptionHandle() + mock_call = Mock() + handle._set_call(mock_call) + handle.cancel() + mock_call.cancel.assert_called_once() + + +def test_immediate_cancellation_of_late_call(): + """Test a gRPC call is cancelled immediately if set after the handle was cancelled.""" + handle = SubscriptionHandle() + mock_call = Mock() + handle.cancel() + handle._set_call(mock_call) + mock_call.cancel.assert_called_once() diff --git a/tests/unit/topic_message_query_test.py b/tests/unit/topic_message_query_test.py index 270aba881..0bca3c2bb 100644 --- a/tests/unit/topic_message_query_test.py +++ b/tests/unit/topic_message_query_test.py @@ -1,15 +1,19 @@ from __future__ import annotations +import time from datetime import datetime, timezone from unittest.mock import MagicMock, patch import pytest +from hiero_sdk_python.account.account_id import AccountId from hiero_sdk_python.client.client import Client from hiero_sdk_python.consensus.topic_id import TopicId from hiero_sdk_python.hapi.mirror import consensus_service_pb2 as mirror_proto from hiero_sdk_python.hapi.services import timestamp_pb2 as hapi_timestamp_pb2 +from hiero_sdk_python.hapi.services.consensus_submit_message_pb2 import ConsensusMessageChunkInfo from hiero_sdk_python.query.topic_message_query import TopicMessageQuery +from hiero_sdk_python.transaction.transaction_id import TransactionId pytestmark = pytest.mark.unit @@ -19,7 +23,8 @@ def mock_client(): """Fixture to provide a mock Client instance.""" client = MagicMock(spec=Client) - client.operator_account_id = "0.0.12345" + client.operator_account_id = AccountId(0, 0, 12345) + return client @@ -40,11 +45,20 @@ def mock_subscription_response(): ) +def test_topic_message_query_initialization(): + """Test initializing the query with various parameter types and setters.""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc) + query = TopicMessageQuery().set_topic_id("0.0.123").set_start_time(start).set_limit(5).set_chunking_enabled(True) + + assert query._topic_id.topicNum == 123 + assert query._start_time.seconds == int(start.timestamp()) + assert query._limit == 5 + assert query._chunking_enabled is True + + # This test uses fixtures (mock_client, mock_topic_id, mock_subscription_response) as parameters def test_topic_message_query_subscription(mock_client, mock_topic_id, mock_subscription_response): - """ - Test subscribing to topic messages using TopicMessageQuery. - """ + """Test subscribing to topic messages using TopicMessageQuery.""" query = TopicMessageQuery().set_topic_id(mock_topic_id).set_start_time(datetime.now(tz=timezone.utc)) with patch("hiero_sdk_python.query.topic_message_query.TopicMessageQuery.subscribe") as mock_subscribe: @@ -69,3 +83,29 @@ def side_effect(client, on_message, on_error): # noqa: ARG001 on_error.assert_not_called() print("Test passed: Subscription handled messages correctly.") + + +def test_chunk_message_handling(mock_client): + """Test that multiple chunks are correctly buffered and released as a single message.""" + query = TopicMessageQuery(topic_id="0.0.123", chunking_enabled=True) + + # Mocking two chunks for the same transaction + tx_id = TransactionId.generate(mock_client.operator_account_id)._to_proto() + chunk1 = mirror_proto.ConsensusTopicResponse( + message=b"part1", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id, total=2, number=1) + ) + chunk2 = mirror_proto.ConsensusTopicResponse( + message=b"part2", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id, total=2, number=2) + ) + + mock_client.mirror_stub.subscribeTopic.return_value = [chunk1, chunk2] + + received_messages = [] + query.subscribe(mock_client, on_message=lambda m: received_messages.append(m)) + + # Wait for thread execution + time.sleep(0.1) + + assert len(received_messages) == 1 + # Assuming TopicMessage.of_many joins messages + assert b"part1" in received_messages[0].message From 47ca92b4a67bd6f9afc889e4075fba3762e2a362 Mon Sep 17 00:00:00 2001 From: Manish Dait Date: Sat, 18 Apr 2026 13:48:35 +0530 Subject: [PATCH 4/7] chore: added test for the changes Signed-off-by: Manish Dait --- .../query/topic_message_query.py | 98 ++++++----- tests/unit/topic_message_query_test.py | 152 ++++++++++++++++-- 2 files changed, 192 insertions(+), 58 deletions(-) diff --git a/src/hiero_sdk_python/query/topic_message_query.py b/src/hiero_sdk_python/query/topic_message_query.py index af42d79fd..959396640 100644 --- a/src/hiero_sdk_python/query/topic_message_query.py +++ b/src/hiero_sdk_python/query/topic_message_query.py @@ -52,7 +52,7 @@ def __init__( self._topic_id: basic_types_pb2.TopicID | None = self._parse_topic_id(topic_id) if topic_id else None self._start_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(start_time) if start_time else None self._end_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(end_time) if end_time else None - self._limit: int | None = limit + self._limit: int = limit if limit is not None else 0 self._chunking_enabled: bool = chunking_enabled self._max_attempts: int = 10 @@ -147,6 +147,59 @@ def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp: nanos = int((dt.timestamp() - seconds) * 1e9) return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + def _build_query_request(self, state: SubscriptionState) -> mirror_proto.ConsensusTopicQuery: + """Build the request object based on current subscription state.""" + request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) + + if self._end_time is not None: + request.consensusEndTime.CopyFrom(self._end_time) + + if state.last_message is not None: + last_message_time = state.last_message.consensusTimestamp + + seconds = last_message_time.seconds + nanos = last_message_time.nanos + 1 + + if nanos >= 1_000_000_000: + seconds += 1 + nanos = 0 + + request.consensusStartTime.seconds = seconds + request.consensusStartTime.nanos = nanos + + if self._limit > 0: + request.limit = max(0, self._limit - state.count) + else: + if self._start_time is not None: + request.consensusStartTime.CopyFrom(self._start_time) + request.limit = self._limit + + return request + + def _handle_response(self, response, state: SubscriptionState, on_message: Callable[[TopicMessage], None]) -> None: + """Handles single or chunked messages.""" + state.count += 1 + state.last_message = response + + if not self._chunking_enabled or not response.HasField("chunkInfo") or response.chunkInfo.total <= 1: + message = TopicMessage.of_single(response) + on_message(message) + return + + initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID) + + if initial_tx_id not in state.pending_messages: + state.pending_messages[initial_tx_id] = [] + + chunks = state.pending_messages[initial_tx_id] + chunks.append(response) + + if len(chunks) == response.chunkInfo.total: + del state.pending_messages[initial_tx_id] + + message = TopicMessage.of_many(chunks) + on_message(message) + def subscribe( self, client: Client, @@ -164,22 +217,7 @@ def subscribe( def run_stream(): while state.attempt < self._max_attempts and not subscription_handle.is_cancelled(): - request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) - - if self._end_time is not None: - request.consensusEndTime.CopyFrom(self._end_time) - - if state.last_message is not None: - last_message_time = state.last_message.consensusTimestamp - request.consensusStartTime.seconds = last_message_time.seconds - request.consensusStartTime.nanos = last_message_time.nanos + 1 - - if self._limit > 0: - request.limit = max(0, self._limit - state.count) - else: - if self._start_time is not None: - request.consensusStartTime.CopyFrom(self._start_time) - request.limit = self._limit + request = self._build_query_request(state) try: message_stream = client.mirror_stub.subscribeTopic(request) @@ -189,31 +227,7 @@ def run_stream(): if subscription_handle.is_cancelled(): return - state.count += 1 - state.last_message = response - - if ( - not self._chunking_enabled - or not response.HasField("chunkInfo") - or response.chunkInfo.total <= 1 - ): - message = TopicMessage.of_single(response) - on_message(message) - continue - - initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID) - - if initial_tx_id not in state.pending_messages: - state.pending_messages[initial_tx_id] = [] - - chunks = state.pending_messages[initial_tx_id] - chunks.append(response) - - if len(chunks) == response.chunkInfo.total: - del state.pending_messages[initial_tx_id] - - message = TopicMessage.of_many(chunks) - on_message(message) + self._handle_response(response, state, on_message) if self._completion_handler: self._completion_handler() diff --git a/tests/unit/topic_message_query_test.py b/tests/unit/topic_message_query_test.py index 0bca3c2bb..74dd3d476 100644 --- a/tests/unit/topic_message_query_test.py +++ b/tests/unit/topic_message_query_test.py @@ -4,6 +4,7 @@ from datetime import datetime, timezone from unittest.mock import MagicMock, patch +import grpc import pytest from hiero_sdk_python.account.account_id import AccountId @@ -14,6 +15,7 @@ from hiero_sdk_python.hapi.services.consensus_submit_message_pb2 import ConsensusMessageChunkInfo from hiero_sdk_python.query.topic_message_query import TopicMessageQuery from hiero_sdk_python.transaction.transaction_id import TransactionId +from tests.unit.mock_server import RealRpcError pytestmark = pytest.mark.unit @@ -24,6 +26,7 @@ def mock_client(): """Fixture to provide a mock Client instance.""" client = MagicMock(spec=Client) client.operator_account_id = AccountId(0, 0, 12345) + client.mirror_stub = MagicMock() return client @@ -39,7 +42,7 @@ def mock_subscription_response(): """Fixture to provide a mock response from a topic subscription.""" return mirror_proto.ConsensusTopicResponse( consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=12345, nanos=67890), - message=b"Hello, world!", + message=b"Hello Hiero!", runningHash=b"\x00" * 48, sequenceNumber=1, ) @@ -56,7 +59,49 @@ def test_topic_message_query_initialization(): assert query._chunking_enabled is True -# This test uses fixtures (mock_client, mock_topic_id, mock_subscription_response) as parameters +def test_topic_message_query_invalid_max_backoff(): + """Test that invalid max_backoff raises errors.""" + query = TopicMessageQuery() + + with pytest.raises(ValueError, match="max_backoff must be at least 500 ms"): + query.set_max_backoff(0.1) + + +def test_topic_message_query_invalid_max_attempts(): + """Test that invalid max_attempts raises errors.""" + query = TopicMessageQuery() + + with pytest.raises(ValueError, match="max_attempts must be greater than 0"): + query.set_max_attempts(0) + + +def test_topic_message_query_invalid_topic_id(): + """Test that invalid topic_id raises errors.""" + query = TopicMessageQuery() + + # Invalid TopicId type + with pytest.raises(TypeError, match="Invalid topic_id format"): + query.set_topic_id(12345) + + # Invalid TopicId format + with pytest.raises(ValueError, match="Invalid topic ID string"): + query.set_topic_id("12345") + + +def test_subscribe_missing_config(mock_client): + """Test that subscribe fails if Topic ID or Mirror Stub is missing.""" + # No TopicId + query_no_id = TopicMessageQuery() + with pytest.raises(ValueError, match="Topic ID must be set before subscribing"): + query_no_id.subscribe(mock_client, on_message=MagicMock()) + + # No MirrorStub + query_ok = TopicMessageQuery(topic_id="0.0.123") + mock_client.mirror_stub = None + with pytest.raises(ValueError, match="Client has no mirror_stub"): + query_ok.subscribe(mock_client, on_message=MagicMock()) + + def test_topic_message_query_subscription(mock_client, mock_topic_id, mock_subscription_response): """Test subscribing to topic messages using TopicMessageQuery.""" query = TopicMessageQuery().set_topic_id(mock_topic_id).set_start_time(datetime.now(tz=timezone.utc)) @@ -77,35 +122,110 @@ def side_effect(client, on_message, on_error): # noqa: ARG001 called_args = on_message.call_args[0][0] assert called_args.consensusTimestamp.seconds == 12345 assert called_args.consensusTimestamp.nanos == 67890 - assert called_args.message == b"Hello, world!" + assert called_args.message == b"Hello Hiero!" assert called_args.sequenceNumber == 1 on_error.assert_not_called() - print("Test passed: Subscription handled messages correctly.") - -def test_chunk_message_handling(mock_client): +def test_chunk_message_handling_improved(mock_client): """Test that multiple chunks are correctly buffered and released as a single message.""" query = TopicMessageQuery(topic_id="0.0.123", chunking_enabled=True) - # Mocking two chunks for the same transaction - tx_id = TransactionId.generate(mock_client.operator_account_id)._to_proto() + tx_id = TransactionId.generate(mock_client.operator_account_id) + tx_id_proto = tx_id._to_proto() + chunk1 = mirror_proto.ConsensusTopicResponse( - message=b"part1", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id, total=2, number=1) + message=b"chunk-1", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=1) ) chunk2 = mirror_proto.ConsensusTopicResponse( - message=b"part2", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id, total=2, number=2) + message=b"chunk-2", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=2) ) - mock_client.mirror_stub.subscribeTopic.return_value = [chunk1, chunk2] + mock_client.mirror_stub.subscribeTopic.return_value = iter([chunk1, chunk2]) received_messages = [] - query.subscribe(mock_client, on_message=lambda m: received_messages.append(m)) + handle = query.subscribe(mock_client, on_message=lambda m: received_messages.append(m)) - # Wait for thread execution - time.sleep(0.1) + handle._thread.join(timeout=1.0) assert len(received_messages) == 1 - # Assuming TopicMessage.of_many joins messages - assert b"part1" in received_messages[0].message + assert b"chunk-1" in received_messages[0].contents + assert b"chunk-2" in received_messages[0].contents + + +@pytest.mark.parametrize( + "error", + [ + RealRpcError(grpc.StatusCode.NOT_FOUND, "unavailable"), + RealRpcError(grpc.StatusCode.UNAVAILABLE, "unavailable"), + RealRpcError(grpc.StatusCode.RESOURCE_EXHAUSTED, "busy"), + RealRpcError(grpc.StatusCode.INTERNAL, "received rst stream"), # internal with rst stream + Exception("non grpc exception"), # non grpc exception + ], +) +def test_retry_logic_on_retryable_error(mock_client, error): + """Test that the query retries on retryable errors but stops after max_attempts.""" + query = TopicMessageQuery(topic_id="0.0.123").set_max_attempts(2).set_max_backoff(0.5) + + mock_client.mirror_stub.subscribeTopic.side_effect = [error, error] + + handle = query.subscribe(mock_client, on_message=MagicMock(), on_error=MagicMock()) + + handle._thread.join(timeout=2.0) + + assert mock_client.mirror_stub.subscribeTopic.call_count == 2 + + +@pytest.mark.parametrize( + "non_retryable_error", + [ + RealRpcError(grpc.StatusCode.PERMISSION_DENIED, "permission denied"), + RealRpcError(grpc.StatusCode.INVALID_ARGUMENT, "invalid argument"), + RealRpcError(grpc.StatusCode.UNAUTHENTICATED, "unauthenticated"), + RealRpcError(grpc.StatusCode.INTERNAL, "internal error"), + ], +) +def test_retry_logic_on_non_retryable_error(mock_client, non_retryable_error): + """Test that the query stops immediately on non-transient errors.""" + query = TopicMessageQuery(topic_id="0.0.123").set_max_attempts(5).set_max_backoff(0.5) + + mock_client.mirror_stub.subscribeTopic.side_effect = [non_retryable_error] * 5 + + on_error = MagicMock() + handle = query.subscribe(mock_client, on_message=MagicMock(), on_error=on_error) + + handle._thread.join(timeout=1.0) + + assert mock_client.mirror_stub.subscribeTopic.call_count == 1 + on_error.assert_called_once_with(non_retryable_error) + + assert not handle._thread.is_alive() + + +def test_subscription_cancellation(mock_client): + """Test that cancelling a handle stops the subscription thread.""" + query = TopicMessageQuery(topic_id="0.0.123") + + def infinite_stream(): + while True: + yield mirror_proto.ConsensusTopicResponse(message=b"ping") + time.sleep(0.1) + + mock_call = MagicMock() + mock_call.__iter__.return_value = infinite_stream() + + mock_client.mirror_stub.subscribeTopic.return_value = mock_call + + on_message = MagicMock() + handle = query.subscribe(mock_client, on_message=on_message) + + time.sleep(0.2) + assert handle._thread.is_alive() + + handle.cancel() + + handle._thread.join(timeout=1.0) + + assert not handle._thread.is_alive() + mock_call.cancel.assert_called() From 57c6f7e712c3b891b51f633198ba7338040479b1 Mon Sep 17 00:00:00 2001 From: Manish Dait Date: Sat, 18 Apr 2026 14:01:04 +0530 Subject: [PATCH 5/7] chore: remove test Signed-off-by: Manish Dait --- ...-check-secondary-unit-integration-test.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.github/workflows/pr-check-secondary-unit-integration-test.yml b/.github/workflows/pr-check-secondary-unit-integration-test.yml index ba1ace34d..737df59c7 100644 --- a/.github/workflows/pr-check-secondary-unit-integration-test.yml +++ b/.github/workflows/pr-check-secondary-unit-integration-test.yml @@ -2,6 +2,25 @@ name: Secondary PR Check - Hiero Solo Integration & Unit Tests on: push: branches: + - "main" + paths-ignore: + - "**/*.md" + - "docs/**" + - "examples/**" + - "tck/**" + - "tests/fuzz/**" + - ".github/**" + - "!.github/workflows/pr-check-secondary-unit-integration-test.yml" + pull_request: + paths-ignore: + - "**/*.md" + - "docs/**" + - "examples/**" + - "tck/**" + - "tests/fuzz/**" + - ".github/**" + - "!.github/workflows/pr-check-secondary-unit-integration-test.yml" + workflow_dispatch: {} permissions: contents: read From 6073c05966820ea061bb5cbf91113b190ed15a2e Mon Sep 17 00:00:00 2001 From: Manish Dait Date: Wed, 20 May 2026 15:10:45 +0530 Subject: [PATCH 6/7] chore: address some review sugesstions Signed-off-by: Manish Dait --- src/hiero_sdk_python/query/topic_message_query.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hiero_sdk_python/query/topic_message_query.py b/src/hiero_sdk_python/query/topic_message_query.py index 959396640..2b10a8527 100644 --- a/src/hiero_sdk_python/query/topic_message_query.py +++ b/src/hiero_sdk_python/query/topic_message_query.py @@ -29,7 +29,7 @@ class SubscriptionState: attempt: int = 0 count: int = 0 last_message: mirror_proto.ConsensusTopicResponse | None = None - pending_messages: dict[str, list[mirror_proto.ConsensusTopicResponse]] = field(default_factory=dict) + pending_messages: dict[TransactionId, list[mirror_proto.ConsensusTopicResponse]] = field(default_factory=dict) class TopicMessageQuery: @@ -82,8 +82,8 @@ def set_completion_handler(self, handler: Callable[[], None]) -> TopicMessageQue self._completion_handler = handler return self - def set_error_handler(self, handler: Callable[[], None]) -> TopicMessageQuery: - """Sets a completion handler that is called when the subscription completes.""" + def set_error_handler(self, handler: Callable[[Exception], None]) -> TopicMessageQuery: + """Sets an error handler that is called when the subscription causes an error.""" self._error_handler = handler return self @@ -217,6 +217,7 @@ def subscribe( def run_stream(): while state.attempt < self._max_attempts and not subscription_handle.is_cancelled(): + state.attempt += 1 request = self._build_query_request(state) try: @@ -247,7 +248,6 @@ def run_stream(): delay = min(0.5 * (2 ** (state.attempt)), self._max_backoff) logger.warning(f"Error subscribing to topic attempt {state.attempt}. Retrying in {int(delay)}s...") - state.attempt += 1 time.sleep(delay) thread = threading.Thread(target=run_stream, daemon=True) From bb49b1484c506688da6a78571ad2b991edc34517 Mon Sep 17 00:00:00 2001 From: Manish Dait Date: Thu, 21 May 2026 16:27:59 +0530 Subject: [PATCH 7/7] chore: extend test coverage Signed-off-by: Manish Dait --- .../query/topic_message_query.py | 6 + tests/unit/topic_message_query_test.py | 185 +++++++++++++++++- 2 files changed, 187 insertions(+), 4 deletions(-) diff --git a/src/hiero_sdk_python/query/topic_message_query.py b/src/hiero_sdk_python/query/topic_message_query.py index 2b10a8527..ad0269d70 100644 --- a/src/hiero_sdk_python/query/topic_message_query.py +++ b/src/hiero_sdk_python/query/topic_message_query.py @@ -79,11 +79,17 @@ def set_max_backoff(self, backoff: float) -> TopicMessageQuery: def set_completion_handler(self, handler: Callable[[], None]) -> TopicMessageQuery: """Sets a completion handler that is called when the subscription completes.""" + if not callable(handler): + raise TypeError("handler must be a callable object") + self._completion_handler = handler return self def set_error_handler(self, handler: Callable[[Exception], None]) -> TopicMessageQuery: """Sets an error handler that is called when the subscription causes an error.""" + if not callable(handler): + raise TypeError("handler must be a callable object") + self._error_handler = handler return self diff --git a/tests/unit/topic_message_query_test.py b/tests/unit/topic_message_query_test.py index 74dd3d476..9a1828265 100644 --- a/tests/unit/topic_message_query_test.py +++ b/tests/unit/topic_message_query_test.py @@ -13,7 +13,7 @@ from hiero_sdk_python.hapi.mirror import consensus_service_pb2 as mirror_proto from hiero_sdk_python.hapi.services import timestamp_pb2 as hapi_timestamp_pb2 from hiero_sdk_python.hapi.services.consensus_submit_message_pb2 import ConsensusMessageChunkInfo -from hiero_sdk_python.query.topic_message_query import TopicMessageQuery +from hiero_sdk_python.query.topic_message_query import SubscriptionState, TopicMessageQuery from hiero_sdk_python.transaction.transaction_id import TransactionId from tests.unit.mock_server import RealRpcError @@ -48,15 +48,35 @@ def mock_subscription_response(): ) +# Initialization + + def test_topic_message_query_initialization(): """Test initializing the query with various parameter types and setters.""" start = datetime(2023, 1, 1, tzinfo=timezone.utc) - query = TopicMessageQuery().set_topic_id("0.0.123").set_start_time(start).set_limit(5).set_chunking_enabled(True) + + def mock_complete(): + pass + + def mock_error(e): + pass + + query = ( + TopicMessageQuery() + .set_topic_id("0.0.123") + .set_start_time(start) + .set_limit(5) + .set_chunking_enabled(True) + .set_completion_handler(mock_complete) + .set_error_handler(mock_error) + ) assert query._topic_id.topicNum == 123 assert query._start_time.seconds == int(start.timestamp()) assert query._limit == 5 assert query._chunking_enabled is True + assert query._completion_handler == mock_complete + assert query._error_handler == mock_error def test_topic_message_query_invalid_max_backoff(): @@ -102,6 +122,138 @@ def test_subscribe_missing_config(mock_client): query_ok.subscribe(mock_client, on_message=MagicMock()) +@pytest.mark.parametrize( + "handler", + ["string", 1, True, None, [], {}], +) +def test_topic_message_query_invalid_handler_param(handler): + """Test that a non-callable handler raises a TypeError.""" + query = TopicMessageQuery() + + # For complete_handler + with pytest.raises(TypeError, match="handler must be a callable object"): + query.set_completion_handler(handler) + # For error_handler + with pytest.raises(TypeError, match="handler must be a callable object"): + query.set_error_handler(handler) + + +# build_query_request + + +def test_build_query_request_uses_provided_start_time(): + """Test that the request uses the provided start_time when no last_message exists.""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc) + query = TopicMessageQuery(topic_id="0.0.123").set_start_time(start) + state = SubscriptionState() + + state.last_message = None + + expected_start = query._start_time + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == expected_start.seconds + assert request.consensusStartTime.nanos == expected_start.nanos + + +def test_build_query_request_from_last_message_timestamp(): + """Test that a reconnection request overrides start_time using the last message timestamp + 1 nano""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc) + query = TopicMessageQuery(topic_id="0.0.123").set_start_time(start) + state = SubscriptionState() + + state.last_message = state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=50, nanos=10) + ) + + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == 50 + assert request.consensusStartTime.nanos == 11 + + +def test_build_query_request_nanosecond_rollover(): + """Test that nanos reaching 1_000_000_000 correctly increments the seconds field.""" + query = TopicMessageQuery(topic_id="0.0.123") + state = SubscriptionState() + + # Mock message arriving at the 999_999_999 nanosecond of second 50 + state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=50, nanos=999_999_999) + ) + + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == 51 + assert request.consensusStartTime.nanos == 0 + + +def test_build_query_request_limit_decrements_on_retry(): + """Test that retry requests ask only for the remaining messages within the limit.""" + query = TopicMessageQuery(topic_id="0.0.123").set_limit(10) + state = SubscriptionState() + + # Mock state already have collect 4 message + state.count = 4 + state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=100, nanos=500) + ) + + request = query._build_query_request(state) + + # New request only ask for the remaining 6 message (10 - 4) + assert request.limit == 6 + + +def test_build_query_request_limit_floor_at_zero(): + """Test that remaining request limits never drop below zero.""" + query = TopicMessageQuery(topic_id="0.0.123").set_limit(5) + state = SubscriptionState() + + state.count = 6 + state.last_message = mirror_proto.ConsensusTopicResponse( + consensusTimestamp=hapi_timestamp_pb2.Timestamp(seconds=100, nanos=500) + ) + + request = query._build_query_request(state) + + # The limit should be 0 rather than becoming negative + assert request.limit == 0 + + +def test_build_query_request_set_end_time_if_provided(): + """Test that request is created with the end_time if provided.""" + end = datetime(2023, 1, 1, tzinfo=timezone.utc) + query = TopicMessageQuery(topic_id="0.0.123").set_end_time(end) + + state = SubscriptionState() + state.last_message = None + + expected_end = query._end_time + request = query._build_query_request(state) + + assert request.consensusEndTime.seconds == expected_end.seconds + assert request.consensusEndTime.nanos == expected_end.nanos + + +def test_build_query_request_set_start_end_time_to_default(): + """Test that request is created with start_time and end_time as None if not present.""" + query = TopicMessageQuery(topic_id="0.0.123") + state = SubscriptionState() + state.last_message = None + + request = query._build_query_request(state) + + assert request.consensusStartTime.seconds == 0 + assert request.consensusStartTime.nanos == 0 + + assert request.consensusEndTime.seconds == 0 + assert request.consensusEndTime.nanos == 0 + + +# handle_response / subscribe + + def test_topic_message_query_subscription(mock_client, mock_topic_id, mock_subscription_response): """Test subscribing to topic messages using TopicMessageQuery.""" query = TopicMessageQuery().set_topic_id(mock_topic_id).set_start_time(datetime.now(tz=timezone.utc)) @@ -128,7 +280,7 @@ def side_effect(client, on_message, on_error): # noqa: ARG001 on_error.assert_not_called() -def test_chunk_message_handling_improved(mock_client): +def test_chunk_message_handling(mock_client): """Test that multiple chunks are correctly buffered and released as a single message.""" query = TopicMessageQuery(topic_id="0.0.123", chunking_enabled=True) @@ -154,6 +306,32 @@ def test_chunk_message_handling_improved(mock_client): assert b"chunk-2" in received_messages[0].contents +def test_chunk_message_handling_when_chunking_is_disabled(mock_client): + """Test that when chunking is disabled only single chunk is released as a single message.""" + query = TopicMessageQuery(topic_id="0.0.123", chunking_enabled=False) + + tx_id = TransactionId.generate(mock_client.operator_account_id) + tx_id_proto = tx_id._to_proto() + + chunk1 = mirror_proto.ConsensusTopicResponse( + message=b"chunk-1", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=1) + ) + chunk2 = mirror_proto.ConsensusTopicResponse( + message=b"chunk-2", chunkInfo=ConsensusMessageChunkInfo(initialTransactionID=tx_id_proto, total=2, number=2) + ) + + mock_client.mirror_stub.subscribeTopic.return_value = iter([chunk1, chunk2]) + + received_messages = [] + handle = query.subscribe(mock_client, on_message=lambda m: received_messages.append(m)) + + handle._thread.join(timeout=1.0) + + assert len(received_messages) == 2 # since we will get 2 seperate message + assert b"chunk-1" in received_messages[0].contents + assert b"chunk-2" not in received_messages[0].contents + + @pytest.mark.parametrize( "error", [ @@ -169,7 +347,6 @@ def test_retry_logic_on_retryable_error(mock_client, error): query = TopicMessageQuery(topic_id="0.0.123").set_max_attempts(2).set_max_backoff(0.5) mock_client.mirror_stub.subscribeTopic.side_effect = [error, error] - handle = query.subscribe(mock_client, on_message=MagicMock(), on_error=MagicMock()) handle._thread.join(timeout=2.0)