Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

ADDED

- Added context-manager support (`__enter__` / `__exit__`) to
`TaskHubGrpcClient` so it can be used with `with` statements, mirroring the
existing `AsyncTaskHubGrpcClient` async-context-manager support and the
`TaskHubGrpcWorker` pattern. `DurableTaskSchedulerClient` inherits this
behavior automatically. `__exit__` delegates to `close()`, so the
resiliency-aware teardown introduced in v1.5.0 (in-flight recreate
thread join, retired-channel timer cancellation, and SDK-owned channel
cleanup) runs unchanged through the new `with` path.

## v1.5.0

ADDED
Expand Down
10 changes: 8 additions & 2 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ def close(self) -> None:
retired_channel.close()
current_channel.close()

def __enter__(self) -> "TaskHubGrpcClient":
return self

def __exit__(self, *args: object) -> None:
self.close()

def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOutput] | str, *,
input: TInput | None = None,
instance_id: str | None = None,
Expand Down Expand Up @@ -783,10 +789,10 @@ async def close(self) -> None:
await retired_channel.close()
await self._channel.close()

async def __aenter__(self):
async def __aenter__(self) -> "AsyncTaskHubGrpcClient":
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, *args: object) -> None:
await self.close()

def _schedule_recreate(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions durabletask/entities/entity_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class EntityLock:
def __init__(self, context: OrchestrationContext):
self._context = context

def __enter__(self):
def __enter__(self) -> EntityLock:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, *args: object) -> None:
self._context._exit_critical_section()
4 changes: 2 additions & 2 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,10 @@ def maximum_timer_interval(self) -> timedelta | None:
"""Get the configured maximum timer interval for long timer chunking."""
return self._maximum_timer_interval

def __enter__(self):
def __enter__(self) -> "TaskHubGrpcWorker":
return self

def __exit__(self, type, value, traceback):
def __exit__(self, *args: object) -> None:
self.stop()

def _classify_stream_outcome(
Expand Down
50 changes: 25 additions & 25 deletions examples/human_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,31 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order):
w.add_activity(place_order)
w.start()

c = client.TaskHubGrpcClient()

# Start a purchase order workflow using the user input
order = Order(args.cost, "MyProduct", 1)
instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order)

def prompt_for_approval():
input("Press [ENTER] to approve the order...\n")
approval_event = namedtuple("Approval", ["approver"])(args.approver)
c.raise_orchestration_event(instance_id, "approval_received", data=approval_event)

# Prompt the user for approval on a background thread
threading.Thread(target=prompt_for_approval, daemon=True).start()

# Wait for the orchestration to complete
try:
state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2)
if not state:
print("Workflow not found!") # not expected
elif state.runtime_status == client.OrchestrationStatus.COMPLETED:
print(f'Orchestration completed! Result: {state.serialized_output}')
else:
state.raise_if_failed() # raises an exception
except TimeoutError:
print("*** Orchestration timed out!")
with client.TaskHubGrpcClient() as c:

# Start a purchase order workflow using the user input
order = Order(args.cost, "MyProduct", 1)
instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order)

def prompt_for_approval():
input("Press [ENTER] to approve the order...\n")
approval_event = namedtuple("Approval", ["approver"])(args.approver)
c.raise_orchestration_event(instance_id, "approval_received", data=approval_event)

# Prompt the user for approval on a background thread
threading.Thread(target=prompt_for_approval, daemon=True).start()

# Wait for the orchestration to complete
try:
state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2)
if not state:
print("Workflow not found!") # not expected
elif state.runtime_status == client.OrchestrationStatus.COMPLETED:
print(f'Orchestration completed! Result: {state.serialized_output}')
else:
state.raise_if_failed() # raises an exception
except TimeoutError:
print("*** Orchestration timed out!")
else:
# Use DurableTaskScheduler
# Use environment variables if provided, otherwise use default emulator values
Expand Down
58 changes: 29 additions & 29 deletions examples/in_memory_backend_example/test/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def test_single_item_order(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
Expand All @@ -98,9 +98,9 @@ def test_multi_item_order(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
Expand All @@ -118,9 +118,9 @@ def test_empty_order_fails(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.FAILED
Expand All @@ -136,9 +136,9 @@ def test_invalid_quantity_fails(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(process_order, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.FAILED
Expand All @@ -163,9 +163,9 @@ def test_low_value_auto_approved(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
Expand All @@ -184,12 +184,12 @@ def test_high_value_approved(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)

# Raise the approval event
c.raise_orchestration_event(instance_id, "approval", data=True)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
# Raise the approval event
c.raise_orchestration_event(instance_id, "approval", data=True)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
Expand All @@ -208,12 +208,12 @@ def test_high_value_rejected(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)

# Reject the order
c.raise_orchestration_event(instance_id, "approval", data=False)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
# Reject the order
c.raise_orchestration_event(instance_id, "approval", data=False)
state = c.wait_for_orchestration_completion(instance_id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
Expand All @@ -232,11 +232,11 @@ def test_high_value_timeout(self):

with _create_worker() as w:
w.start()
c = client.TaskHubGrpcClient(host_address=HOST)
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)
with client.TaskHubGrpcClient(host_address=HOST) as c:
instance_id = c.schedule_new_orchestration(order_with_approval, input=order)

# Don't raise any event — let the timer fire
state = c.wait_for_orchestration_completion(instance_id, timeout=60)
# Don't raise any event — let the timer fire
state = c.wait_for_orchestration_completion(instance_id, timeout=60)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
Expand Down
38 changes: 19 additions & 19 deletions tests/durabletask/entities/test_class_based_entities_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def do_nothing(self, _):
w.add_entity(EmptyEntity, name="EntityNameCustom")
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed
with client.TaskHubGrpcClient(host_address=HOST) as c:
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed

assert invoked

Expand All @@ -59,14 +59,14 @@ def do_nothing(self, _):
w.add_entity(EmptyEntity)
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed
state = c.get_entity(entity_id, include_state=True)
assert state is not None
assert state.id == entity_id
assert state.get_state(int) == 1
with client.TaskHubGrpcClient(host_address=HOST) as c:
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed
state = c.get_entity(entity_id, include_state=True)
assert state is not None
assert state.id == entity_id
assert state.get_state(int) == 1

assert invoked

Expand All @@ -89,10 +89,10 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
w.add_entity(EmptyEntity, name="EntityNameCustom")
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
id = c.schedule_new_orchestration(empty_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
time.sleep(2) # wait for the signal to be processed
with client.TaskHubGrpcClient(host_address=HOST) as c:
id = c.schedule_new_orchestration(empty_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
time.sleep(2) # wait for the signal to be processed

assert invoked
assert state is not None
Expand Down Expand Up @@ -123,9 +123,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
w.add_entity(EmptyEntity)
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
id = c.schedule_new_orchestration(empty_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
id = c.schedule_new_orchestration(empty_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)

assert invoked
assert state is not None
Expand Down
30 changes: 15 additions & 15 deletions tests/durabletask/entities/test_entity_failure_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
w.add_entity(FailingEntity)
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.name == task.get_name(test_orchestrator)
Expand All @@ -66,9 +66,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
w.add_entity(failing_entity)
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.name == task.get_name(test_orchestrator)
Expand Down Expand Up @@ -97,9 +97,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
w.add_entity(FailingEntity)
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.name == task.get_name(test_orchestrator)
Expand Down Expand Up @@ -129,9 +129,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
w.add_entity(failing_entity)
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.name == task.get_name(test_orchestrator)
Expand Down Expand Up @@ -168,9 +168,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _):
w.add_entity(failing_entity)
w.start()

c = client.TaskHubGrpcClient(host_address=HOST)
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)
with client.TaskHubGrpcClient(host_address=HOST) as c:
id = c.schedule_new_orchestration(test_orchestrator)
state = c.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.name == task.get_name(test_orchestrator)
Expand Down
Loading
Loading