Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 139 additions & 55 deletions src/hiero_sdk_python/query/topic_message_query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

import logging
import re
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime

import grpc

from hiero_sdk_python.client.client import Client
from hiero_sdk_python.consensus.topic_id import TopicId
from hiero_sdk_python.consensus.topic_message import TopicMessage
Expand All @@ -14,6 +19,19 @@
from hiero_sdk_python.utils.subscription_handle import SubscriptionHandle


logger = logging.getLogger(__name__)

RST_STREAM = re.compile(r"\brst[^0-9a-zA-Z]stream\b", re.IGNORECASE | re.DOTALL)


@dataclass
class SubscriptionState:
attempt: int = 0
count: int = 0
last_message: mirror_proto.ConsensusTopicResponse | None = None
pending_messages: dict[TransactionId, list[mirror_proto.ConsensusTopicResponse]] = field(default_factory=dict)


class TopicMessageQuery:
"""
A query to subscribe to messages from a specific HCS topic, via a mirror node.
Expand All @@ -31,48 +49,49 @@ def __init__(
chunking_enabled: bool = False,
) -> None:
"""Initializes a TopicMessageQuery."""
self._topic_id: TopicId | None = self._parse_topic_id(topic_id) if topic_id else None
self._topic_id: basic_types_pb2.TopicID | None = self._parse_topic_id(topic_id) if topic_id else None
self._start_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(start_time) if start_time else None
self._end_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(end_time) if end_time else None
self._limit: int | None = limit
self._limit: int = limit if limit is not None else 0
self._chunking_enabled: bool = chunking_enabled
self._completion_handler: Callable[[], None] | None = None

self._max_attempts: int = 10
self._max_backoff: float = 8.0

self._completion_handler: Callable[[], None] | None = self._on_complete
self._error_handler: Callable[[], None] | None = self._on_error
Comment thread
manishdait marked this conversation as resolved.

def set_max_attempts(self, attempts: int) -> TopicMessageQuery:
"""Sets the maximum number of attempts to reconnect on failure."""
if attempts <= 0:
raise ValueError("max_attempts must be greater than 0")

self._max_attempts = attempts
return self

def set_max_backoff(self, backoff: float) -> TopicMessageQuery:
"""Sets the maximum backoff time in seconds for reconnection attempts."""
if backoff < 0.5:
raise ValueError("max_backoff must be at least 500 ms")

self._max_backoff = backoff
return self

def set_completion_handler(self, handler: Callable[[], None]) -> TopicMessageQuery:
"""Sets a completion handler that is called when the subscription completes."""
if not callable(handler):
raise TypeError("handler must be a callable object")

self._completion_handler = handler
return self

def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID:
"""Parses a topic ID from a string or TopicId object into a protobuf TopicID."""
if isinstance(topic_id, str):
parts = topic_id.strip().split(".")
if len(parts) != 3:
raise ValueError(f"Invalid topic ID string: {topic_id}")
shard, realm, topic = map(int, parts)
return basic_types_pb2.TopicID(shardNum=shard, realmNum=realm, topicNum=topic)
if isinstance(topic_id, TopicId):
return topic_id._to_proto()
raise TypeError("Invalid topic_id format. Must be a string or TopicId.")
def set_error_handler(self, handler: Callable[[Exception], None]) -> TopicMessageQuery:
"""Sets an error handler that is called when the subscription causes an error."""
if not callable(handler):
raise TypeError("handler must be a callable object")

def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp:
"""Converts a datetime object to a protobuf Timestamp."""
seconds = int(dt.timestamp())
nanos = int((dt.timestamp() - seconds) * 1e9)
return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
self._error_handler = handler
return self

def set_topic_id(self, topic_id: str | TopicId) -> TopicMessageQuery:
"""Sets the topic ID for the query."""
Expand All @@ -99,6 +118,94 @@ def set_chunking_enabled(self, enabled: bool) -> TopicMessageQuery:
self._chunking_enabled = enabled
return self

def _on_complete(self) -> None:
logger.info(f"Subscription to topic {self._topic_id} complete")

def _on_error(self, err: Exception) -> None:
if isinstance(err, grpc.RpcError) and err.code() == grpc.StatusCode.CANCELLED:
logger.warning(f"Call is cancelled for topic {self._topic_id}")
else:
logger.error(f"Error attempting to subscribe to topic {self._topic_id}: {err}")

def _should_retry(self, err: Exception) -> bool:
if isinstance(err, grpc.RpcError):
return err.code() in (
grpc.StatusCode.NOT_FOUND,
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.RESOURCE_EXHAUSTED,
) or (err.code() == grpc.StatusCode.INTERNAL and bool(RST_STREAM.search(err.details())))

return True
Comment thread
manishdait marked this conversation as resolved.

def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID:
"""Parses a topic ID from a string or TopicId object into a protobuf TopicID."""
if isinstance(topic_id, str):
topic_id = TopicId.from_string(topic_id)

if isinstance(topic_id, TopicId):
return topic_id._to_proto()

raise TypeError("Invalid topic_id format. Must be a string or TopicId.")

def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp:
"""Converts a datetime object to a protobuf Timestamp."""
seconds = int(dt.timestamp())
nanos = int((dt.timestamp() - seconds) * 1e9)
return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)

def _build_query_request(self, state: SubscriptionState) -> mirror_proto.ConsensusTopicQuery:
"""Build the request object based on current subscription state."""
request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id)

if self._end_time is not None:
request.consensusEndTime.CopyFrom(self._end_time)

if state.last_message is not None:
last_message_time = state.last_message.consensusTimestamp

seconds = last_message_time.seconds
nanos = last_message_time.nanos + 1

if nanos >= 1_000_000_000:
seconds += 1
nanos = 0

request.consensusStartTime.seconds = seconds
request.consensusStartTime.nanos = nanos

if self._limit > 0:
request.limit = max(0, self._limit - state.count)
Comment on lines +176 to +177
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Count emitted TopicMessages before decrementing limit.

  • File and line: src/hiero_sdk_python/query/topic_message_query.py, Lines 176-177 and 185-207
  • Proto field: ConsensusTopicQuery.limit (#4). (raw.githubusercontent.com)
  • Issue type: Wrong default
  • Description: state.count is subtracted from the next request's limit, but it is incremented for every ConsensusTopicResponse, including intermediate chunks. On retries with chunking enabled, that makes the remaining limit track wire chunks instead of delivered TopicMessages. Once the subtraction reaches 0, the retry request serializes limit = 0, and the schema defines zero/unset as "receive indefinitely", so a resumed subscription can either stop before enough logical messages are emitted or run past the caller's cap. (raw.githubusercontent.com)
  • Suggested fix: Increment state.count only after on_message() runs for a complete logical message, and stop retrying once state.count >= self._limit instead of issuing another request with limit = 0.
🐛 Proposed fix
     def _build_query_request(self, state: SubscriptionState) -> mirror_proto.ConsensusTopicQuery:
         """Build the request object based on current subscription state."""
         request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id)

         if self._end_time is not None:
             request.consensusEndTime.CopyFrom(self._end_time)

         if state.last_message is not None:
             last_message_time = state.last_message.consensusTimestamp
@@
             request.consensusStartTime.seconds = seconds
             request.consensusStartTime.nanos = nanos

             if self._limit > 0:
-                request.limit = max(0, self._limit - state.count)
+                request.limit = self._limit - state.count
         else:
             if self._start_time is not None:
                 request.consensusStartTime.CopyFrom(self._start_time)
             request.limit = self._limit

         return request

     def _handle_response(self, response, state: SubscriptionState, on_message: Callable[[TopicMessage], None]) -> None:
         """Handles single or chunked messages."""
-        state.count += 1
         state.last_message = response

         if not self._chunking_enabled or not response.HasField("chunkInfo") or response.chunkInfo.total <= 1:
             message = TopicMessage.of_single(response)
             on_message(message)
+            state.count += 1
             return
@@
         if len(chunks) == response.chunkInfo.total:
             del state.pending_messages[initial_tx_id]

             message = TopicMessage.of_many(chunks)
             on_message(message)
+            state.count += 1

     def subscribe(
@@
         def run_stream():
             while state.attempt < self._max_attempts and not subscription_handle.is_cancelled():
+                if self._limit > 0 and state.count >= self._limit:
+                    if self._completion_handler:
+                        self._completion_handler()
+                    return
+
                 state.attempt += 1
                 request = self._build_query_request(state)

As per coding guidelines, "Compare the SDK class against the proto schema" and "Ensure Query code remains: Backward-compatible."

Also applies to: 185-207

else:
if self._start_time is not None:
request.consensusStartTime.CopyFrom(self._start_time)
request.limit = self._limit

return request

def _handle_response(self, response, state: SubscriptionState, on_message: Callable[[TopicMessage], None]) -> None:
"""Handles single or chunked messages."""
state.count += 1
state.last_message = response

if not self._chunking_enabled or not response.HasField("chunkInfo") or response.chunkInfo.total <= 1:
message = TopicMessage.of_single(response)
on_message(message)
return

initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID)

if initial_tx_id not in state.pending_messages:
state.pending_messages[initial_tx_id] = []

chunks = state.pending_messages[initial_tx_id]
chunks.append(response)

if len(chunks) == response.chunkInfo.total:
del state.pending_messages[initial_tx_id]

message = TopicMessage.of_many(chunks)
on_message(message)

def subscribe(
self,
client: Client,
Expand All @@ -111,49 +218,23 @@ def subscribe(
if not client.mirror_stub:
raise ValueError("Client has no mirror_stub. Did you configure a mirror node address?")

request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id)
if self._start_time:
request.consensusStartTime.CopyFrom(self._start_time)
if self._end_time:
request.consensusEndTime.CopyFrom(self._end_time)
if self._limit is not None:
request.limit = self._limit

subscription_handle = SubscriptionHandle()

pending_chunks: dict[str, list[mirror_proto.ConsensusTopicResponse]] = {}
state = SubscriptionState()

def run_stream():
attempt = 0
while attempt < self._max_attempts and not subscription_handle.is_cancelled():
while state.attempt < self._max_attempts and not subscription_handle.is_cancelled():
Comment thread
coderabbitai[bot] marked this conversation as resolved.
state.attempt += 1
request = self._build_query_request(state)

try:
message_stream = client.mirror_stub.subscribeTopic(request)
subscription_handle._set_call(message_stream)

for response in message_stream:
if subscription_handle.is_cancelled():
return

if (
not self._chunking_enabled
or not response.HasField("chunkInfo")
or response.chunkInfo.total <= 1
):
msg_obj = TopicMessage.of_single(response)
on_message(msg_obj)
continue

initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID)

if initial_tx_id not in pending_chunks:
pending_chunks[initial_tx_id] = []

pending_chunks[initial_tx_id].append(response)

if len(pending_chunks[initial_tx_id]) == response.chunkInfo.total:
chunk_list = pending_chunks.pop(initial_tx_id)

msg_obj = TopicMessage.of_many(chunk_list)
on_message(msg_obj)
self._handle_response(response, state, on_message)

if self._completion_handler:
self._completion_handler()
Expand All @@ -163,13 +244,16 @@ def run_stream():
if subscription_handle.is_cancelled():
return

attempt += 1
if attempt >= self._max_attempts:
if state.attempt >= self._max_attempts or not self._should_retry(e):
if self._error_handler:
self._error_handler(e)
if on_error:
on_error(e)
return

delay = min(0.5 * (2 ** (attempt - 1)), self._max_backoff)
delay = min(0.5 * (2 ** (state.attempt)), self._max_backoff)
logger.warning(f"Error subscribing to topic attempt {state.attempt}. Retrying in {int(delay)}s...")

time.sleep(delay)

thread = threading.Thread(target=run_stream, daemon=True)
Expand Down
19 changes: 17 additions & 2 deletions src/hiero_sdk_python/utils/subscription_handle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import threading
from typing import Any


class SubscriptionHandle:
Expand All @@ -12,11 +13,25 @@ class SubscriptionHandle:

def __init__(self):
self._cancelled = threading.Event()
self._thread = None
self._thread: threading.Thread | None = None
self._call: Any | None = None
self._lock = threading.Lock()

def _set_call(self, call: Any):
"""Sets the active gRPC call so it can be cancelled."""
with self._lock:
self._call = call

if self._cancelled.is_set():
self._call.cancel()
Comment on lines +20 to +26
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid calling cancel() while holding _lock.

Line 26 and Line 34 invoke cancel() inside the critical section. If that call blocks or re-enters related paths, this can stall or deadlock cancellation. Capture the call reference while locked, then cancel outside the lock. Also use is not None at Line 33.

Suggested fix
 def _set_call(self, call: Any):
     """Sets the active gRPC call so it can be cancelled."""
+    call_to_cancel = None
     with self._lock:
         self._call = call
 
         if self._cancelled.is_set():
-            self._call.cancel()
+            call_to_cancel = call
+    if call_to_cancel is not None:
+        call_to_cancel.cancel()
 
 def cancel(self):
     """Signals to cancel the subscription."""
+    call_to_cancel = None
     with self._lock:
         self._cancelled.set()
 
-        if self._call:
-            self._call.cancel()
+        if self._call is not None:
+            call_to_cancel = self._call
+    if call_to_cancel is not None:
+        call_to_cancel.cancel()

Also applies to: 30-34


def cancel(self):
"""Signals to cancel the subscription."""
self._cancelled.set()
with self._lock:
self._cancelled.set()

if self._call:
self._call.cancel()

def is_cancelled(self) -> bool:
"""Returns True if this subscription is already cancelled."""
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/subscription_handle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@


def test_not_cancelled_by_default():
"""Test a new handle starts in a non-cancelled state."""
handle = SubscriptionHandle()
assert not handle.is_cancelled()


def test_cancel_marks_as_cancelled():
"""Test calling cancel updates the is_cancelled status."""
handle = SubscriptionHandle()
handle.cancel()
assert handle.is_cancelled()


def test_set_thread_and_join_calls_thread_join_with_timeout():
"""Test that join correctly forwards the timeout to the underlying thread."""
handle = SubscriptionHandle()
mock_thread = Mock()
handle.set_thread(mock_thread)
Expand All @@ -25,6 +28,25 @@ def test_set_thread_and_join_calls_thread_join_with_timeout():


def test_join_without_thread_raises_nothing():
"""Test join is a no-op if no thread has been associated."""
handle = SubscriptionHandle()
# should not raise
handle.join()


def test_cancel_triggers_grpc_termination():
"""Test that cancelling the handle terminates the active gRPC call."""
handle = SubscriptionHandle()
mock_call = Mock()
handle._set_call(mock_call)
handle.cancel()
mock_call.cancel.assert_called_once()


def test_immediate_cancellation_of_late_call():
"""Test a gRPC call is cancelled immediately if set after the handle was cancelled."""
handle = SubscriptionHandle()
mock_call = Mock()
handle.cancel()
handle._set_call(mock_call)
mock_call.cancel.assert_called_once()
Loading
Loading