Skip to content

Commit 5b8e0c8

Browse files
Bernd VerstCopilot
andcommitted
Add pyright strict type-check CI workflow
- New .github/workflows/typecheck.yml runs pyright in strict mode on Python 3.10 (lowest supported) across durabletask and durabletask-azuremanaged for PRs and pushes to main. - Add pyrightconfig.json at repo root (strict, Python 3.10, excludes generated protobuf/gRPC files). - Add pyright to dev-requirements.txt. - Clean up 1598 strict-mode type errors across the SDK while preserving runtime behavior. Changes are purely additive type annotations, casts, and targeted `# pyright: ignore` comments scoped to specific rules. - Address related typing issues: - #93: OrchestrationContext.create_timer now returns TimerTask (was CancellableTask). - #94: WhenAnyTask is now generic; when_any(tasks: Sequence[Task[T]]) returns WhenAnyTask[T], so the completing child Task[T] is statically typed. - #92: Broad improvements to generic type-safety hints. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1cb30b0 commit 5b8e0c8

27 files changed

Lines changed: 827 additions & 459 deletions

.github/workflows/typecheck.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: Type Check (pyright)
2+
3+
on:
4+
push:
5+
branches:
6+
- "main"
7+
tags:
8+
- "v*"
9+
- "azuremanaged-v*"
10+
pull_request:
11+
branches:
12+
- "main"
13+
14+
permissions:
15+
contents: read
16+
17+
jobs:
18+
pyright:
19+
runs-on: ubuntu-latest
20+
steps:
21+
- name: Checkout repository
22+
uses: actions/checkout@v4
23+
24+
- name: Set up Python 3.10 (lowest supported)
25+
uses: actions/setup-python@v5
26+
with:
27+
python-version: "3.10"
28+
29+
- name: Install packages and dependencies
30+
run: |
31+
python -m pip install --upgrade pip
32+
pip install -r requirements.txt
33+
pip install -e ".[azure-blob-payloads,opentelemetry]"
34+
pip install -e ./durabletask-azuremanaged
35+
pip install pyright
36+
37+
- name: Run pyright (strict, Python 3.10)
38+
run: pyright

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ ADDED
1717
resiliency-aware teardown introduced in v1.5.0 (in-flight recreate
1818
thread join, retired-channel timer cancellation, and SDK-owned channel
1919
cleanup) runs unchanged through the new `with` path.
20+
- Added a pyright type-check CI workflow that runs on pull requests and pushes
21+
to `main`, using strict mode against the lowest supported Python version
22+
(3.10) across both `durabletask` and `durabletask-azuremanaged` packages.
23+
- Improved type coverage across the public API. `OrchestrationContext.create_timer`
24+
now returns the specific `TimerTask` type (previously `CancellableTask`)
25+
([#93](https://github.com/microsoft/durabletask-python/issues/93)), and
26+
`WhenAnyTask` is now generic with `when_any(tasks: Sequence[Task[T]]) -> WhenAnyTask[T]`
27+
for better static type inference of the completing child task
28+
([#94](https://github.com/microsoft/durabletask-python/issues/94)).
29+
These changes also broadly improve generic type-safety hints throughout the
30+
SDK ([#92](https://github.com/microsoft/durabletask-python/issues/92)).
2031

2132
## v1.5.0
2233

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
grpcio-tools
22
pymarkdownlnt
3+
pyright

durabletask-azuremanaged/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## Unreleased
99

10+
- Improved type coverage benefits Azure Managed users: `create_timer` now
11+
returns the specific `TimerTask` type and `when_any` is generic so the
12+
completing child task is type-checked through `DurableTaskSchedulerClient`,
13+
`AsyncDurableTaskSchedulerClient`, and `DurableTaskSchedulerWorker` derived
14+
orchestrations.
15+
1016
## v1.5.0
1117

1218
- Updates base dependency to durabletask v1.5.0

durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from durabletask.internal.grpc_interceptor import (
1515
DefaultAsyncClientInterceptorImpl,
1616
DefaultClientInterceptorImpl,
17-
_AsyncClientCallDetails,
18-
_ClientCallDetails,
1917
)
2018

2119

@@ -62,7 +60,7 @@ def _upsert_authorization_header(self, token: str) -> None:
6260
self._metadata.append(("authorization", f"Bearer {token}"))
6361

6462
def _intercept_call(
65-
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
63+
self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
6664
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
6765
call details."""
6866
# Refresh the auth token if a credential was provided. The call to
@@ -114,7 +112,7 @@ def _upsert_authorization_header(self, token: str) -> None:
114112
self._metadata.append(("authorization", f"Bearer {token}"))
115113

116114
async def _intercept_call(
117-
self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails:
115+
self, client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails:
118116
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
119117
call details."""
120118
# Refresh the auth token if a credential was provided. The call to

durabletask/client.py

Lines changed: 141 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import threading
77
import time
88
import uuid
9+
from collections.abc import AsyncIterable, Iterable
910
from dataclasses import dataclass
1011
from datetime import datetime
1112
from enum import Enum
12-
from typing import Any, Generic, Sequence, TypeVar
13+
from typing import Any, Generic, Protocol, Sequence, TypeVar, cast
1314

1415
import grpc
1516
import grpc.aio
17+
from google.protobuf import wrappers_pb2
1618

1719
import durabletask.history as history
1820
from durabletask.entities import EntityInstanceId
@@ -64,8 +66,8 @@ class OrchestrationStatus(Enum):
6466
PENDING = pb.ORCHESTRATION_STATUS_PENDING
6567
SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED
6668

67-
def __str__(self):
68-
return helpers.get_orchestration_status_str(self.value)
69+
def __str__(self) -> str:
70+
return cast(str, helpers.get_orchestration_status_str(self.value))
6971

7072

7173
@dataclass
@@ -173,6 +175,128 @@ def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationStat
173175
_RETIRED_CHANNEL_CLOSE_DELAY_SECONDS = 30.0
174176

175177

178+
class _SyncTaskHubSidecarServiceStub(Protocol):
179+
def StartInstance(self, request: pb.CreateInstanceRequest) -> pb.CreateInstanceResponse:
180+
...
181+
182+
def GetInstance(self, request: pb.GetInstanceRequest) -> pb.GetInstanceResponse:
183+
...
184+
185+
def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest) -> Iterable[pb.HistoryChunk]:
186+
...
187+
188+
def ListInstanceIds(self, request: pb.ListInstanceIdsRequest) -> pb.ListInstanceIdsResponse:
189+
...
190+
191+
def QueryInstances(self, request: pb.QueryInstancesRequest) -> pb.QueryInstancesResponse:
192+
...
193+
194+
def WaitForInstanceStart(
195+
self,
196+
request: pb.GetInstanceRequest,
197+
*,
198+
timeout: float | None = None) -> pb.GetInstanceResponse:
199+
...
200+
201+
def WaitForInstanceCompletion(
202+
self,
203+
request: pb.GetInstanceRequest,
204+
*,
205+
timeout: float | None = None) -> pb.GetInstanceResponse:
206+
...
207+
208+
def RaiseEvent(self, request: pb.RaiseEventRequest) -> pb.RaiseEventResponse:
209+
...
210+
211+
def TerminateInstance(self, request: pb.TerminateRequest) -> pb.TerminateResponse:
212+
...
213+
214+
def SuspendInstance(self, request: pb.SuspendRequest) -> pb.SuspendResponse:
215+
...
216+
217+
def ResumeInstance(self, request: pb.ResumeRequest) -> pb.ResumeResponse:
218+
...
219+
220+
def RestartInstance(self, request: pb.RestartInstanceRequest) -> pb.RestartInstanceResponse:
221+
...
222+
223+
def PurgeInstances(self, request: pb.PurgeInstancesRequest) -> pb.PurgeInstancesResponse:
224+
...
225+
226+
def SignalEntity(self, request: pb.SignalEntityRequest) -> pb.SignalEntityResponse:
227+
...
228+
229+
def GetEntity(self, request: pb.GetEntityRequest) -> pb.GetEntityResponse:
230+
...
231+
232+
def QueryEntities(self, request: pb.QueryEntitiesRequest) -> pb.QueryEntitiesResponse:
233+
...
234+
235+
def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest) -> pb.CleanEntityStorageResponse:
236+
...
237+
238+
239+
class _AsyncTaskHubSidecarServiceStub(Protocol):
240+
async def StartInstance(self, request: pb.CreateInstanceRequest) -> pb.CreateInstanceResponse:
241+
...
242+
243+
async def GetInstance(self, request: pb.GetInstanceRequest) -> pb.GetInstanceResponse:
244+
...
245+
246+
def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest) -> AsyncIterable[pb.HistoryChunk]:
247+
...
248+
249+
async def ListInstanceIds(self, request: pb.ListInstanceIdsRequest) -> pb.ListInstanceIdsResponse:
250+
...
251+
252+
async def QueryInstances(self, request: pb.QueryInstancesRequest) -> pb.QueryInstancesResponse:
253+
...
254+
255+
async def WaitForInstanceStart(
256+
self,
257+
request: pb.GetInstanceRequest,
258+
*,
259+
timeout: float | None = None) -> pb.GetInstanceResponse:
260+
...
261+
262+
async def WaitForInstanceCompletion(
263+
self,
264+
request: pb.GetInstanceRequest,
265+
*,
266+
timeout: float | None = None) -> pb.GetInstanceResponse:
267+
...
268+
269+
async def RaiseEvent(self, request: pb.RaiseEventRequest) -> pb.RaiseEventResponse:
270+
...
271+
272+
async def TerminateInstance(self, request: pb.TerminateRequest) -> pb.TerminateResponse:
273+
...
274+
275+
async def SuspendInstance(self, request: pb.SuspendRequest) -> pb.SuspendResponse:
276+
...
277+
278+
async def ResumeInstance(self, request: pb.ResumeRequest) -> pb.ResumeResponse:
279+
...
280+
281+
async def RestartInstance(self, request: pb.RestartInstanceRequest) -> pb.RestartInstanceResponse:
282+
...
283+
284+
async def PurgeInstances(self, request: pb.PurgeInstancesRequest) -> pb.PurgeInstancesResponse:
285+
...
286+
287+
async def SignalEntity(self, request: pb.SignalEntityRequest) -> pb.SignalEntityResponse:
288+
...
289+
290+
async def GetEntity(self, request: pb.GetEntityRequest) -> pb.GetEntityResponse:
291+
...
292+
293+
async def QueryEntities(self, request: pb.QueryEntitiesRequest) -> pb.QueryEntitiesResponse:
294+
...
295+
296+
async def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest) -> pb.CleanEntityStorageResponse:
297+
...
298+
299+
176300
class TaskHubGrpcClient:
177301
def __init__(self, *,
178302
host_address: str | None = None,
@@ -245,7 +369,7 @@ def __init__(self, *,
245369
# observable effect. Callers wanting resiliency on a custom channel
246370
# can prepend the interceptor themselves via grpc.intercept_channel.
247371
self._channel = channel
248-
self._stub = stubs.TaskHubSidecarServiceStub(channel)
372+
self._stub = cast(_SyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(channel))
249373
self._logger = shared.get_logger("client", log_handler, log_formatter)
250374
self.default_version = default_version
251375
self._payload_store = payload_store
@@ -322,7 +446,7 @@ def _maybe_recreate_channel(self) -> None:
322446
interceptors=self._interceptors,
323447
channel_options=self._channel_options,
324448
)
325-
self._stub = stubs.TaskHubSidecarServiceStub(self._channel)
449+
self._stub = cast(_SyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(self._channel))
326450
self._last_recreate_time = now
327451
self._client_failure_tracker.record_success()
328452
close_timer = threading.Timer(
@@ -459,11 +583,11 @@ def get_all_orchestration_states(self,
459583
) -> list[OrchestrationState]:
460584
if orchestration_query is None:
461585
orchestration_query = OrchestrationQuery()
462-
_continuation_token = None
586+
_continuation_token: wrappers_pb2.StringValue | None = None
463587

464588
self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
465589

466-
states = []
590+
states: list[OrchestrationState] = []
467591

468592
while True:
469593
req = build_query_instances_req(orchestration_query, _continuation_token)
@@ -621,11 +745,11 @@ def get_all_entities(self,
621745
entity_query: EntityQuery | None = None) -> list[EntityMetadata]:
622746
if entity_query is None:
623747
entity_query = EntityQuery()
624-
_continuation_token = None
748+
_continuation_token: wrappers_pb2.StringValue | None = None
625749

626750
self._logger.info(f"Retrieving entities by filter: {entity_query}")
627751

628-
entities = []
752+
entities: list[EntityMetadata] = []
629753

630754
while True:
631755
query_request = build_query_entities_req(entity_query, _continuation_token)
@@ -647,7 +771,7 @@ def clean_entity_storage(self,
647771

648772
empty_entities_removed = 0
649773
orphaned_locks_released = 0
650-
_continuation_token = None
774+
_continuation_token: wrappers_pb2.StringValue | None = None
651775

652776
while True:
653777
req = pb.CleanEntityStorageRequest(
@@ -741,7 +865,7 @@ def __init__(self, *,
741865
# leave the failure-tracking opt-out implicit: callers wanting full
742866
# resiliency should let us create the channel.
743867
self._channel = channel
744-
self._stub = stubs.TaskHubSidecarServiceStub(channel)
868+
self._stub = cast(_AsyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(channel))
745869
self._logger = shared.get_logger("async_client", log_handler, log_formatter)
746870
self.default_version = default_version
747871
self._payload_store = payload_store
@@ -839,7 +963,7 @@ async def _maybe_recreate_channel(self) -> None:
839963
interceptors=self._interceptors,
840964
channel_options=self._channel_options,
841965
)
842-
self._stub = stubs.TaskHubSidecarServiceStub(self._channel)
966+
self._stub = cast(_AsyncTaskHubSidecarServiceStub, stubs.TaskHubSidecarServiceStub(self._channel))
843967
self._last_recreate_time = now
844968
self._client_failure_tracker.record_success()
845969
self._retired_channels.append(old_channel)
@@ -940,11 +1064,11 @@ async def get_all_orchestration_states(self,
9401064
) -> list[OrchestrationState]:
9411065
if orchestration_query is None:
9421066
orchestration_query = OrchestrationQuery()
943-
_continuation_token = None
1067+
_continuation_token: wrappers_pb2.StringValue | None = None
9441068

9451069
self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
9461070

947-
states = []
1071+
states: list[OrchestrationState] = []
9481072

9491073
while True:
9501074
req = build_query_instances_req(orchestration_query, _continuation_token)
@@ -1101,11 +1225,11 @@ async def get_all_entities(self,
11011225
entity_query: EntityQuery | None = None) -> list[EntityMetadata]:
11021226
if entity_query is None:
11031227
entity_query = EntityQuery()
1104-
_continuation_token = None
1228+
_continuation_token: wrappers_pb2.StringValue | None = None
11051229

11061230
self._logger.info(f"Retrieving entities by filter: {entity_query}")
11071231

1108-
entities = []
1232+
entities: list[EntityMetadata] = []
11091233

11101234
while True:
11111235
query_request = build_query_entities_req(entity_query, _continuation_token)
@@ -1127,7 +1251,7 @@ async def clean_entity_storage(self,
11271251

11281252
empty_entities_removed = 0
11291253
orphaned_locks_released = 0
1130-
_continuation_token = None
1254+
_continuation_token: wrappers_pb2.StringValue | None = None
11311255

11321256
while True:
11331257
req = pb.CleanEntityStorageRequest(

durabletask/entities/entity_context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState |
7171
"""
7272
return self._state.get_state(intended_type, default)
7373

74-
def set_state(self, new_state: Any):
74+
def set_state(self, new_state: Any) -> None:
7575
"""Set the state of the entity to a new value.
7676
7777
Parameters
@@ -93,7 +93,7 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in
9393
input : Any, optional
9494
The input to provide to the entity for the operation.
9595
"""
96-
encoded_input = shared.to_json(input) if input is not None else None
96+
encoded_input: str | None = shared.to_json(input) if input is not None else None
9797
self._state.add_operation_action(
9898
pb.OperationAction(
9999
sendSignal=pb.SendSignalAction(
@@ -124,7 +124,7 @@ def schedule_new_orchestration(self, orchestration_name: str, input: Any | None
124124
str
125125
The instance ID of the scheduled orchestration.
126126
"""
127-
encoded_input = shared.to_json(input) if input is not None else None
127+
encoded_input: str | None = shared.to_json(input) if input is not None else None
128128
if not instance_id:
129129
instance_id = uuid.uuid4().hex
130130
self._state.add_operation_action(

0 commit comments

Comments
 (0)