-
Notifications
You must be signed in to change notification settings - Fork 283
feat: Implement stateful retry and resumable stream logic for TopicMessageQuery #2171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
197289a
a2963eb
a1b2318
47ca92b
57c6f7e
6073c05
bb49b14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,15 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import re | ||
| import threading | ||
| import time | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime | ||
|
|
||
| import grpc | ||
|
|
||
| from hiero_sdk_python.client.client import Client | ||
| from hiero_sdk_python.consensus.topic_id import TopicId | ||
| from hiero_sdk_python.consensus.topic_message import TopicMessage | ||
|
|
@@ -14,6 +19,19 @@ | |
| from hiero_sdk_python.utils.subscription_handle import SubscriptionHandle | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| RST_STREAM = re.compile(r"\brst[^0-9a-zA-Z]stream\b", re.IGNORECASE | re.DOTALL) | ||
|
|
||
|
|
||
| @dataclass | ||
| class SubscriptionState: | ||
| attempt: int = 0 | ||
| count: int = 0 | ||
| last_message: mirror_proto.ConsensusTopicResponse | None = None | ||
| pending_messages: dict[TransactionId, list[mirror_proto.ConsensusTopicResponse]] = field(default_factory=dict) | ||
|
|
||
|
|
||
| class TopicMessageQuery: | ||
| """ | ||
| A query to subscribe to messages from a specific HCS topic, via a mirror node. | ||
|
|
@@ -31,48 +49,49 @@ def __init__( | |
| chunking_enabled: bool = False, | ||
| ) -> None: | ||
| """Initializes a TopicMessageQuery.""" | ||
| self._topic_id: TopicId | None = self._parse_topic_id(topic_id) if topic_id else None | ||
| self._topic_id: basic_types_pb2.TopicID | None = self._parse_topic_id(topic_id) if topic_id else None | ||
| self._start_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(start_time) if start_time else None | ||
| self._end_time: timestamp_pb2.Timestamp | None = self._parse_timestamp(end_time) if end_time else None | ||
| self._limit: int | None = limit | ||
| self._limit: int = limit if limit is not None else 0 | ||
| self._chunking_enabled: bool = chunking_enabled | ||
| self._completion_handler: Callable[[], None] | None = None | ||
|
|
||
| self._max_attempts: int = 10 | ||
| self._max_backoff: float = 8.0 | ||
|
|
||
| self._completion_handler: Callable[[], None] | None = self._on_complete | ||
| self._error_handler: Callable[[], None] | None = self._on_error | ||
|
|
||
| def set_max_attempts(self, attempts: int) -> TopicMessageQuery: | ||
| """Sets the maximum number of attempts to reconnect on failure.""" | ||
| if attempts <= 0: | ||
| raise ValueError("max_attempts must be greater than 0") | ||
|
|
||
| self._max_attempts = attempts | ||
| return self | ||
|
|
||
| def set_max_backoff(self, backoff: float) -> TopicMessageQuery: | ||
| """Sets the maximum backoff time in seconds for reconnection attempts.""" | ||
| if backoff < 0.5: | ||
| raise ValueError("max_backoff must be at least 500 ms") | ||
|
|
||
| self._max_backoff = backoff | ||
| return self | ||
|
|
||
| def set_completion_handler(self, handler: Callable[[], None]) -> TopicMessageQuery: | ||
| """Sets a completion handler that is called when the subscription completes.""" | ||
| if not callable(handler): | ||
| raise TypeError("handler must be a callable object") | ||
|
|
||
| self._completion_handler = handler | ||
| return self | ||
|
|
||
| def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID: | ||
| """Parses a topic ID from a string or TopicId object into a protobuf TopicID.""" | ||
| if isinstance(topic_id, str): | ||
| parts = topic_id.strip().split(".") | ||
| if len(parts) != 3: | ||
| raise ValueError(f"Invalid topic ID string: {topic_id}") | ||
| shard, realm, topic = map(int, parts) | ||
| return basic_types_pb2.TopicID(shardNum=shard, realmNum=realm, topicNum=topic) | ||
| if isinstance(topic_id, TopicId): | ||
| return topic_id._to_proto() | ||
| raise TypeError("Invalid topic_id format. Must be a string or TopicId.") | ||
| def set_error_handler(self, handler: Callable[[Exception], None]) -> TopicMessageQuery: | ||
| """Sets an error handler that is called when the subscription causes an error.""" | ||
| if not callable(handler): | ||
| raise TypeError("handler must be a callable object") | ||
|
|
||
| def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp: | ||
| """Converts a datetime object to a protobuf Timestamp.""" | ||
| seconds = int(dt.timestamp()) | ||
| nanos = int((dt.timestamp() - seconds) * 1e9) | ||
| return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) | ||
| self._error_handler = handler | ||
| return self | ||
|
|
||
| def set_topic_id(self, topic_id: str | TopicId) -> TopicMessageQuery: | ||
| """Sets the topic ID for the query.""" | ||
|
|
@@ -99,6 +118,94 @@ def set_chunking_enabled(self, enabled: bool) -> TopicMessageQuery: | |
| self._chunking_enabled = enabled | ||
| return self | ||
|
|
||
| def _on_complete(self) -> None: | ||
| logger.info(f"Subscription to topic {self._topic_id} complete") | ||
|
|
||
| def _on_error(self, err: Exception) -> None: | ||
| if isinstance(err, grpc.RpcError) and err.code() == grpc.StatusCode.CANCELLED: | ||
| logger.warning(f"Call is cancelled for topic {self._topic_id}") | ||
| else: | ||
| logger.error(f"Error attempting to subscribe to topic {self._topic_id}: {err}") | ||
|
|
||
| def _should_retry(self, err: Exception) -> bool: | ||
| if isinstance(err, grpc.RpcError): | ||
| return err.code() in ( | ||
| grpc.StatusCode.NOT_FOUND, | ||
| grpc.StatusCode.UNAVAILABLE, | ||
| grpc.StatusCode.RESOURCE_EXHAUSTED, | ||
| ) or (err.code() == grpc.StatusCode.INTERNAL and bool(RST_STREAM.search(err.details()))) | ||
|
|
||
| return True | ||
|
manishdait marked this conversation as resolved.
|
||
|
|
||
| def _parse_topic_id(self, topic_id: str | TopicId) -> basic_types_pb2.TopicID: | ||
| """Parses a topic ID from a string or TopicId object into a protobuf TopicID.""" | ||
| if isinstance(topic_id, str): | ||
| topic_id = TopicId.from_string(topic_id) | ||
|
|
||
| if isinstance(topic_id, TopicId): | ||
| return topic_id._to_proto() | ||
|
|
||
| raise TypeError("Invalid topic_id format. Must be a string or TopicId.") | ||
|
|
||
| def _parse_timestamp(self, dt: datetime) -> timestamp_pb2.Timestamp: | ||
| """Converts a datetime object to a protobuf Timestamp.""" | ||
| seconds = int(dt.timestamp()) | ||
| nanos = int((dt.timestamp() - seconds) * 1e9) | ||
| return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) | ||
|
|
||
| def _build_query_request(self, state: SubscriptionState) -> mirror_proto.ConsensusTopicQuery: | ||
| """Build the request object based on current subscription state.""" | ||
| request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) | ||
|
|
||
| if self._end_time is not None: | ||
| request.consensusEndTime.CopyFrom(self._end_time) | ||
|
|
||
| if state.last_message is not None: | ||
| last_message_time = state.last_message.consensusTimestamp | ||
|
|
||
| seconds = last_message_time.seconds | ||
| nanos = last_message_time.nanos + 1 | ||
|
|
||
| if nanos >= 1_000_000_000: | ||
| seconds += 1 | ||
| nanos = 0 | ||
|
|
||
| request.consensusStartTime.seconds = seconds | ||
| request.consensusStartTime.nanos = nanos | ||
|
|
||
| if self._limit > 0: | ||
| request.limit = max(0, self._limit - state.count) | ||
|
Comment on lines
+176
to
+177
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Count emitted
🐛 Proposed fix def _build_query_request(self, state: SubscriptionState) -> mirror_proto.ConsensusTopicQuery:
"""Build the request object based on current subscription state."""
request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id)
if self._end_time is not None:
request.consensusEndTime.CopyFrom(self._end_time)
if state.last_message is not None:
last_message_time = state.last_message.consensusTimestamp
@@
request.consensusStartTime.seconds = seconds
request.consensusStartTime.nanos = nanos
if self._limit > 0:
- request.limit = max(0, self._limit - state.count)
+ request.limit = self._limit - state.count
else:
if self._start_time is not None:
request.consensusStartTime.CopyFrom(self._start_time)
request.limit = self._limit
return request
def _handle_response(self, response, state: SubscriptionState, on_message: Callable[[TopicMessage], None]) -> None:
"""Handles single or chunked messages."""
- state.count += 1
state.last_message = response
if not self._chunking_enabled or not response.HasField("chunkInfo") or response.chunkInfo.total <= 1:
message = TopicMessage.of_single(response)
on_message(message)
+ state.count += 1
return
@@
if len(chunks) == response.chunkInfo.total:
del state.pending_messages[initial_tx_id]
message = TopicMessage.of_many(chunks)
on_message(message)
+ state.count += 1
def subscribe(
@@
def run_stream():
while state.attempt < self._max_attempts and not subscription_handle.is_cancelled():
+ if self._limit > 0 and state.count >= self._limit:
+ if self._completion_handler:
+ self._completion_handler()
+ return
+
state.attempt += 1
request = self._build_query_request(state)As per coding guidelines, "Compare the SDK class against the proto schema" and "Ensure Query code remains: Backward-compatible." Also applies to: 185-207 |
||
| else: | ||
| if self._start_time is not None: | ||
| request.consensusStartTime.CopyFrom(self._start_time) | ||
| request.limit = self._limit | ||
|
|
||
| return request | ||
|
|
||
| def _handle_response(self, response, state: SubscriptionState, on_message: Callable[[TopicMessage], None]) -> None: | ||
| """Handles single or chunked messages.""" | ||
| state.count += 1 | ||
| state.last_message = response | ||
|
|
||
| if not self._chunking_enabled or not response.HasField("chunkInfo") or response.chunkInfo.total <= 1: | ||
| message = TopicMessage.of_single(response) | ||
| on_message(message) | ||
| return | ||
|
|
||
| initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID) | ||
|
|
||
| if initial_tx_id not in state.pending_messages: | ||
| state.pending_messages[initial_tx_id] = [] | ||
|
|
||
| chunks = state.pending_messages[initial_tx_id] | ||
| chunks.append(response) | ||
|
|
||
| if len(chunks) == response.chunkInfo.total: | ||
| del state.pending_messages[initial_tx_id] | ||
|
|
||
| message = TopicMessage.of_many(chunks) | ||
| on_message(message) | ||
|
|
||
| def subscribe( | ||
| self, | ||
| client: Client, | ||
|
|
@@ -111,49 +218,23 @@ def subscribe( | |
| if not client.mirror_stub: | ||
| raise ValueError("Client has no mirror_stub. Did you configure a mirror node address?") | ||
|
|
||
| request = mirror_proto.ConsensusTopicQuery(topicID=self._topic_id) | ||
| if self._start_time: | ||
| request.consensusStartTime.CopyFrom(self._start_time) | ||
| if self._end_time: | ||
| request.consensusEndTime.CopyFrom(self._end_time) | ||
| if self._limit is not None: | ||
| request.limit = self._limit | ||
|
|
||
| subscription_handle = SubscriptionHandle() | ||
|
|
||
| pending_chunks: dict[str, list[mirror_proto.ConsensusTopicResponse]] = {} | ||
| state = SubscriptionState() | ||
|
|
||
| def run_stream(): | ||
| attempt = 0 | ||
| while attempt < self._max_attempts and not subscription_handle.is_cancelled(): | ||
| while state.attempt < self._max_attempts and not subscription_handle.is_cancelled(): | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| state.attempt += 1 | ||
| request = self._build_query_request(state) | ||
|
|
||
| try: | ||
| message_stream = client.mirror_stub.subscribeTopic(request) | ||
| subscription_handle._set_call(message_stream) | ||
|
|
||
| for response in message_stream: | ||
| if subscription_handle.is_cancelled(): | ||
| return | ||
|
|
||
| if ( | ||
| not self._chunking_enabled | ||
| or not response.HasField("chunkInfo") | ||
| or response.chunkInfo.total <= 1 | ||
| ): | ||
| msg_obj = TopicMessage.of_single(response) | ||
| on_message(msg_obj) | ||
| continue | ||
|
|
||
| initial_tx_id = TransactionId._from_proto(response.chunkInfo.initialTransactionID) | ||
|
|
||
| if initial_tx_id not in pending_chunks: | ||
| pending_chunks[initial_tx_id] = [] | ||
|
|
||
| pending_chunks[initial_tx_id].append(response) | ||
|
|
||
| if len(pending_chunks[initial_tx_id]) == response.chunkInfo.total: | ||
| chunk_list = pending_chunks.pop(initial_tx_id) | ||
|
|
||
| msg_obj = TopicMessage.of_many(chunk_list) | ||
| on_message(msg_obj) | ||
| self._handle_response(response, state, on_message) | ||
|
|
||
| if self._completion_handler: | ||
| self._completion_handler() | ||
|
|
@@ -163,13 +244,16 @@ def run_stream(): | |
| if subscription_handle.is_cancelled(): | ||
| return | ||
|
|
||
| attempt += 1 | ||
| if attempt >= self._max_attempts: | ||
| if state.attempt >= self._max_attempts or not self._should_retry(e): | ||
| if self._error_handler: | ||
| self._error_handler(e) | ||
| if on_error: | ||
| on_error(e) | ||
| return | ||
|
|
||
| delay = min(0.5 * (2 ** (attempt - 1)), self._max_backoff) | ||
| delay = min(0.5 * (2 ** (state.attempt)), self._max_backoff) | ||
| logger.warning(f"Error subscribing to topic attempt {state.attempt}. Retrying in {int(delay)}s...") | ||
|
|
||
| time.sleep(delay) | ||
|
|
||
| thread = threading.Thread(target=run_stream, daemon=True) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import threading | ||
| from typing import Any | ||
|
|
||
|
|
||
| class SubscriptionHandle: | ||
|
|
@@ -12,11 +13,25 @@ class SubscriptionHandle: | |
|
|
||
| def __init__(self): | ||
| self._cancelled = threading.Event() | ||
| self._thread = None | ||
| self._thread: threading.Thread | None = None | ||
| self._call: Any | None = None | ||
| self._lock = threading.Lock() | ||
|
|
||
| def _set_call(self, call: Any): | ||
| """Sets the active gRPC call so it can be cancelled.""" | ||
| with self._lock: | ||
| self._call = call | ||
|
|
||
| if self._cancelled.is_set(): | ||
| self._call.cancel() | ||
|
Comment on lines
+20
to
+26
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid calling Line 26 and Line 34 invoke Suggested fix def _set_call(self, call: Any):
"""Sets the active gRPC call so it can be cancelled."""
+ call_to_cancel = None
with self._lock:
self._call = call
if self._cancelled.is_set():
- self._call.cancel()
+ call_to_cancel = call
+ if call_to_cancel is not None:
+ call_to_cancel.cancel()
def cancel(self):
"""Signals to cancel the subscription."""
+ call_to_cancel = None
with self._lock:
self._cancelled.set()
- if self._call:
- self._call.cancel()
+ if self._call is not None:
+ call_to_cancel = self._call
+ if call_to_cancel is not None:
+ call_to_cancel.cancel()Also applies to: 30-34 |
||
|
|
||
| def cancel(self): | ||
| """Signals to cancel the subscription.""" | ||
| self._cancelled.set() | ||
| with self._lock: | ||
| self._cancelled.set() | ||
|
|
||
| if self._call: | ||
| self._call.cancel() | ||
|
|
||
| def is_cancelled(self) -> bool: | ||
| """Returns True if this subscription is already cancelled.""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.