Skip to content

Commit 752b383

Browse files
Bernd VerstCopilot
andcommitted
Add context manager support to TaskHubGrpcClient (#134)
Add `__enter__`/`__exit__` to the sync `TaskHubGrpcClient` so callers can use it with a `with` statement, mirroring the existing `AsyncTaskHubGrpcClient` async-context-manager support and the `TaskHubGrpcWorker` pattern. `DurableTaskSchedulerClient` inherits this behavior automatically. `__exit__` delegates to `close()`, so the resiliency-aware teardown introduced in #135 (in-flight recreate thread join, retired-channel timer cancellation, SDK-owned channel cleanup) runs unchanged through the new `with` path. Caller-owned channels remain untouched. Migrate every test and example callsite that previously instantiated `TaskHubGrpcClient(...)` and never closed it to the `with` form so the gRPC channel is deterministically released. Unit tests in `test_client.py` that intentionally test construction (with mocked stubs) are left unchanged. Add focused unit tests for the new context-manager behavior, including a regression test that exits a `with` block while a fire-and-forget channel recreate is pending and asserts the #135 resiliency invariants (retired-channel timers cancelled, recreate thread joined) still hold. Fixes #134 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1232472 commit 752b383

13 files changed

Lines changed: 794 additions & 665 deletions

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## Unreleased
99

10+
ADDED
11+
12+
- Added context-manager support (`__enter__` / `__exit__`) to
13+
`TaskHubGrpcClient` so it can be used with `with` statements, mirroring the
14+
existing `AsyncTaskHubGrpcClient` async-context-manager support and the
15+
`TaskHubGrpcWorker` pattern. `DurableTaskSchedulerClient` inherits this
16+
behavior automatically. `__exit__` delegates to `close()`, so the
17+
resiliency-aware teardown introduced in v1.5.0 (in-flight recreate
18+
thread join, retired-channel timer cancellation, and SDK-owned channel
19+
cleanup) runs unchanged through the new `with` path.
20+
1021
## v1.5.0
1122

1223
ADDED

durabletask/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ def close(self) -> None:
366366
retired_channel.close()
367367
current_channel.close()
368368

369+
def __enter__(self) -> "TaskHubGrpcClient":
370+
return self
371+
372+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
373+
self.close()
374+
369375
def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOutput] | str, *,
370376
input: TInput | None = None,
371377
instance_id: str | None = None,

examples/human_interaction.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,31 +79,31 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order):
7979
w.add_activity(place_order)
8080
w.start()
8181

82-
c = client.TaskHubGrpcClient()
83-
84-
# Start a purchase order workflow using the user input
85-
order = Order(args.cost, "MyProduct", 1)
86-
instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order)
87-
88-
def prompt_for_approval():
89-
input("Press [ENTER] to approve the order...\n")
90-
approval_event = namedtuple("Approval", ["approver"])(args.approver)
91-
c.raise_orchestration_event(instance_id, "approval_received", data=approval_event)
92-
93-
# Prompt the user for approval on a background thread
94-
threading.Thread(target=prompt_for_approval, daemon=True).start()
95-
96-
# Wait for the orchestration to complete
97-
try:
98-
state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2)
99-
if not state:
100-
print("Workflow not found!") # not expected
101-
elif state.runtime_status == client.OrchestrationStatus.COMPLETED:
102-
print(f'Orchestration completed! Result: {state.serialized_output}')
103-
else:
104-
state.raise_if_failed() # raises an exception
105-
except TimeoutError:
106-
print("*** Orchestration timed out!")
82+
with client.TaskHubGrpcClient() as c:
83+
84+
# Start a purchase order workflow using the user input
85+
order = Order(args.cost, "MyProduct", 1)
86+
instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order)
87+
88+
def prompt_for_approval():
89+
input("Press [ENTER] to approve the order...\n")
90+
approval_event = namedtuple("Approval", ["approver"])(args.approver)
91+
c.raise_orchestration_event(instance_id, "approval_received", data=approval_event)
92+
93+
# Prompt the user for approval on a background thread
94+
threading.Thread(target=prompt_for_approval, daemon=True).start()
95+
96+
# Wait for the orchestration to complete
97+
try:
98+
state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2)
99+
if not state:
100+
print("Workflow not found!") # not expected
101+
elif state.runtime_status == client.OrchestrationStatus.COMPLETED:
102+
print(f'Orchestration completed! Result: {state.serialized_output}')
103+
else:
104+
state.raise_if_failed() # raises an exception
105+
except TimeoutError:
106+
print("*** Orchestration timed out!")
107107
else:
108108
# Use DurableTaskScheduler
109109
# Use environment variables if provided, otherwise use default emulator values

examples/in_memory_backend_example/test/test_workflows.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def test_single_item_order(self):
7272

7373
with _create_worker() as w:
7474
w.start()
75-
c = client.TaskHubGrpcClient(host_address=HOST)
76-
instance_id = c.schedule_new_orchestration(process_order, input=order)
77-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
75+
with client.TaskHubGrpcClient(host_address=HOST) as c:
76+
instance_id = c.schedule_new_orchestration(process_order, input=order)
77+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
7878

7979
assert state is not None
8080
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -98,9 +98,9 @@ def test_multi_item_order(self):
9898

9999
with _create_worker() as w:
100100
w.start()
101-
c = client.TaskHubGrpcClient(host_address=HOST)
102-
instance_id = c.schedule_new_orchestration(process_order, input=order)
103-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
101+
with client.TaskHubGrpcClient(host_address=HOST) as c:
102+
instance_id = c.schedule_new_orchestration(process_order, input=order)
103+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
104104

105105
assert state is not None
106106
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -118,9 +118,9 @@ def test_empty_order_fails(self):
118118

119119
with _create_worker() as w:
120120
w.start()
121-
c = client.TaskHubGrpcClient(host_address=HOST)
122-
instance_id = c.schedule_new_orchestration(process_order, input=order)
123-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
121+
with client.TaskHubGrpcClient(host_address=HOST) as c:
122+
instance_id = c.schedule_new_orchestration(process_order, input=order)
123+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
124124

125125
assert state is not None
126126
assert state.runtime_status == client.OrchestrationStatus.FAILED
@@ -136,9 +136,9 @@ def test_invalid_quantity_fails(self):
136136

137137
with _create_worker() as w:
138138
w.start()
139-
c = client.TaskHubGrpcClient(host_address=HOST)
140-
instance_id = c.schedule_new_orchestration(process_order, input=order)
141-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
139+
with client.TaskHubGrpcClient(host_address=HOST) as c:
140+
instance_id = c.schedule_new_orchestration(process_order, input=order)
141+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
142142

143143
assert state is not None
144144
assert state.runtime_status == client.OrchestrationStatus.FAILED
@@ -163,9 +163,9 @@ def test_low_value_auto_approved(self):
163163

164164
with _create_worker() as w:
165165
w.start()
166-
c = client.TaskHubGrpcClient(host_address=HOST)
167-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
168-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
166+
with client.TaskHubGrpcClient(host_address=HOST) as c:
167+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
168+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
169169

170170
assert state is not None
171171
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -184,12 +184,12 @@ def test_high_value_approved(self):
184184

185185
with _create_worker() as w:
186186
w.start()
187-
c = client.TaskHubGrpcClient(host_address=HOST)
188-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
187+
with client.TaskHubGrpcClient(host_address=HOST) as c:
188+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
189189

190-
# Raise the approval event
191-
c.raise_orchestration_event(instance_id, "approval", data=True)
192-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
190+
# Raise the approval event
191+
c.raise_orchestration_event(instance_id, "approval", data=True)
192+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
193193

194194
assert state is not None
195195
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -208,12 +208,12 @@ def test_high_value_rejected(self):
208208

209209
with _create_worker() as w:
210210
w.start()
211-
c = client.TaskHubGrpcClient(host_address=HOST)
212-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
211+
with client.TaskHubGrpcClient(host_address=HOST) as c:
212+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
213213

214-
# Reject the order
215-
c.raise_orchestration_event(instance_id, "approval", data=False)
216-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
214+
# Reject the order
215+
c.raise_orchestration_event(instance_id, "approval", data=False)
216+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
217217

218218
assert state is not None
219219
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
@@ -232,11 +232,11 @@ def test_high_value_timeout(self):
232232

233233
with _create_worker() as w:
234234
w.start()
235-
c = client.TaskHubGrpcClient(host_address=HOST)
236-
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
235+
with client.TaskHubGrpcClient(host_address=HOST) as c:
236+
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
237237

238-
# Don't raise any event — let the timer fire
239-
state = c.wait_for_orchestration_completion(instance_id, timeout=60)
238+
# Don't raise any event — let the timer fire
239+
state = c.wait_for_orchestration_completion(instance_id, timeout=60)
240240

241241
assert state is not None
242242
assert state.runtime_status == client.OrchestrationStatus.COMPLETED

tests/durabletask/entities/test_class_based_entities_e2e.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def do_nothing(self, _):
3737
w.add_entity(EmptyEntity, name="EntityNameCustom")
3838
w.start()
3939

40-
c = client.TaskHubGrpcClient(host_address=HOST)
41-
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
42-
c.signal_entity(entity_id, "do_nothing")
43-
time.sleep(2) # wait for the signal to be processed
40+
with client.TaskHubGrpcClient(host_address=HOST) as c:
41+
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
42+
c.signal_entity(entity_id, "do_nothing")
43+
time.sleep(2) # wait for the signal to be processed
4444

4545
assert invoked
4646

@@ -59,14 +59,14 @@ def do_nothing(self, _):
5959
w.add_entity(EmptyEntity)
6060
w.start()
6161

62-
c = client.TaskHubGrpcClient(host_address=HOST)
63-
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
64-
c.signal_entity(entity_id, "do_nothing")
65-
time.sleep(2) # wait for the signal to be processed
66-
state = c.get_entity(entity_id, include_state=True)
67-
assert state is not None
68-
assert state.id == entity_id
69-
assert state.get_state(int) == 1
62+
with client.TaskHubGrpcClient(host_address=HOST) as c:
63+
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
64+
c.signal_entity(entity_id, "do_nothing")
65+
time.sleep(2) # wait for the signal to be processed
66+
state = c.get_entity(entity_id, include_state=True)
67+
assert state is not None
68+
assert state.id == entity_id
69+
assert state.get_state(int) == 1
7070

7171
assert invoked
7272

@@ -89,10 +89,10 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
8989
w.add_entity(EmptyEntity, name="EntityNameCustom")
9090
w.start()
9191

92-
c = client.TaskHubGrpcClient(host_address=HOST)
93-
id = c.schedule_new_orchestration(empty_orchestrator)
94-
state = c.wait_for_orchestration_completion(id, timeout=30)
95-
time.sleep(2) # wait for the signal to be processed
92+
with client.TaskHubGrpcClient(host_address=HOST) as c:
93+
id = c.schedule_new_orchestration(empty_orchestrator)
94+
state = c.wait_for_orchestration_completion(id, timeout=30)
95+
time.sleep(2) # wait for the signal to be processed
9696

9797
assert invoked
9898
assert state is not None
@@ -123,9 +123,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
123123
w.add_entity(EmptyEntity)
124124
w.start()
125125

126-
c = client.TaskHubGrpcClient(host_address=HOST)
127-
id = c.schedule_new_orchestration(empty_orchestrator)
128-
state = c.wait_for_orchestration_completion(id, timeout=30)
126+
with client.TaskHubGrpcClient(host_address=HOST) as c:
127+
id = c.schedule_new_orchestration(empty_orchestrator)
128+
state = c.wait_for_orchestration_completion(id, timeout=30)
129129

130130
assert invoked
131131
assert state is not None

tests/durabletask/entities/test_entity_failure_handling.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
3939
w.add_entity(FailingEntity)
4040
w.start()
4141

42-
c = client.TaskHubGrpcClient(host_address=HOST)
43-
id = c.schedule_new_orchestration(test_orchestrator)
44-
state = c.wait_for_orchestration_completion(id, timeout=30)
42+
with client.TaskHubGrpcClient(host_address=HOST) as c:
43+
id = c.schedule_new_orchestration(test_orchestrator)
44+
state = c.wait_for_orchestration_completion(id, timeout=30)
4545

4646
assert state is not None
4747
assert state.name == task.get_name(test_orchestrator)
@@ -66,9 +66,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
6666
w.add_entity(failing_entity)
6767
w.start()
6868

69-
c = client.TaskHubGrpcClient(host_address=HOST)
70-
id = c.schedule_new_orchestration(test_orchestrator)
71-
state = c.wait_for_orchestration_completion(id, timeout=30)
69+
with client.TaskHubGrpcClient(host_address=HOST) as c:
70+
id = c.schedule_new_orchestration(test_orchestrator)
71+
state = c.wait_for_orchestration_completion(id, timeout=30)
7272

7373
assert state is not None
7474
assert state.name == task.get_name(test_orchestrator)
@@ -97,9 +97,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
9797
w.add_entity(FailingEntity)
9898
w.start()
9999

100-
c = client.TaskHubGrpcClient(host_address=HOST)
101-
id = c.schedule_new_orchestration(test_orchestrator)
102-
state = c.wait_for_orchestration_completion(id, timeout=30)
100+
with client.TaskHubGrpcClient(host_address=HOST) as c:
101+
id = c.schedule_new_orchestration(test_orchestrator)
102+
state = c.wait_for_orchestration_completion(id, timeout=30)
103103

104104
assert state is not None
105105
assert state.name == task.get_name(test_orchestrator)
@@ -129,9 +129,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
129129
w.add_entity(failing_entity)
130130
w.start()
131131

132-
c = client.TaskHubGrpcClient(host_address=HOST)
133-
id = c.schedule_new_orchestration(test_orchestrator)
134-
state = c.wait_for_orchestration_completion(id, timeout=30)
132+
with client.TaskHubGrpcClient(host_address=HOST) as c:
133+
id = c.schedule_new_orchestration(test_orchestrator)
134+
state = c.wait_for_orchestration_completion(id, timeout=30)
135135

136136
assert state is not None
137137
assert state.name == task.get_name(test_orchestrator)
@@ -168,9 +168,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
168168
w.add_entity(failing_entity)
169169
w.start()
170170

171-
c = client.TaskHubGrpcClient(host_address=HOST)
172-
id = c.schedule_new_orchestration(test_orchestrator)
173-
state = c.wait_for_orchestration_completion(id, timeout=30)
171+
with client.TaskHubGrpcClient(host_address=HOST) as c:
172+
id = c.schedule_new_orchestration(test_orchestrator)
173+
state = c.wait_for_orchestration_completion(id, timeout=30)
174174

175175
assert state is not None
176176
assert state.name == task.get_name(test_orchestrator)

0 commit comments

Comments
 (0)