66import threading
77import time
88import uuid
9+ from collections .abc import AsyncIterable , Iterable
910from dataclasses import dataclass
1011from datetime import datetime
1112from enum import Enum
12- from typing import Any , Generic , Sequence , TypeVar
13+ from typing import Any , Generic , Protocol , Sequence , TypeVar , cast
1314
1415import grpc
1516import grpc .aio
17+ from google .protobuf import wrappers_pb2
1618
1719import durabletask .history as history
1820from durabletask .entities import EntityInstanceId
@@ -64,8 +66,8 @@ class OrchestrationStatus(Enum):
6466 PENDING = pb .ORCHESTRATION_STATUS_PENDING
6567 SUSPENDED = pb .ORCHESTRATION_STATUS_SUSPENDED
6668
67- def __str__ (self ):
68- return helpers .get_orchestration_status_str (self .value )
69+ def __str__ (self ) -> str :
70+ return cast ( str , helpers .get_orchestration_status_str (self .value ) )
6971
7072
7173@dataclass
@@ -173,6 +175,128 @@ def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationStat
173175_RETIRED_CHANNEL_CLOSE_DELAY_SECONDS = 30.0
174176
175177
178+ class _SyncTaskHubSidecarServiceStub (Protocol ):
179+ def StartInstance (self , request : pb .CreateInstanceRequest ) -> pb .CreateInstanceResponse :
180+ ...
181+
182+ def GetInstance (self , request : pb .GetInstanceRequest ) -> pb .GetInstanceResponse :
183+ ...
184+
185+ def StreamInstanceHistory (self , request : pb .StreamInstanceHistoryRequest ) -> Iterable [pb .HistoryChunk ]:
186+ ...
187+
188+ def ListInstanceIds (self , request : pb .ListInstanceIdsRequest ) -> pb .ListInstanceIdsResponse :
189+ ...
190+
191+ def QueryInstances (self , request : pb .QueryInstancesRequest ) -> pb .QueryInstancesResponse :
192+ ...
193+
194+ def WaitForInstanceStart (
195+ self ,
196+ request : pb .GetInstanceRequest ,
197+ * ,
198+ timeout : float | None = None ) -> pb .GetInstanceResponse :
199+ ...
200+
201+ def WaitForInstanceCompletion (
202+ self ,
203+ request : pb .GetInstanceRequest ,
204+ * ,
205+ timeout : float | None = None ) -> pb .GetInstanceResponse :
206+ ...
207+
208+ def RaiseEvent (self , request : pb .RaiseEventRequest ) -> pb .RaiseEventResponse :
209+ ...
210+
211+ def TerminateInstance (self , request : pb .TerminateRequest ) -> pb .TerminateResponse :
212+ ...
213+
214+ def SuspendInstance (self , request : pb .SuspendRequest ) -> pb .SuspendResponse :
215+ ...
216+
217+ def ResumeInstance (self , request : pb .ResumeRequest ) -> pb .ResumeResponse :
218+ ...
219+
220+ def RestartInstance (self , request : pb .RestartInstanceRequest ) -> pb .RestartInstanceResponse :
221+ ...
222+
223+ def PurgeInstances (self , request : pb .PurgeInstancesRequest ) -> pb .PurgeInstancesResponse :
224+ ...
225+
226+ def SignalEntity (self , request : pb .SignalEntityRequest ) -> pb .SignalEntityResponse :
227+ ...
228+
229+ def GetEntity (self , request : pb .GetEntityRequest ) -> pb .GetEntityResponse :
230+ ...
231+
232+ def QueryEntities (self , request : pb .QueryEntitiesRequest ) -> pb .QueryEntitiesResponse :
233+ ...
234+
235+ def CleanEntityStorage (self , request : pb .CleanEntityStorageRequest ) -> pb .CleanEntityStorageResponse :
236+ ...
237+
238+
239+ class _AsyncTaskHubSidecarServiceStub (Protocol ):
240+ async def StartInstance (self , request : pb .CreateInstanceRequest ) -> pb .CreateInstanceResponse :
241+ ...
242+
243+ async def GetInstance (self , request : pb .GetInstanceRequest ) -> pb .GetInstanceResponse :
244+ ...
245+
246+ def StreamInstanceHistory (self , request : pb .StreamInstanceHistoryRequest ) -> AsyncIterable [pb .HistoryChunk ]:
247+ ...
248+
249+ async def ListInstanceIds (self , request : pb .ListInstanceIdsRequest ) -> pb .ListInstanceIdsResponse :
250+ ...
251+
252+ async def QueryInstances (self , request : pb .QueryInstancesRequest ) -> pb .QueryInstancesResponse :
253+ ...
254+
255+ async def WaitForInstanceStart (
256+ self ,
257+ request : pb .GetInstanceRequest ,
258+ * ,
259+ timeout : float | None = None ) -> pb .GetInstanceResponse :
260+ ...
261+
262+ async def WaitForInstanceCompletion (
263+ self ,
264+ request : pb .GetInstanceRequest ,
265+ * ,
266+ timeout : float | None = None ) -> pb .GetInstanceResponse :
267+ ...
268+
269+ async def RaiseEvent (self , request : pb .RaiseEventRequest ) -> pb .RaiseEventResponse :
270+ ...
271+
272+ async def TerminateInstance (self , request : pb .TerminateRequest ) -> pb .TerminateResponse :
273+ ...
274+
275+ async def SuspendInstance (self , request : pb .SuspendRequest ) -> pb .SuspendResponse :
276+ ...
277+
278+ async def ResumeInstance (self , request : pb .ResumeRequest ) -> pb .ResumeResponse :
279+ ...
280+
281+ async def RestartInstance (self , request : pb .RestartInstanceRequest ) -> pb .RestartInstanceResponse :
282+ ...
283+
284+ async def PurgeInstances (self , request : pb .PurgeInstancesRequest ) -> pb .PurgeInstancesResponse :
285+ ...
286+
287+ async def SignalEntity (self , request : pb .SignalEntityRequest ) -> pb .SignalEntityResponse :
288+ ...
289+
290+ async def GetEntity (self , request : pb .GetEntityRequest ) -> pb .GetEntityResponse :
291+ ...
292+
293+ async def QueryEntities (self , request : pb .QueryEntitiesRequest ) -> pb .QueryEntitiesResponse :
294+ ...
295+
296+ async def CleanEntityStorage (self , request : pb .CleanEntityStorageRequest ) -> pb .CleanEntityStorageResponse :
297+ ...
298+
299+
176300class TaskHubGrpcClient :
177301 def __init__ (self , * ,
178302 host_address : str | None = None ,
@@ -245,7 +369,7 @@ def __init__(self, *,
245369 # observable effect. Callers wanting resiliency on a custom channel
246370 # can prepend the interceptor themselves via grpc.intercept_channel.
247371 self ._channel = channel
248- self ._stub = stubs .TaskHubSidecarServiceStub (channel )
372+ self ._stub = cast ( _SyncTaskHubSidecarServiceStub , stubs .TaskHubSidecarServiceStub (channel ) )
249373 self ._logger = shared .get_logger ("client" , log_handler , log_formatter )
250374 self .default_version = default_version
251375 self ._payload_store = payload_store
@@ -322,7 +446,7 @@ def _maybe_recreate_channel(self) -> None:
322446 interceptors = self ._interceptors ,
323447 channel_options = self ._channel_options ,
324448 )
325- self ._stub = stubs .TaskHubSidecarServiceStub (self ._channel )
449+ self ._stub = cast ( _SyncTaskHubSidecarServiceStub , stubs .TaskHubSidecarServiceStub (self ._channel ) )
326450 self ._last_recreate_time = now
327451 self ._client_failure_tracker .record_success ()
328452 close_timer = threading .Timer (
@@ -459,11 +583,11 @@ def get_all_orchestration_states(self,
459583 ) -> list [OrchestrationState ]:
460584 if orchestration_query is None :
461585 orchestration_query = OrchestrationQuery ()
462- _continuation_token = None
586+ _continuation_token : wrappers_pb2 . StringValue | None = None
463587
464588 self ._logger .info (f"Querying orchestration instances with query: { orchestration_query } " )
465589
466- states = []
590+ states : list [ OrchestrationState ] = []
467591
468592 while True :
469593 req = build_query_instances_req (orchestration_query , _continuation_token )
@@ -621,11 +745,11 @@ def get_all_entities(self,
621745 entity_query : EntityQuery | None = None ) -> list [EntityMetadata ]:
622746 if entity_query is None :
623747 entity_query = EntityQuery ()
624- _continuation_token = None
748+ _continuation_token : wrappers_pb2 . StringValue | None = None
625749
626750 self ._logger .info (f"Retrieving entities by filter: { entity_query } " )
627751
628- entities = []
752+ entities : list [ EntityMetadata ] = []
629753
630754 while True :
631755 query_request = build_query_entities_req (entity_query , _continuation_token )
@@ -647,7 +771,7 @@ def clean_entity_storage(self,
647771
648772 empty_entities_removed = 0
649773 orphaned_locks_released = 0
650- _continuation_token = None
774+ _continuation_token : wrappers_pb2 . StringValue | None = None
651775
652776 while True :
653777 req = pb .CleanEntityStorageRequest (
@@ -741,7 +865,7 @@ def __init__(self, *,
741865 # leave the failure-tracking opt-out implicit: callers wanting full
742866 # resiliency should let us create the channel.
743867 self ._channel = channel
744- self ._stub = stubs .TaskHubSidecarServiceStub (channel )
868+ self ._stub = cast ( _AsyncTaskHubSidecarServiceStub , stubs .TaskHubSidecarServiceStub (channel ) )
745869 self ._logger = shared .get_logger ("async_client" , log_handler , log_formatter )
746870 self .default_version = default_version
747871 self ._payload_store = payload_store
@@ -839,7 +963,7 @@ async def _maybe_recreate_channel(self) -> None:
839963 interceptors = self ._interceptors ,
840964 channel_options = self ._channel_options ,
841965 )
842- self ._stub = stubs .TaskHubSidecarServiceStub (self ._channel )
966+ self ._stub = cast ( _AsyncTaskHubSidecarServiceStub , stubs .TaskHubSidecarServiceStub (self ._channel ) )
843967 self ._last_recreate_time = now
844968 self ._client_failure_tracker .record_success ()
845969 self ._retired_channels .append (old_channel )
@@ -940,11 +1064,11 @@ async def get_all_orchestration_states(self,
9401064 ) -> list [OrchestrationState ]:
9411065 if orchestration_query is None :
9421066 orchestration_query = OrchestrationQuery ()
943- _continuation_token = None
1067+ _continuation_token : wrappers_pb2 . StringValue | None = None
9441068
9451069 self ._logger .info (f"Querying orchestration instances with query: { orchestration_query } " )
9461070
947- states = []
1071+ states : list [ OrchestrationState ] = []
9481072
9491073 while True :
9501074 req = build_query_instances_req (orchestration_query , _continuation_token )
@@ -1101,11 +1225,11 @@ async def get_all_entities(self,
11011225 entity_query : EntityQuery | None = None ) -> list [EntityMetadata ]:
11021226 if entity_query is None :
11031227 entity_query = EntityQuery ()
1104- _continuation_token = None
1228+ _continuation_token : wrappers_pb2 . StringValue | None = None
11051229
11061230 self ._logger .info (f"Retrieving entities by filter: { entity_query } " )
11071231
1108- entities = []
1232+ entities : list [ EntityMetadata ] = []
11091233
11101234 while True :
11111235 query_request = build_query_entities_req (entity_query , _continuation_token )
@@ -1127,7 +1251,7 @@ async def clean_entity_storage(self,
11271251
11281252 empty_entities_removed = 0
11291253 orphaned_locks_released = 0
1130- _continuation_token = None
1254+ _continuation_token : wrappers_pb2 . StringValue | None = None
11311255
11321256 while True :
11331257 req = pb .CleanEntityStorageRequest (
0 commit comments