Skip to content

Commit 1cb30b0

Browse files
berndverstBernd VerstCopilot
authored
Add context manager support to TaskHubGrpcClient (#145)
* Add context manager support to TaskHubGrpcClient (#134) Add `__enter__`/`__exit__` to the sync `TaskHubGrpcClient` so callers can use it with a `with` statement, mirroring the existing `AsyncTaskHubGrpcClient` async-context-manager support and the `TaskHubGrpcWorker` pattern. `DurableTaskSchedulerClient` inherits this behavior automatically. `__exit__` delegates to `close()`, so the resiliency-aware teardown introduced in #135 (in-flight recreate thread join, retired-channel timer cancellation, SDK-owned channel cleanup) runs unchanged through the new `with` path. Caller-owned channels remain untouched. Migrate every test and example callsite that previously instantiated `TaskHubGrpcClient(...)` and never closed it to the `with` form so the gRPC channel is deterministically released. Unit tests in `test_client.py` that intentionally test construction (with mocked stubs) are left unchanged. Add focused unit tests for the new context-manager behavior, including a regression test that exits a `with` block while a fire-and-forget channel recreate is pending and asserts the #135 resiliency invariants (retired-channel timers cancelled, recreate thread joined) still hold. Fixes #134 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address auto-generated code-quality review comments on PR #145 1. tests/durabletask/test_orchestration_e2e.py (test_suspend_and_resume): Drop the unused `state =` assignment around the expected-timeout `wait_for_orchestration_completion` call. The value is never read because the next line asserts False and the only non-failing path raises TimeoutError; `state` is reassigned a few lines down. Silences the "variable defined multiple times" warning that CodeQL flagged because this previously-untouched line was pulled into the diff by the indent change. 2. tests/durabletask/test_client.py (test_sync_client_context_manager_propagates_exception_and_calls_close): Replace the nested `with pytest.raises(...): with client: raise ...` pattern with an explicit try/except so CodeQL no longer reports the post-block assertions as unreachable. Test intent (exception propagation + cleanup verification) is preserved. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Align __exit__ typing with repo convention (BlobPayloadStore) Use `*args: object` for the `__exit__` parameters instead of leaving them untyped. This matches the most recent context-manager class in the repo (`durabletask/extensions/azure_blob_payloads/blob_payload_store.py`), is more type-safe under Pylance/mypy, and avoids the parameter shadowing of the builtin `type` that exists in `TaskHubGrpcWorker.__exit__`. Behavior is unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Align pre-existing context manager methods with repo idiom Apply the same `__enter__`/`__exit__` typing pattern that PR #145 uses for `TaskHubGrpcClient` (and that `BlobPayloadStore` already follows) to three pre-existing context managers that were either untyped or shadowed a builtin: * `AsyncTaskHubGrpcClient.__aenter__` now returns the concrete type and `__aexit__` takes `*args: object` with `-> None`. * `TaskHubGrpcWorker.__enter__/__exit__` get the same treatment, also removing the `type` parameter that shadowed the builtin. * `EntityLock.__enter__/__exit__` get the same treatment; the file already has `from __future__ import annotations` so the return annotation is the bare class name. Behavior is unchanged: each `__exit__` still delegates to its existing teardown method (`close`/`stop`/`_exit_critical_section`), so the gRPC resiliency teardown added in PR #135 continues to flow through `TaskHubGrpcClient.close()` unchanged. No changelog entry: per the repo's contributor guidance, internal-only type-annotation refactors with no externally observable behavior change are excluded from the changelog. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Bernd Verst <beverst@microsoft.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1232472 commit 1cb30b0

15 files changed

Lines changed: 805 additions & 671 deletions

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## Unreleased
99

10+
ADDED
11+
12+
- Added context-manager support (`__enter__` / `__exit__`) to
13+
`TaskHubGrpcClient` so it can be used with `with` statements, mirroring the
14+
existing `AsyncTaskHubGrpcClient` async-context-manager support and the
15+
`TaskHubGrpcWorker` pattern. `DurableTaskSchedulerClient` inherits this
16+
behavior automatically. `__exit__` delegates to `close()`, so the
17+
resiliency-aware teardown introduced in v1.5.0 (in-flight recreate
18+
thread join, retired-channel timer cancellation, and SDK-owned channel
19+
cleanup) runs unchanged through the new `with` path.
20+
1021
## v1.5.0
1122

1223
ADDED

durabletask/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ def close(self) -> None:
366366
retired_channel.close()
367367
current_channel.close()
368368

369+
def __enter__(self) -> "TaskHubGrpcClient":
370+
return self
371+
372+
def __exit__(self, *args: object) -> None:
373+
self.close()
374+
369375
def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOutput] | str, *,
370376
input: TInput | None = None,
371377
instance_id: str | None = None,
@@ -783,10 +789,10 @@ async def close(self) -> None:
783789
await retired_channel.close()
784790
await self._channel.close()
785791

786-
async def __aenter__(self):
792+
async def __aenter__(self) -> "AsyncTaskHubGrpcClient":
787793
return self
788794

789-
async def __aexit__(self, exc_type, exc_val, exc_tb):
795+
async def __aexit__(self, *args: object) -> None:
790796
await self.close()
791797

792798
def _schedule_recreate(self) -> None:

durabletask/entities/entity_lock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ class EntityLock:
1212
def __init__(self, context: OrchestrationContext):
1313
self._context = context
1414

15-
def __enter__(self):
15+
def __enter__(self) -> EntityLock:
1616
return self
1717

18-
def __exit__(self, exc_type, exc_val, exc_tb):
18+
def __exit__(self, *args: object) -> None:
1919
self._context._exit_critical_section()

durabletask/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,10 +581,10 @@ def maximum_timer_interval(self) -> timedelta | None:
581581
"""Get the configured maximum timer interval for long timer chunking."""
582582
return self._maximum_timer_interval
583583

584-
def __enter__(self):
584+
def __enter__(self) -> "TaskHubGrpcWorker":
585585
return self
586586

587-
def __exit__(self, type, value, traceback):
587+
def __exit__(self, *args: object) -> None:
588588
self.stop()
589589

590590
def _classify_stream_outcome(

examples/human_interaction.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,31 +79,31 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order):
7979
w.add_activity(place_order)
8080
w.start()
8181

82-
c = client.TaskHubGrpcClient()
83-
84-
# Start a purchase order workflow using the user input
85-
order = Order(args.cost, "MyProduct", 1)
86-
instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order)
87-
88-
def prompt_for_approval():
89-
input("Press [ENTER] to approve the order...\n")
90-
approval_event = namedtuple("Approval", ["approver"])(args.approver)
91-
c.raise_orchestration_event(instance_id, "approval_received", data=approval_event)
92-
93-
# Prompt the user for approval on a background thread
94-
threading.Thread(target=prompt_for_approval, daemon=True).start()
95-
96-
# Wait for the orchestration to complete
97-
try:
98-
state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2)
99-
if not state:
100-
print("Workflow not found!") # not expected
101-
elif state.runtime_status == client.OrchestrationStatus.COMPLETED:
102-
print(f'Orchestration completed! Result: {state.serialized_output}')
103-
else:
104-
state.raise_if_failed() # raises an exception
105-
except TimeoutError:
106-
print("*** Orchestration timed out!")
82+
with client.TaskHubGrpcClient() as c:
83+
84+
# Start a purchase order workflow using the user input
85+
order = Order(args.cost, "MyProduct", 1)
86+
instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order)
87+
88+
def prompt_for_approval():
89+
input("Press [ENTER] to approve the order...\n")
90+
approval_event = namedtuple("Approval", ["approver"])(args.approver)
91+
c.raise_orchestration_event(instance_id, "approval_received", data=approval_event)
92+
93+
# Prompt the user for approval on a background thread
94+
threading.Thread(target=prompt_for_approval, daemon=True).start()
95+
96+
# Wait for the orchestration to complete
97+
try:
98+
state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2)
99+
if not state:
100+
print("Workflow not found!") # not expected
101+
elif state.runtime_status == client.OrchestrationStatus.COMPLETED:
102+
print(f'Orchestration completed! Result: {state.serialized_output}')
103+
else:
104+
state.raise_if_failed() # raises an exception
105+
except TimeoutError:
106+
print("*** Orchestration timed out!")
107107
else:
108108
# Use DurableTaskScheduler
109109
# Use environment variables if provided, otherwise use default emulator values

examples/in_memory_backend_example/test/test_workflows.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def test_single_item_order(self):
7272

7373
with _create_worker() as w:
7474
w.start()
75-
c = client.TaskHubGrpcClient(host_address=HOST)
76-
instance_id = c.schedule_new_orchestration(process_order, input=order)
77-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
75+
with client.TaskHubGrpcClient(host_address=HOST) as c:
76+
instance_id = c.schedule_new_orchestration(process_order, input=order)
77+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
7878

7979
assert state is not None
8080
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -98,9 +98,9 @@ def test_multi_item_order(self):
9898

9999
with _create_worker() as w:
100100
w.start()
101-
c = client.TaskHubGrpcClient(host_address=HOST)
102-
instance_id = c.schedule_new_orchestration(process_order, input=order)
103-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
101+
with client.TaskHubGrpcClient(host_address=HOST) as c:
102+
instance_id = c.schedule_new_orchestration(process_order, input=order)
103+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
104104

105105
assert state is not None
106106
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -118,9 +118,9 @@ def test_empty_order_fails(self):
118118

119119
with _create_worker() as w:
120120
w.start()
121-
c = client.TaskHubGrpcClient(host_address=HOST)
122-
instance_id = c.schedule_new_orchestration(process_order, input=order)
123-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
121+
with client.TaskHubGrpcClient(host_address=HOST) as c:
122+
instance_id = c.schedule_new_orchestration(process_order, input=order)
123+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
124124

125125
assert state is not None
126126
assert state.runtime_status == client.OrchestrationStatus.FAILED
@@ -136,9 +136,9 @@ def test_invalid_quantity_fails(self):
136136

137137
with _create_worker() as w:
138138
w.start()
139-
c = client.TaskHubGrpcClient(host_address=HOST)
140-
instance_id = c.schedule_new_orchestration(process_order, input=order)
141-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
139+
with client.TaskHubGrpcClient(host_address=HOST) as c:
140+
instance_id = c.schedule_new_orchestration(process_order, input=order)
141+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
142142

143143
assert state is not None
144144
assert state.runtime_status == client.OrchestrationStatus.FAILED
@@ -163,9 +163,9 @@ def test_low_value_auto_approved(self):
163163

164164
with _create_worker() as w:
165165
w.start()
166-
c = client.TaskHubGrpcClient(host_address=HOST)
167-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
168-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
166+
with client.TaskHubGrpcClient(host_address=HOST) as c:
167+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
168+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
169169

170170
assert state is not None
171171
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -184,12 +184,12 @@ def test_high_value_approved(self):
184184

185185
with _create_worker() as w:
186186
w.start()
187-
c = client.TaskHubGrpcClient(host_address=HOST)
188-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
187+
with client.TaskHubGrpcClient(host_address=HOST) as c:
188+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
189189

190-
# Raise the approval event
191-
c.raise_orchestration_event(instance_id, "approval", data=True)
192-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
190+
# Raise the approval event
191+
c.raise_orchestration_event(instance_id, "approval", data=True)
192+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
193193

194194
assert state is not None
195195
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -208,12 +208,12 @@ def test_high_value_rejected(self):
208208

209209
with _create_worker() as w:
210210
w.start()
211-
c = client.TaskHubGrpcClient(host_address=HOST)
212-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
211+
with client.TaskHubGrpcClient(host_address=HOST) as c:
212+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
213213

214-
# Reject the order
215-
c.raise_orchestration_event(instance_id, "approval", data=False)
216-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
214+
# Reject the order
215+
c.raise_orchestration_event(instance_id, "approval", data=False)
216+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
217217

218218
assert state is not None
219219
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -232,11 +232,11 @@ def test_high_value_timeout(self):
232232

233233
with _create_worker() as w:
234234
w.start()
235-
c = client.TaskHubGrpcClient(host_address=HOST)
236-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
235+
with client.TaskHubGrpcClient(host_address=HOST) as c:
236+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
237237

238-
# Don't raise any event — let the timer fire
239-
state = c.wait_for_orchestration_completion(instance_id, timeout=60)
238+
# Don't raise any event — let the timer fire
239+
state = c.wait_for_orchestration_completion(instance_id, timeout=60)
240240

241241
assert state is not None
242242
assert state.runtime_status == client.OrchestrationStatus.COMPLETED

tests/durabletask/entities/test_class_based_entities_e2e.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def do_nothing(self, _):
3737
w.add_entity(EmptyEntity, name="EntityNameCustom")
3838
w.start()
3939

40-
c = client.TaskHubGrpcClient(host_address=HOST)
41-
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
42-
c.signal_entity(entity_id, "do_nothing")
43-
time.sleep(2) # wait for the signal to be processed
40+
with client.TaskHubGrpcClient(host_address=HOST) as c:
41+
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
42+
c.signal_entity(entity_id, "do_nothing")
43+
time.sleep(2) # wait for the signal to be processed
4444

4545
assert invoked
4646

@@ -59,14 +59,14 @@ def do_nothing(self, _):
5959
w.add_entity(EmptyEntity)
6060
w.start()
6161

62-
c = client.TaskHubGrpcClient(host_address=HOST)
63-
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
64-
c.signal_entity(entity_id, "do_nothing")
65-
time.sleep(2) # wait for the signal to be processed
66-
state = c.get_entity(entity_id, include_state=True)
67-
assert state is not None
68-
assert state.id == entity_id
69-
assert state.get_state(int) == 1
62+
with client.TaskHubGrpcClient(host_address=HOST) as c:
63+
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
64+
c.signal_entity(entity_id, "do_nothing")
65+
time.sleep(2) # wait for the signal to be processed
66+
state = c.get_entity(entity_id, include_state=True)
67+
assert state is not None
68+
assert state.id == entity_id
69+
assert state.get_state(int) == 1
7070

7171
assert invoked
7272

@@ -89,10 +89,10 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
8989
w.add_entity(EmptyEntity, name="EntityNameCustom")
9090
w.start()
9191

92-
c = client.TaskHubGrpcClient(host_address=HOST)
93-
id = c.schedule_new_orchestration(empty_orchestrator)
94-
state = c.wait_for_orchestration_completion(id, timeout=30)
95-
time.sleep(2) # wait for the signal to be processed
92+
with client.TaskHubGrpcClient(host_address=HOST) as c:
93+
id = c.schedule_new_orchestration(empty_orchestrator)
94+
state = c.wait_for_orchestration_completion(id, timeout=30)
95+
time.sleep(2) # wait for the signal to be processed
9696

9797
assert invoked
9898
assert state is not None
@@ -123,9 +123,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
123123
w.add_entity(EmptyEntity)
124124
w.start()
125125

126-
c = client.TaskHubGrpcClient(host_address=HOST)
127-
id = c.schedule_new_orchestration(empty_orchestrator)
128-
state = c.wait_for_orchestration_completion(id, timeout=30)
126+
with client.TaskHubGrpcClient(host_address=HOST) as c:
127+
id = c.schedule_new_orchestration(empty_orchestrator)
128+
state = c.wait_for_orchestration_completion(id, timeout=30)
129129

130130
assert invoked
131131
assert state is not None

tests/durabletask/entities/test_entity_failure_handling.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
3939
w.add_entity(FailingEntity)
4040
w.start()
4141

42-
c = client.TaskHubGrpcClient(host_address=HOST)
43-
id = c.schedule_new_orchestration(test_orchestrator)
44-
state = c.wait_for_orchestration_completion(id, timeout=30)
42+
with client.TaskHubGrpcClient(host_address=HOST) as c:
43+
id = c.schedule_new_orchestration(test_orchestrator)
44+
state = c.wait_for_orchestration_completion(id, timeout=30)
4545

4646
assert state is not None
4747
assert state.name == task.get_name(test_orchestrator)
@@ -66,9 +66,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
6666
w.add_entity(failing_entity)
6767
w.start()
6868

69-
c = client.TaskHubGrpcClient(host_address=HOST)
70-
id = c.schedule_new_orchestration(test_orchestrator)
71-
state = c.wait_for_orchestration_completion(id, timeout=30)
69+
with client.TaskHubGrpcClient(host_address=HOST) as c:
70+
id = c.schedule_new_orchestration(test_orchestrator)
71+
state = c.wait_for_orchestration_completion(id, timeout=30)
7272

7373
assert state is not None
7474
assert state.name == task.get_name(test_orchestrator)
@@ -97,9 +97,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
9797
w.add_entity(FailingEntity)
9898
w.start()
9999

100-
c = client.TaskHubGrpcClient(host_address=HOST)
101-
id = c.schedule_new_orchestration(test_orchestrator)
102-
state = c.wait_for_orchestration_completion(id, timeout=30)
100+
with client.TaskHubGrpcClient(host_address=HOST) as c:
101+
id = c.schedule_new_orchestration(test_orchestrator)
102+
state = c.wait_for_orchestration_completion(id, timeout=30)
103103

104104
assert state is not None
105105
assert state.name == task.get_name(test_orchestrator)
@@ -129,9 +129,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
129129
w.add_entity(failing_entity)
130130
w.start()
131131

132-
c = client.TaskHubGrpcClient(host_address=HOST)
133-
id = c.schedule_new_orchestration(test_orchestrator)
134-
state = c.wait_for_orchestration_completion(id, timeout=30)
132+
with client.TaskHubGrpcClient(host_address=HOST) as c:
133+
id = c.schedule_new_orchestration(test_orchestrator)
134+
state = c.wait_for_orchestration_completion(id, timeout=30)
135135

136136
assert state is not None
137137
assert state.name == task.get_name(test_orchestrator)
@@ -168,9 +168,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
168168
w.add_entity(failing_entity)
169169
w.start()
170170

171-
c = client.TaskHubGrpcClient(host_address=HOST)
172-
id = c.schedule_new_orchestration(test_orchestrator)
173-
state = c.wait_for_orchestration_completion(id, timeout=30)
171+
with client.TaskHubGrpcClient(host_address=HOST) as c:
172+
id = c.schedule_new_orchestration(test_orchestrator)
173+
state = c.wait_for_orchestration_completion(id, timeout=30)
174174

175175
assert state is not None
176176
assert state.name == task.get_name(test_orchestrator)

0 commit comments

Comments
 (0)