diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a75913e..7a7daf38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,17 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +ADDED + +- Added context-manager support (`__enter__` / `__exit__`) to + `TaskHubGrpcClient` so it can be used with `with` statements, mirroring the + existing `AsyncTaskHubGrpcClient` async-context-manager support and the + `TaskHubGrpcWorker` pattern. `DurableTaskSchedulerClient` inherits this + behavior automatically. `__exit__` delegates to `close()`, so the + resiliency-aware teardown introduced in v1.5.0 (in-flight recreate + thread join, retired-channel timer cancellation, and SDK-owned channel + cleanup) runs unchanged through the new `with` path. + ## v1.5.0 ADDED diff --git a/durabletask/client.py b/durabletask/client.py index 56d4b05a..7c85f9f1 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -366,6 +366,12 @@ def close(self) -> None: retired_channel.close() current_channel.close() + def __enter__(self) -> "TaskHubGrpcClient": + return self + + def __exit__(self, *args: object) -> None: + self.close() + def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOutput] | str, *, input: TInput | None = None, instance_id: str | None = None, @@ -783,10 +789,10 @@ async def close(self) -> None: await retired_channel.close() await self._channel.close() - async def __aenter__(self): + async def __aenter__(self) -> "AsyncTaskHubGrpcClient": return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, *args: object) -> None: await self.close() def _schedule_recreate(self) -> None: diff --git a/durabletask/entities/entity_lock.py b/durabletask/entities/entity_lock.py index becc30e1..33c52d87 100644 --- a/durabletask/entities/entity_lock.py +++ b/durabletask/entities/entity_lock.py @@ -12,8 +12,8 @@ class EntityLock: def __init__(self, context: OrchestrationContext): self._context = context - def __enter__(self): + def __enter__(self) -> EntityLock: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, *args: object) -> None: self._context._exit_critical_section() diff --git a/durabletask/worker.py b/durabletask/worker.py index ad081ee4..a7b42f2f 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -581,10 +581,10 @@ def maximum_timer_interval(self) -> timedelta | None: """Get the configured maximum timer interval for long timer chunking.""" return self._maximum_timer_interval - def __enter__(self): + def __enter__(self) -> "TaskHubGrpcWorker": return self - def __exit__(self, type, value, traceback): + def __exit__(self, *args: object) -> None: self.stop() def _classify_stream_outcome( diff --git a/examples/human_interaction.py b/examples/human_interaction.py index b43336dd..69131006 100644 --- a/examples/human_interaction.py +++ b/examples/human_interaction.py @@ -79,31 +79,31 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): w.add_activity(place_order) w.start() - c = client.TaskHubGrpcClient() - - # Start a purchase order workflow using the user input - order = Order(args.cost, "MyProduct", 1) - instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order) - - def prompt_for_approval(): - input("Press [ENTER] to approve the order...\n") - approval_event = namedtuple("Approval", ["approver"])(args.approver) - c.raise_orchestration_event(instance_id, "approval_received", data=approval_event) - - # Prompt the user for approval on a background thread - threading.Thread(target=prompt_for_approval, daemon=True).start() - - # Wait for the orchestration to complete - try: - state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2) - if not state: - print("Workflow not found!") # not expected - elif state.runtime_status == client.OrchestrationStatus.COMPLETED: - print(f'Orchestration completed! Result: {state.serialized_output}') - else: - state.raise_if_failed() # raises an exception - except TimeoutError: - print("*** Orchestration timed out!") + with client.TaskHubGrpcClient() as c: + + # Start a purchase order workflow using the user input + order = Order(args.cost, "MyProduct", 1) + instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order) + + def prompt_for_approval(): + input("Press [ENTER] to approve the order...\n") + approval_event = namedtuple("Approval", ["approver"])(args.approver) + c.raise_orchestration_event(instance_id, "approval_received", data=approval_event) + + # Prompt the user for approval on a background thread + threading.Thread(target=prompt_for_approval, daemon=True).start() + + # Wait for the orchestration to complete + try: + state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2) + if not state: + print("Workflow not found!") # not expected + elif state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + else: + state.raise_if_failed() # raises an exception + except TimeoutError: + print("*** Orchestration timed out!") else: # Use DurableTaskScheduler # Use environment variables if provided, otherwise use default emulator values diff --git a/examples/in_memory_backend_example/test/test_workflows.py b/examples/in_memory_backend_example/test/test_workflows.py index 80b3c01a..ff9f3a31 100644 --- a/examples/in_memory_backend_example/test/test_workflows.py +++ b/examples/in_memory_backend_example/test/test_workflows.py @@ -72,9 +72,9 @@ def test_single_item_order(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(process_order, input=order) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(process_order, input=order) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -98,9 +98,9 @@ def test_multi_item_order(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(process_order, input=order) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(process_order, input=order) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -118,9 +118,9 @@ def test_empty_order_fails(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(process_order, input=order) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(process_order, input=order) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.FAILED @@ -136,9 +136,9 @@ def test_invalid_quantity_fails(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(process_order, input=order) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(process_order, input=order) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.FAILED @@ -163,9 +163,9 @@ def test_low_value_auto_approved(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(order_with_approval, input=order) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(order_with_approval, input=order) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -184,12 +184,12 @@ def test_high_value_approved(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(order_with_approval, input=order) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(order_with_approval, input=order) - # Raise the approval event - c.raise_orchestration_event(instance_id, "approval", data=True) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) + # Raise the approval event + c.raise_orchestration_event(instance_id, "approval", data=True) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -208,12 +208,12 @@ def test_high_value_rejected(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(order_with_approval, input=order) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(order_with_approval, input=order) - # Reject the order - c.raise_orchestration_event(instance_id, "approval", data=False) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) + # Reject the order + c.raise_orchestration_event(instance_id, "approval", data=False) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -232,11 +232,11 @@ def test_high_value_timeout(self): with _create_worker() as w: w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - instance_id = c.schedule_new_orchestration(order_with_approval, input=order) + with client.TaskHubGrpcClient(host_address=HOST) as c: + instance_id = c.schedule_new_orchestration(order_with_approval, input=order) - # Don't raise any event — let the timer fire - state = c.wait_for_orchestration_completion(instance_id, timeout=60) + # Don't raise any event — let the timer fire + state = c.wait_for_orchestration_completion(instance_id, timeout=60) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED diff --git a/tests/durabletask/entities/test_class_based_entities_e2e.py b/tests/durabletask/entities/test_class_based_entities_e2e.py index ae8b31b2..ab5c6074 100644 --- a/tests/durabletask/entities/test_class_based_entities_e2e.py +++ b/tests/durabletask/entities/test_class_based_entities_e2e.py @@ -37,10 +37,10 @@ def do_nothing(self, _): w.add_entity(EmptyEntity, name="EntityNameCustom") w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity") - c.signal_entity(entity_id, "do_nothing") - time.sleep(2) # wait for the signal to be processed + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed assert invoked @@ -59,14 +59,14 @@ def do_nothing(self, _): w.add_entity(EmptyEntity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") - c.signal_entity(entity_id, "do_nothing") - time.sleep(2) # wait for the signal to be processed - state = c.get_entity(entity_id, include_state=True) - assert state is not None - assert state.id == entity_id - assert state.get_state(int) == 1 + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + state = c.get_entity(entity_id, include_state=True) + assert state is not None + assert state.id == entity_id + assert state.get_state(int) == 1 assert invoked @@ -89,10 +89,10 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(EmptyEntity, name="EntityNameCustom") w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) - time.sleep(2) # wait for the signal to be processed + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed assert invoked assert state is not None @@ -123,9 +123,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(EmptyEntity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert invoked assert state is not None diff --git a/tests/durabletask/entities/test_entity_failure_handling.py b/tests/durabletask/entities/test_entity_failure_handling.py index 2db398b6..92ad2351 100644 --- a/tests/durabletask/entities/test_entity_failure_handling.py +++ b/tests/durabletask/entities/test_entity_failure_handling.py @@ -39,9 +39,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(FailingEntity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(test_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(test_orchestrator) @@ -66,9 +66,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(failing_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(test_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(test_orchestrator) @@ -97,9 +97,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(FailingEntity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(test_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(test_orchestrator) @@ -129,9 +129,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(failing_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(test_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(test_orchestrator) @@ -168,9 +168,9 @@ def test_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(failing_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(test_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(test_orchestrator) diff --git a/tests/durabletask/entities/test_function_based_entities_e2e.py b/tests/durabletask/entities/test_function_based_entities_e2e.py index e19a6c4e..9b1a301e 100644 --- a/tests/durabletask/entities/test_function_based_entities_e2e.py +++ b/tests/durabletask/entities/test_function_based_entities_e2e.py @@ -37,10 +37,10 @@ def empty_entity(ctx: entities.EntityContext, _): w.add_entity(empty_entity, name="EntityNameCustom") w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity") - c.signal_entity(entity_id, "do_nothing") - time.sleep(2) # wait for the signal to be processed + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed assert invoked @@ -59,14 +59,14 @@ def empty_entity(ctx: entities.EntityContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - entity_id = entities.EntityInstanceId("empty_entity", "testEntity") - c.signal_entity(entity_id, "do_nothing") - time.sleep(2) # wait for the signal to be processed - state = c.get_entity(entity_id, include_state=True) - assert state is not None - assert state.id == entity_id - assert state.get_state(int) == 1 + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + state = c.get_entity(entity_id, include_state=True) + assert state is not None + assert state.id == entity_id + assert state.get_state(int) == 1 assert invoked @@ -90,10 +90,10 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity, name="EntityNameCustom") w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) - time.sleep(2) # wait for the signal to be processed + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed assert invoked assert state is not None @@ -125,9 +125,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert invoked assert state is not None @@ -160,14 +160,14 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) - # Call a second time to ensure the entity is still responsive - # after being locked and unlocked - id_2 = c.schedule_new_orchestration(empty_orchestrator) - state_2 = c.wait_for_orchestration_completion(id_2, timeout=30) + # Call a second time to ensure the entity is still responsive + # after being locked and unlocked + id_2 = c.schedule_new_orchestration(empty_orchestrator) + state_2 = c.wait_for_orchestration_completion(id_2, timeout=30) assert invoked assert state is not None @@ -213,10 +213,10 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) - time.sleep(2) # wait for the entity-to-entity signal to be processed + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the entity-to-entity signal to be processed assert invoked assert state is not None @@ -246,11 +246,11 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - c.signal_entity( - entities.EntityInstanceId("empty_entity", "testEntity"), - "start_orchestration") - time.sleep(3) # wait for the signal and orchestration to be processed + with client.TaskHubGrpcClient(host_address=HOST) as c: + c.signal_entity( + entities.EntityInstanceId("empty_entity", "testEntity"), + "start_orchestration") + time.sleep(3) # wait for the signal and orchestration to be processed assert invoked @@ -276,9 +276,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(empty_orchestrator) @@ -310,12 +310,12 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - time.sleep(2) # wait for initial setup - id = c.schedule_new_orchestration(empty_orchestrator) - c.wait_for_orchestration_completion(id, timeout=30) - id = c.schedule_new_orchestration(empty_orchestrator) - c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + time.sleep(2) # wait for initial setup + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) assert invoke_count == 2 @@ -339,19 +339,19 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_entity(empty_entity) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - time.sleep(2) # wait for initial setup - id = c.schedule_new_orchestration(empty_orchestrator) - c.wait_for_orchestration_completion(id, timeout=30) - id = c.schedule_new_orchestration(empty_orchestrator) - c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + time.sleep(2) # wait for initial setup + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) assert invoke_count == 2 def test_get_entity_not_found(): """Test that get_entity returns None for a non-existent entity.""" - c = client.TaskHubGrpcClient(host_address=HOST) - entity_id = entities.EntityInstanceId("counter", "nonexistent") - metadata = c.get_entity(entity_id, include_state=True) - assert metadata is None + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("counter", "nonexistent") + metadata = c.get_entity(entity_id, include_state=True) + assert metadata is None diff --git a/tests/durabletask/test_batch_actions.py b/tests/durabletask/test_batch_actions.py index b0778503..60d313a3 100644 --- a/tests/durabletask/test_batch_actions.py +++ b/tests/durabletask/test_batch_actions.py @@ -38,20 +38,20 @@ def failing_orchestrator(ctx: task.OrchestrationContext, _): def test_get_all_orchestration_states(backend): worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_orchestrator(empty_orchestrator) worker.start() try: - id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") - c.wait_for_orchestration_completion(id, timeout=30) - - all_orchestrations = c.get_all_orchestration_states() - query = client.OrchestrationQuery() - query.fetch_inputs_and_outputs = True - all_orchestrations_with_state = c.get_all_orchestration_states(query) - this_orch = c.get_orchestration_state(id) + with TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") + c.wait_for_orchestration_completion(id, timeout=30) + + all_orchestrations = c.get_all_orchestration_states() + query = client.OrchestrationQuery() + query.fetch_inputs_and_outputs = True + all_orchestrations_with_state = c.get_all_orchestration_states(query) + this_orch = c.get_orchestration_state(id) finally: worker.stop() @@ -79,35 +79,35 @@ def test_get_all_orchestration_states(backend): def test_get_orchestration_state_by_status(backend): worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_orchestrator(empty_orchestrator) worker.add_orchestrator(failing_orchestrator) worker.start() try: - # Schedule completed orchestration - completed_id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") - c.wait_for_orchestration_completion(completed_id, timeout=30) - - # Schedule failed orchestration - failed_id = c.schedule_new_orchestration(failing_orchestrator) - try: - c.wait_for_orchestration_completion(failed_id, timeout=30) - except client.OrchestrationFailedError: - pass # Expected failure - - # Query by completed status - query = client.OrchestrationQuery() - query.runtime_status = [client.OrchestrationStatus.COMPLETED] - query.fetch_inputs_and_outputs = True - completed_orchestrations = c.get_all_orchestration_states(query) - - # Query by failed status - query = client.OrchestrationQuery() - query.runtime_status = [client.OrchestrationStatus.FAILED] - query.fetch_inputs_and_outputs = True - failed_orchestrations = c.get_all_orchestration_states(query) + with TaskHubGrpcClient(host_address=HOST) as c: + # Schedule completed orchestration + completed_id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") + c.wait_for_orchestration_completion(completed_id, timeout=30) + + # Schedule failed orchestration + failed_id = c.schedule_new_orchestration(failing_orchestrator) + try: + c.wait_for_orchestration_completion(failed_id, timeout=30) + except client.OrchestrationFailedError: + pass # Expected failure + + # Query by completed status + query = client.OrchestrationQuery() + query.runtime_status = [client.OrchestrationStatus.COMPLETED] + query.fetch_inputs_and_outputs = True + completed_orchestrations = c.get_all_orchestration_states(query) + + # Query by failed status + query = client.OrchestrationQuery() + query.runtime_status = [client.OrchestrationStatus.FAILED] + query.fetch_inputs_and_outputs = True + failed_orchestrations = c.get_all_orchestration_states(query) finally: worker.stop() @@ -124,36 +124,36 @@ def test_get_orchestration_state_by_status(backend): def test_get_orchestration_state_by_time_range(backend): worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_orchestrator(empty_orchestrator) worker.start() try: - # Get current time - before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) - - # Schedule orchestration - id = c.schedule_new_orchestration(empty_orchestrator, input="TimeTest") - c.wait_for_orchestration_completion(id, timeout=30) - - after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) - - # Query by time range - query = client.OrchestrationQuery( - created_time_from=before_creation, - created_time_to=after_creation, - fetch_inputs_and_outputs=True - ) - orchestrations_in_range = c.get_all_orchestration_states(query) - - # Query outside time range - query = client.OrchestrationQuery( - created_time_from=after_creation, - created_time_to=after_creation + timedelta(hours=1), - fetch_inputs_and_outputs=True - ) - orchestrations_outside_range = c.get_all_orchestration_states(query) + with TaskHubGrpcClient(host_address=HOST) as c: + # Get current time + before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) + + # Schedule orchestration + id = c.schedule_new_orchestration(empty_orchestrator, input="TimeTest") + c.wait_for_orchestration_completion(id, timeout=30) + + after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) + + # Query by time range + query = client.OrchestrationQuery( + created_time_from=before_creation, + created_time_to=after_creation, + fetch_inputs_and_outputs=True + ) + orchestrations_in_range = c.get_all_orchestration_states(query) + + # Query outside time range + query = client.OrchestrationQuery( + created_time_from=after_creation, + created_time_to=after_creation + timedelta(hours=1), + fetch_inputs_and_outputs=True + ) + orchestrations_outside_range = c.get_all_orchestration_states(query) finally: worker.stop() @@ -172,25 +172,25 @@ def emit(self, record): handler = ListHandler() worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST, log_handler=handler) worker.add_orchestrator(empty_orchestrator) worker.start() try: - # Create at least 3 orchestrations to test the limit - ids = [] - for i in range(3): - id = c.schedule_new_orchestration(empty_orchestrator, input=f"Test{i}") - ids.append(id) - - # Wait for all to complete - for id in ids: - c.wait_for_orchestration_completion(id, timeout=30) - - # Query with max_instance_count=2 - query = client.OrchestrationQuery(max_instance_count=2) - orchestrations = c.get_all_orchestration_states(query) + with TaskHubGrpcClient(host_address=HOST, log_handler=handler) as c: + # Create at least 3 orchestrations to test the limit + ids = [] + for i in range(3): + id = c.schedule_new_orchestration(empty_orchestrator, input=f"Test{i}") + ids.append(id) + + # Wait for all to complete + for id in ids: + c.wait_for_orchestration_completion(id, timeout=30) + + # Query with max_instance_count=2 + query = client.OrchestrationQuery(max_instance_count=2) + orchestrations = c.get_all_orchestration_states(query) finally: worker.stop() @@ -203,135 +203,135 @@ def emit(self, record): def test_purge_orchestration(backend): worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_orchestrator(empty_orchestrator) worker.start() try: - # Schedule and complete orchestration - id = c.schedule_new_orchestration(empty_orchestrator, input="ToPurge") - c.wait_for_orchestration_completion(id, timeout=30) + with TaskHubGrpcClient(host_address=HOST) as c: + # Schedule and complete orchestration + id = c.schedule_new_orchestration(empty_orchestrator, input="ToPurge") + c.wait_for_orchestration_completion(id, timeout=30) - # Verify it exists - state_before = c.get_orchestration_state(id) - assert state_before is not None + # Verify it exists + state_before = c.get_orchestration_state(id) + assert state_before is not None - # Purge the orchestration - result = c.purge_orchestration(id, recursive=True) + # Purge the orchestration + result = c.purge_orchestration(id, recursive=True) - # Verify purge result - assert result.deleted_instance_count >= 1 + # Verify purge result + assert result.deleted_instance_count >= 1 - # Verify it no longer exists - state_after = c.get_orchestration_state(id) - assert state_after is None + # Verify it no longer exists + state_after = c.get_orchestration_state(id) + assert state_after is None finally: worker.stop() def test_purge_orchestrations_by_status(backend): worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_orchestrator(failing_orchestrator) worker.start() try: - # Schedule and let it fail - failed_id = c.schedule_new_orchestration(failing_orchestrator) - try: - c.wait_for_orchestration_completion(failed_id, timeout=30) - except client.OrchestrationFailedError: - pass # Expected failure - - # Verify it exists and is failed - state_before = c.get_orchestration_state(failed_id) - assert state_before is not None - assert state_before.runtime_status == client.OrchestrationStatus.FAILED - - # Purge failed orchestrations - result = c.purge_orchestrations_by( - runtime_status=[client.OrchestrationStatus.FAILED], - recursive=True - ) - - # Verify purge result - assert result.deleted_instance_count >= 1 - - # Verify the failed orchestration no longer exists - state_after = c.get_orchestration_state(failed_id) - assert state_after is None + with TaskHubGrpcClient(host_address=HOST) as c: + # Schedule and let it fail + failed_id = c.schedule_new_orchestration(failing_orchestrator) + try: + c.wait_for_orchestration_completion(failed_id, timeout=30) + except client.OrchestrationFailedError: + pass # Expected failure + + # Verify it exists and is failed + state_before = c.get_orchestration_state(failed_id) + assert state_before is not None + assert state_before.runtime_status == client.OrchestrationStatus.FAILED + + # Purge failed orchestrations + result = c.purge_orchestrations_by( + runtime_status=[client.OrchestrationStatus.FAILED], + recursive=True + ) + + # Verify purge result + assert result.deleted_instance_count >= 1 + + # Verify the failed orchestration no longer exists + state_after = c.get_orchestration_state(failed_id) + assert state_after is None finally: worker.stop() def test_purge_orchestrations_by_time_range(backend): worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_orchestrator(empty_orchestrator) worker.start() try: - # Get current time - before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) + with TaskHubGrpcClient(host_address=HOST) as c: + # Get current time + before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) - # Schedule orchestration - id = c.schedule_new_orchestration(empty_orchestrator, input="ToPurgeByTime") - c.wait_for_orchestration_completion(id, timeout=30) + # Schedule orchestration + id = c.schedule_new_orchestration(empty_orchestrator, input="ToPurgeByTime") + c.wait_for_orchestration_completion(id, timeout=30) - after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) + after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) - # Verify it exists - state_before = c.get_orchestration_state(id) - assert state_before is not None + # Verify it exists + state_before = c.get_orchestration_state(id) + assert state_before is not None - # Purge by time range - result = c.purge_orchestrations_by( - created_time_from=before_creation, - created_time_to=after_creation, - runtime_status=[client.OrchestrationStatus.COMPLETED], - recursive=True - ) + # Purge by time range + result = c.purge_orchestrations_by( + created_time_from=before_creation, + created_time_to=after_creation, + runtime_status=[client.OrchestrationStatus.COMPLETED], + recursive=True + ) - # Verify purge result - assert result.deleted_instance_count >= 1 + # Verify purge result + assert result.deleted_instance_count >= 1 - # Verify it no longer exists - state_after = c.get_orchestration_state(id) - assert state_after is None + # Verify it no longer exists + state_after = c.get_orchestration_state(id) + assert state_after is None finally: worker.stop() def test_list_instance_ids_paginates_terminal_instances(backend): worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_orchestrator(empty_orchestrator) worker.add_orchestrator(failing_orchestrator) worker.start() try: - completed_id = c.schedule_new_orchestration(empty_orchestrator, input='done') - c.wait_for_orchestration_completion(completed_id, timeout=30) - - failed_id = c.schedule_new_orchestration(failing_orchestrator) - failed_state = c.wait_for_orchestration_completion(failed_id, timeout=30) - - window_start = datetime.now(timezone.utc) - timedelta(minutes=1) - first_page = c.list_instance_ids( - runtime_status=[client.OrchestrationStatus.COMPLETED, client.OrchestrationStatus.FAILED], - completed_time_from=window_start, - page_size=1, - ) - second_page = c.list_instance_ids( - runtime_status=[client.OrchestrationStatus.COMPLETED, client.OrchestrationStatus.FAILED], - completed_time_from=window_start, - page_size=1, - continuation_token=first_page.continuation_token, - ) + with TaskHubGrpcClient(host_address=HOST) as c: + completed_id = c.schedule_new_orchestration(empty_orchestrator, input='done') + c.wait_for_orchestration_completion(completed_id, timeout=30) + + failed_id = c.schedule_new_orchestration(failing_orchestrator) + failed_state = c.wait_for_orchestration_completion(failed_id, timeout=30) + + window_start = datetime.now(timezone.utc) - timedelta(minutes=1) + first_page = c.list_instance_ids( + runtime_status=[client.OrchestrationStatus.COMPLETED, client.OrchestrationStatus.FAILED], + completed_time_from=window_start, + page_size=1, + ) + second_page = c.list_instance_ids( + runtime_status=[client.OrchestrationStatus.COMPLETED, client.OrchestrationStatus.FAILED], + completed_time_from=window_start, + page_size=1, + continuation_token=first_page.continuation_token, + ) finally: worker.stop() @@ -357,30 +357,30 @@ def counter_entity(ctx: entities.EntityContext, input): return ctx.get_state(int, 0) worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_entity(counter_entity) worker.start() try: - # Create entity - entity_id = entities.EntityInstanceId("counter_entity", "testCounter1") - c.signal_entity(entity_id, "add", input=5) - time.sleep(3) # Wait for signal to be processed - - # Get all entities without state - query = client.EntityQuery(include_state=False) - all_entities = c.get_all_entities(query) - assert len([e for e in all_entities if e.id == entity_id]) == 1 - entity_without_state = [e for e in all_entities if e.id == entity_id][0] - assert entity_without_state.get_state(int) is None - - # Get all entities with state - query = client.EntityQuery(include_state=True) - all_entities_with_state = c.get_all_entities(query) - assert len([e for e in all_entities_with_state if e.id == entity_id]) == 1 - entity_with_state = [e for e in all_entities_with_state if e.id == entity_id][0] - assert entity_with_state.get_state(int) == 5 + with TaskHubGrpcClient(host_address=HOST) as c: + # Create entity + entity_id = entities.EntityInstanceId("counter_entity", "testCounter1") + c.signal_entity(entity_id, "add", input=5) + time.sleep(3) # Wait for signal to be processed + + # Get all entities without state + query = client.EntityQuery(include_state=False) + all_entities = c.get_all_entities(query) + assert len([e for e in all_entities if e.id == entity_id]) == 1 + entity_without_state = [e for e in all_entities if e.id == entity_id][0] + assert entity_without_state.get_state(int) is None + + # Get all entities with state + query = client.EntityQuery(include_state=True) + all_entities_with_state = c.get_all_entities(query) + assert len([e for e in all_entities_with_state if e.id == entity_id]) == 1 + entity_with_state = [e for e in all_entities_with_state if e.id == entity_id][0] + assert entity_with_state.get_state(int) == 5 finally: worker.stop() @@ -391,32 +391,32 @@ def counter_entity(ctx: entities.EntityContext, input): ctx.set_state(input) worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_entity(counter_entity) worker.start() try: - # Create entities with different prefixes - entity_id_1 = entities.EntityInstanceId("counter_entity", "prefix1_counter") - entity_id_2 = entities.EntityInstanceId("counter_entity", "prefix2_counter") - - c.signal_entity(entity_id_1, "set", input=10) - c.signal_entity(entity_id_2, "set", input=20) - time.sleep(3) # Wait for signals to be processed - - # Query by prefix - query = client.EntityQuery( - instance_id_starts_with="@counter_entity@prefix1", - include_state=True - ) - entities_prefix1 = c.get_all_entities(query) - - query = client.EntityQuery( - instance_id_starts_with="@counter_entity@prefix2", - include_state=True - ) - entities_prefix2 = c.get_all_entities(query) + with TaskHubGrpcClient(host_address=HOST) as c: + # Create entities with different prefixes + entity_id_1 = entities.EntityInstanceId("counter_entity", "prefix1_counter") + entity_id_2 = entities.EntityInstanceId("counter_entity", "prefix2_counter") + + c.signal_entity(entity_id_1, "set", input=10) + c.signal_entity(entity_id_2, "set", input=20) + time.sleep(3) # Wait for signals to be processed + + # Query by prefix + query = client.EntityQuery( + instance_id_starts_with="@counter_entity@prefix1", + include_state=True + ) + entities_prefix1 = c.get_all_entities(query) + + query = client.EntityQuery( + instance_id_starts_with="@counter_entity@prefix2", + include_state=True + ) + entities_prefix2 = c.get_all_entities(query) finally: worker.stop() @@ -433,36 +433,36 @@ def simple_entity(ctx: entities.EntityContext, input): ctx.set_state(input) worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_entity(simple_entity) worker.start() try: - # Get current time - before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) - - # Create entity - entity_id = entities.EntityInstanceId("simple_entity", "timeTestEntity") - c.signal_entity(entity_id, "set", input="test_value") - time.sleep(3) # Wait for signal to be processed - - after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) - - # Query by time range - query = client.EntityQuery( - last_modified_from=before_creation, - last_modified_to=after_creation, - include_state=True - ) - entities_in_range = c.get_all_entities(query) - - # Query outside time range - query = client.EntityQuery( - last_modified_from=after_creation, - last_modified_to=after_creation + timedelta(hours=1) - ) - entities_outside_range = c.get_all_entities(query) + with TaskHubGrpcClient(host_address=HOST) as c: + # Get current time + before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) + + # Create entity + entity_id = entities.EntityInstanceId("simple_entity", "timeTestEntity") + c.signal_entity(entity_id, "set", input="test_value") + time.sleep(3) # Wait for signal to be processed + + after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) + + # Query by time range + query = client.EntityQuery( + last_modified_from=before_creation, + last_modified_to=after_creation, + include_state=True + ) + entities_in_range = c.get_all_entities(query) + + # Query outside time range + query = client.EntityQuery( + last_modified_from=after_creation, + last_modified_to=after_creation + timedelta(hours=1) + ) + entities_outside_range = c.get_all_entities(query) finally: worker.stop() @@ -475,22 +475,22 @@ class EmptyEntity(entities.DurableEntity): pass worker = TaskHubGrpcWorker(host_address=HOST) - c = TaskHubGrpcClient(host_address=HOST) worker.add_entity(EmptyEntity) worker.start() try: - # Create an entity and then delete its state to make it empty - entity_id = entities.EntityInstanceId("EmptyEntity", "toClean") - c.signal_entity(entity_id, "delete") - time.sleep(3) # Wait for signal to be processed - - # Clean entity storage - result = c.clean_entity_storage( - remove_empty_entities=True, - release_orphaned_locks=True - ) + with TaskHubGrpcClient(host_address=HOST) as c: + # Create an entity and then delete its state to make it empty + entity_id = entities.EntityInstanceId("EmptyEntity", "toClean") + c.signal_entity(entity_id, "delete") + time.sleep(3) # Wait for signal to be processed + + # Clean entity storage + result = c.clean_entity_storage( + remove_empty_entities=True, + release_orphaned_locks=True + ) finally: worker.stop() diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index b219e550..2fe48517 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -744,6 +744,130 @@ def test_sync_client_does_not_recreate_caller_owned_channel(): provided_channel.close.assert_not_called() +def test_sync_client_context_manager_returns_self_and_calls_close(): + with ( + patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), + patch("durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock()), + ): + client = TaskHubGrpcClient(host_address=HOST_ADDRESS) + with patch.object(client, "close", wraps=client.close) as spy_close: + with client as entered: + assert entered is client + spy_close.assert_called_once_with() + + +def test_sync_client_context_manager_closes_sdk_owned_channel(): + channel = MagicMock(name="sdk-owned-channel") + with ( + patch("durabletask.client.shared.get_grpc_channel", return_value=channel), + patch("durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock()), + ): + with TaskHubGrpcClient(host_address=HOST_ADDRESS) as client: + assert client._closing is False + assert client._owns_channel is True + + assert client._closing is True + channel.close.assert_called_once_with() + + +def test_sync_client_context_manager_preserves_caller_owned_channel(): + provided_channel = MagicMock(name="caller-owned-channel") + with patch("durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock()): + with TaskHubGrpcClient( + channel=provided_channel, host_address=HOST_ADDRESS) as client: + assert client._channel is provided_channel + assert client._owns_channel is False + + # close() is a no-op for caller-owned channels: the caller retains + # ownership and is responsible for closing the channel themselves. + provided_channel.close.assert_not_called() + assert client._closing is False + + +def test_sync_client_context_manager_propagates_exception_and_calls_close(): + channel = MagicMock(name="sdk-owned-channel") + with ( + patch("durabletask.client.shared.get_grpc_channel", return_value=channel), + patch("durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock()), + ): + client = TaskHubGrpcClient(host_address=HOST_ADDRESS) + raised = False + try: + with client: + raise RuntimeError("boom") + except RuntimeError as exc: + raised = True + assert str(exc) == "boom" + assert raised, "RuntimeError raised inside the with block must propagate" + + assert client._closing is True + channel.close.assert_called_once_with() + + +def test_sync_client_context_manager_cleans_up_resiliency_state(): + """Regression: exiting the ``with`` block tears down resiliency state + introduced by PR #135 (retired channels + recreate thread). + """ + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + first_stub = MagicMock() + first_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + second_stub = MagicMock() + second_stub.GetInstance.return_value = MagicMock(exists=False) + close_timer = MagicMock(name="close-timer") + + with ( + patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel], + ), + patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", + side_effect=[first_stub, second_stub], + ), + patch("threading.Timer", return_value=close_timer), + ): + with TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + )) as client: + install_resilient_test_stubs(client) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + # Wait for the fire-and-forget recreate to finish so the retired + # channel + timer are registered before the context manager exits. + assert client._recreate_done_event.wait(timeout=5.0) + + assert client._closing is True + close_timer.cancel.assert_called_once_with() + first_channel.close.assert_called_once_with() + second_channel.close.assert_called_once_with() + assert client._retired_channels == {} + # The recreate thread (if any) must have been joined during __exit__. + recreate_thread = client._recreate_thread + if recreate_thread is not None: + assert not recreate_thread.is_alive() + + +def test_sync_client_supports_context_manager_reentry_after_use(): + """``with`` is idempotent against repeated ``close()`` calls.""" + channel = MagicMock(name="sdk-owned-channel") + with ( + patch("durabletask.client.shared.get_grpc_channel", return_value=channel), + patch("durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock()), + ): + client = TaskHubGrpcClient(host_address=HOST_ADDRESS) + with client: + pass + # Calling close() again after the context manager has already torn + # down the channel must not raise (mirrors how ``AsyncTaskHubGrpcClient`` + # tolerates explicit close-after-aexit during shutdown sequences). + client.close() + + assert channel.close.call_count >= 1 + + def test_sync_client_recreate_cooldown_prevents_immediate_repeated_recreation(): first_channel = MagicMock(name="first-channel") second_channel = MagicMock(name="second-channel") diff --git a/tests/durabletask/test_large_payload_e2e.py b/tests/durabletask/test_large_payload_e2e.py index 3ff7af85..832d4f26 100644 --- a/tests/durabletask/test_large_payload_e2e.py +++ b/tests/durabletask/test_large_payload_e2e.py @@ -138,9 +138,9 @@ def echo(ctx: task.OrchestrationContext, inp: str): w.add_orchestrator(echo) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) - inst_id = c.schedule_new_orchestration(echo, input=large_input) - state = c.wait_for_orchestration_completion(inst_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) as c: + inst_id = c.schedule_new_orchestration(echo, input=large_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -161,9 +161,9 @@ def orchestrator(ctx: task.OrchestrationContext, size_kb: int): w.add_activity(produce_large) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) - inst_id = c.schedule_new_orchestration(orchestrator, input=10) # 10 KB - state = c.wait_for_orchestration_completion(inst_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) as c: + inst_id = c.schedule_new_orchestration(orchestrator, input=10) # 10 KB + state = c.wait_for_orchestration_completion(inst_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -182,9 +182,9 @@ def transform(ctx: task.OrchestrationContext, inp: dict): w.add_orchestrator(transform) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) - inst_id = c.schedule_new_orchestration(transform, input=large_input) - state = c.wait_for_orchestration_completion(inst_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) as c: + inst_id = c.schedule_new_orchestration(transform, input=large_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -208,12 +208,12 @@ def wait_for_event(ctx: task.OrchestrationContext, _): w.add_orchestrator(wait_for_event) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) - inst_id = c.schedule_new_orchestration(wait_for_event) - c.wait_for_orchestration_start(inst_id, timeout=10) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) as c: + inst_id = c.schedule_new_orchestration(wait_for_event) + c.wait_for_orchestration_start(inst_id, timeout=10) - c.raise_orchestration_event(inst_id, "big_event", data=large_event) - state = c.wait_for_orchestration_completion(inst_id, timeout=30) + c.raise_orchestration_event(inst_id, "big_event", data=large_event) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -235,12 +235,12 @@ def long_running(ctx: task.OrchestrationContext, _): w.add_orchestrator(long_running) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) - inst_id = c.schedule_new_orchestration(long_running) - c.wait_for_orchestration_start(inst_id, timeout=10) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) as c: + inst_id = c.schedule_new_orchestration(long_running) + c.wait_for_orchestration_start(inst_id, timeout=10) - c.terminate_orchestration(inst_id, output=large_output) - state = c.wait_for_orchestration_completion(inst_id, timeout=30) + c.terminate_orchestration(inst_id, output=large_output) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.TERMINATED @@ -264,9 +264,9 @@ def fan_out(ctx: task.OrchestrationContext, count: int): w.add_activity(make_large) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) - inst_id = c.schedule_new_orchestration(fan_out, input=5) - state = c.wait_for_orchestration_completion(inst_id, timeout=60) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) as c: + inst_id = c.schedule_new_orchestration(fan_out, input=5) + state = c.wait_for_orchestration_completion(inst_id, timeout=60) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -291,9 +291,9 @@ def echo(ctx: task.OrchestrationContext, inp: str): w.add_orchestrator(echo) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) - inst_id = c.schedule_new_orchestration(echo, input=large_input) - c.wait_for_orchestration_completion(inst_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) as c: + inst_id = c.schedule_new_orchestration(echo, input=large_input) + c.wait_for_orchestration_completion(inst_id, timeout=30) # Verify blobs were actually created in the Azurite container svc = azure_blob.BlobServiceClient.from_connection_string( @@ -333,9 +333,9 @@ def echo(ctx: task.OrchestrationContext, inp: str): w.add_orchestrator(echo) w.start() - c = client.TaskHubGrpcClient(host_address=HOST, payload_store=store) - inst_id = c.schedule_new_orchestration(echo, input=small_input) - state = c.wait_for_orchestration_completion(inst_id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST, payload_store=store) as c: + inst_id = c.schedule_new_orchestration(echo, input=small_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 065a8b11..a8d47274 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -37,9 +37,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(empty_orchestrator) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator, tags={'Tagged': 'true'}) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator, tags={'Tagged': 'true'}) + state = c.wait_for_orchestration_completion(id, timeout=30) assert invoked assert state is not None @@ -70,10 +70,10 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): w.add_activity(plus_one) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(sequence, input=1, tags={'Orchestration': 'Sequence'}) - state = task_hub_client.wait_for_orchestration_completion( - id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(sequence, input=1, tags={'Orchestration': 'Sequence'}) + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) assert state is not None assert state.name == task.get_name(sequence) @@ -115,9 +115,9 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): w.add_activity(increment_counter) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(orchestrator) @@ -143,13 +143,10 @@ def simple(ctx: task.OrchestrationContext, value: int): w.add_activity(plus_one) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - try: + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: instance_id = task_hub_client.schedule_new_orchestration(simple, input=1) state = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) events = task_hub_client.get_orchestration_history(instance_id) - finally: - task_hub_client.close() assert state is not None assert len(events) > 0 @@ -186,9 +183,9 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): w.add_orchestrator(parent_orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -211,9 +208,9 @@ def parent_orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(parent_orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -233,12 +230,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.start() # Start the orchestration and immediately raise events to it. - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - task_hub_client.raise_orchestration_event(id, 'A', data='a') - task_hub_client.raise_orchestration_event(id, 'B', data='b') - task_hub_client.raise_orchestration_event(id, 'C', data='c') - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + task_hub_client.raise_orchestration_event(id, 'A', data='a') + task_hub_client.raise_orchestration_event(id, 'B', data='b') + task_hub_client.raise_orchestration_event(id, 'C', data='c') + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -261,11 +258,11 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.start() # Start the orchestration and immediately raise events to it. - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - if raise_event: - task_hub_client.raise_orchestration_event(id, 'Approval') - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + if raise_event: + task_hub_client.raise_orchestration_event(id, 'Approval') + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -284,35 +281,35 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - - # Suspend the orchestration and wait for it to go into the SUSPENDED state - task_hub_client.suspend_orchestration(id) - deadline = time.time() + 10 - while state.runtime_status == client.OrchestrationStatus.RUNNING: - assert time.time() < deadline, "Timed out waiting for SUSPENDED status" - time.sleep(0.1) - state = task_hub_client.get_orchestration_state(id) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_start(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.SUSPENDED - # Raise an event to the orchestration and confirm that it does NOT complete - task_hub_client.raise_orchestration_event(id, "my_event", data=42) - try: - state = task_hub_client.wait_for_orchestration_completion(id, timeout=3) - assert False, "Orchestration should not have completed" - except TimeoutError: - pass - - # Resume the orchestration and wait for it to complete - task_hub_client.resume_orchestration(id) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps(42) + # Suspend the orchestration and wait for it to go into the SUSPENDED state + task_hub_client.suspend_orchestration(id) + deadline = time.time() + 10 + while state.runtime_status == client.OrchestrationStatus.RUNNING: + assert time.time() < deadline, "Timed out waiting for SUSPENDED status" + time.sleep(0.1) + state = task_hub_client.get_orchestration_state(id) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.SUSPENDED + + # Raise an event to the orchestration and confirm that it does NOT complete + task_hub_client.raise_orchestration_event(id, "my_event", data=42) + try: + task_hub_client.wait_for_orchestration_completion(id, timeout=3) + assert False, "Orchestration should not have completed" + except TimeoutError: + pass + + # Resume the orchestration and wait for it to complete + task_hub_client.resume_orchestration(id) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) def test_terminate(): @@ -324,17 +321,17 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.RUNNING + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING - task_hub_client.terminate_orchestration(id, output="some reason for termination") - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.TERMINATED - assert state.serialized_output == json.dumps("some reason for termination") + task_hub_client.terminate_orchestration(id, output="some reason for termination") + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") def test_terminate_recursive(): @@ -351,26 +348,26 @@ def child(ctx: task.OrchestrationContext, _): w.add_orchestrator(child) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(root) - state = task_hub_client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.RUNNING + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(root) + state = task_hub_client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING - # Terminate root orchestration(recursive set to True by default) - task_hub_client.terminate_orchestration(id, output="some reason for termination") - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.TERMINATED + # Terminate root orchestration(recursive set to True by default) + task_hub_client.terminate_orchestration(id, output="some reason for termination") + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED - # Verify that child orchestration is also terminated - task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.TERMINATED + # Verify that child orchestration is also terminated + task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED - task_hub_client.purge_orchestration(id) - state = task_hub_client.get_orchestration_state(id) - assert state is None + task_hub_client.purge_orchestration(id) + state = task_hub_client.get_orchestration_state(id) + assert state is None def test_restart_with_same_instance_id(): @@ -387,21 +384,21 @@ def say_hello(ctx: task.ActivityContext, input: str): w.add_activity(say_hello) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps("Hello, World!") + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") - # Restart the orchestration with the same instance ID - restarted_id = task_hub_client.restart_orchestration(id) - assert restarted_id == id + # Restart the orchestration with the same instance ID + restarted_id = task_hub_client.restart_orchestration(id) + assert restarted_id == id - state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps("Hello, World!") + state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") def test_restart_with_new_instance_id(): @@ -418,20 +415,20 @@ def say_hello(ctx: task.ActivityContext, input: str): w.add_activity(say_hello) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED - # Restart the orchestration with a new instance ID - restarted_id = task_hub_client.restart_orchestration(id, restart_with_new_instance_id=True) - assert restarted_id != id + # Restart the orchestration with a new instance ID + restarted_id = task_hub_client.restart_orchestration(id, restart_with_new_instance_id=True) + assert restarted_id != id - state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps("Hello, World!") + state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") def test_continue_as_new(): @@ -453,20 +450,20 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator, input=0) - task_hub_client.raise_orchestration_event(id, "my_event", data=1) - task_hub_client.raise_orchestration_event(id, "my_event", data=2) - task_hub_client.raise_orchestration_event(id, "my_event", data=3) - task_hub_client.raise_orchestration_event(id, "my_event", data=4) - task_hub_client.raise_orchestration_event(id, "my_event", data=5) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator, input=0) + task_hub_client.raise_orchestration_event(id, "my_event", data=1) + task_hub_client.raise_orchestration_event(id, "my_event", data=2) + task_hub_client.raise_orchestration_event(id, "my_event", data=3) + task_hub_client.raise_orchestration_event(id, "my_event", data=4) + task_hub_client.raise_orchestration_event(id, "my_event", data=5) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps(all_results) - assert state.serialized_input == json.dumps(4) - assert all_results == [1, 2, 3, 4, 5] + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(all_results) + assert state.serialized_input == json.dumps(4) + assert all_results == [1, 2, 3, 4, 5] def test_retry_policies(): @@ -510,18 +507,18 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): w.add_activity(throw_activity_with_retry) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(parent_orchestrator_with_retry) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.FAILED - assert state.failure_details is not None - assert state.failure_details.error_type == "TaskFailedError" - assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") - assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") - assert state.failure_details.stack_trace is not None - assert throw_activity_counter == 9 - assert child_orch_counter == 3 + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(parent_orchestrator_with_retry) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 9 + assert child_orch_counter == 3 def test_retry_timeout(): @@ -550,16 +547,16 @@ def throw_activity(ctx: task.ActivityContext, _): w.add_activity(throw_activity) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(mock_orchestrator) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == client.OrchestrationStatus.FAILED - assert state.failure_details is not None - assert state.failure_details.error_type == "TaskFailedError" - assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") - assert state.failure_details.stack_trace is not None - assert throw_activity_counter == 4 + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(mock_orchestrator) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 4 def test_custom_status(): @@ -571,9 +568,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(empty_orchestrator) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(empty_orchestrator) @@ -602,9 +599,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_activity(noop) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(empty_orchestrator) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(empty_orchestrator) @@ -638,11 +635,11 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - if raise_event: - task_hub_client.raise_orchestration_event(id, 'Approval') - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + if raise_event: + task_hub_client.raise_orchestration_event(id, 'Approval') + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -674,10 +671,10 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - task_hub_client.raise_orchestration_event(id, winning_event) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + task_hub_client.raise_orchestration_event(id, winning_event) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -714,9 +711,9 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration(orchestrator) - state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED diff --git a/tests/durabletask/test_orchestration_versioning_e2e.py b/tests/durabletask/test_orchestration_versioning_e2e.py index d9ba719d..60575bfd 100644 --- a/tests/durabletask/test_orchestration_versioning_e2e.py +++ b/tests/durabletask/test_orchestration_versioning_e2e.py @@ -36,11 +36,11 @@ def return_version(ctx: task.OrchestrationContext, _: None): )) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration( - return_version, version="2.5.0") - state = task_hub_client.wait_for_orchestration_completion( - id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration( + return_version, version="2.5.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -64,11 +64,11 @@ def simple(ctx: task.OrchestrationContext, _: None): )) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration( - simple, version="2.0.0") - state = task_hub_client.wait_for_orchestration_completion( - id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration( + simple, version="2.0.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.FAILED @@ -92,11 +92,11 @@ def simple(ctx: task.OrchestrationContext, _: None): )) w.start() - task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - id = task_hub_client.schedule_new_orchestration( - simple, version="1.1.0") - state = task_hub_client.wait_for_orchestration_completion( - id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as task_hub_client: + id = task_hub_client.schedule_new_orchestration( + simple, version="1.1.0") + state = task_hub_client.wait_for_orchestration_completion( + id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.FAILED diff --git a/tests/durabletask/test_work_item_filters_e2e.py b/tests/durabletask/test_work_item_filters_e2e.py index 5eec8fe4..97370c72 100644 --- a/tests/durabletask/test_work_item_filters_e2e.py +++ b/tests/durabletask/test_work_item_filters_e2e.py @@ -63,9 +63,9 @@ def test_auto_filters_processes_matching_work_items(): w.use_work_item_filters() w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(_orchestrator_with_activity, input=5) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=5) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -95,9 +95,9 @@ def test_explicit_filters_matching(): w.use_work_item_filters(custom_filters) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(_orchestrator_with_activity, input=10) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=10) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -116,9 +116,9 @@ def test_no_filters_processes_all(): # Intentionally do NOT call use_work_item_filters() w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(_orchestrator_with_activity, input=7) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=7) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -138,9 +138,9 @@ def test_cleared_filters_processes_all(): w.use_work_item_filters(None) # then clear w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration(_orchestrator_with_activity, input=3) - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=3) + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -168,12 +168,12 @@ def add(self, amount: int): )) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - entity_id = entities.EntityInstanceId("counter", "myKey") - c.signal_entity(entity_id, "add", input=10) - time.sleep(2) # wait for the signal to be processed + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("counter", "myKey") + c.signal_entity(entity_id, "add", input=10) + time.sleep(2) # wait for the signal to be processed - state = c.get_entity(entity_id, include_state=True) + state = c.get_entity(entity_id, include_state=True) assert invoked assert state is not None @@ -203,26 +203,25 @@ def test_non_matching_orchestrator_not_processed(): )) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) + with client.TaskHubGrpcClient(host_address=HOST) as c: + # Schedule the non-matching orchestration — should NOT be processed + non_match_id = c.schedule_new_orchestration( + _orchestrator_with_activity, input=1) - # Schedule the non-matching orchestration — should NOT be processed - non_match_id = c.schedule_new_orchestration( - _orchestrator_with_activity, input=1) + # Schedule the matching orchestration — should complete + match_id = c.schedule_new_orchestration(_other_orchestrator) + match_state = c.wait_for_orchestration_completion( + match_id, timeout=30) - # Schedule the matching orchestration — should complete - match_id = c.schedule_new_orchestration(_other_orchestrator) - match_state = c.wait_for_orchestration_completion( - match_id, timeout=30) + # The matching orchestration completes normally + assert match_state is not None + assert match_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert match_state.serialized_output == '"other"' - # The matching orchestration completes normally - assert match_state is not None - assert match_state.runtime_status == client.OrchestrationStatus.COMPLETED - assert match_state.serialized_output == '"other"' - - # The non-matching orchestration should still be pending - non_match_state = c.get_orchestration_state(non_match_id) - assert non_match_state is not None - assert non_match_state.runtime_status == client.OrchestrationStatus.PENDING + # The non-matching orchestration should still be pending + non_match_state = c.get_orchestration_state(non_match_id) + assert non_match_state is not None + assert non_match_state.runtime_status == client.OrchestrationStatus.PENDING def test_non_matching_entity_not_processed(): @@ -249,12 +248,12 @@ def ping(self, _): )) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - c.signal_entity( - entities.EntityInstanceId("allowedentity", "k1"), "ping") - c.signal_entity( - entities.EntityInstanceId("blockedentity", "k1"), "ping") - time.sleep(3) # wait for processing + with client.TaskHubGrpcClient(host_address=HOST) as c: + c.signal_entity( + entities.EntityInstanceId("allowedentity", "k1"), "ping") + c.signal_entity( + entities.EntityInstanceId("blockedentity", "k1"), "ping") + time.sleep(3) # wait for processing assert matched_invoked assert not unmatched_invoked @@ -280,10 +279,10 @@ def test_strict_version_matching_orchestration_completes(): w.use_work_item_filters() # auto-generate with version w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - id = c.schedule_new_orchestration( - _simple_v2_orchestrator, input=10, version="2.0") - state = c.wait_for_orchestration_completion(id, timeout=30) + with client.TaskHubGrpcClient(host_address=HOST) as c: + id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=10, version="2.0") + state = c.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED @@ -301,24 +300,23 @@ def test_strict_version_incompatible_orchestration_stays_pending(): w.use_work_item_filters() w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - - # Schedule with version "1.0" — incompatible with the worker's "2.0" - bad_id = c.schedule_new_orchestration( - _simple_v2_orchestrator, input=5, version="1.0") + with client.TaskHubGrpcClient(host_address=HOST) as c: + # Schedule with version "1.0" — incompatible with the worker's "2.0" + bad_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=5, version="1.0") - # Schedule a compatible one so we can confirm the worker is active - good_id = c.schedule_new_orchestration( - _simple_v2_orchestrator, input=5, version="2.0") - good_state = c.wait_for_orchestration_completion(good_id, timeout=30) + # Schedule a compatible one so we can confirm the worker is active + good_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=5, version="2.0") + good_state = c.wait_for_orchestration_completion(good_id, timeout=30) - assert good_state is not None - assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert good_state is not None + assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED - # The incompatible orchestration must remain pending (not failed) - bad_state = c.get_orchestration_state(bad_id) - assert bad_state is not None - assert bad_state.runtime_status == client.OrchestrationStatus.PENDING + # The incompatible orchestration must remain pending (not failed) + bad_state = c.get_orchestration_state(bad_id) + assert bad_state is not None + assert bad_state.runtime_status == client.OrchestrationStatus.PENDING def test_strict_version_no_version_orchestration_stays_pending(): @@ -332,23 +330,22 @@ def test_strict_version_no_version_orchestration_stays_pending(): w.use_work_item_filters() w.start() - c = client.TaskHubGrpcClient(host_address=HOST) + with client.TaskHubGrpcClient(host_address=HOST) as c: + # Schedule without any version + no_ver_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1) - # Schedule without any version - no_ver_id = c.schedule_new_orchestration( - _simple_v2_orchestrator, input=1) + # Schedule a compatible one to prove the worker is running + good_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1, version="2.0") + good_state = c.wait_for_orchestration_completion(good_id, timeout=30) + assert good_state is not None + assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED - # Schedule a compatible one to prove the worker is running - good_id = c.schedule_new_orchestration( - _simple_v2_orchestrator, input=1, version="2.0") - good_state = c.wait_for_orchestration_completion(good_id, timeout=30) - assert good_state is not None - assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED - - # The unversioned orchestration must remain pending - no_ver_state = c.get_orchestration_state(no_ver_id) - assert no_ver_state is not None - assert no_ver_state.runtime_status == client.OrchestrationStatus.PENDING + # The unversioned orchestration must remain pending + no_ver_state = c.get_orchestration_state(no_ver_id) + assert no_ver_state is not None + assert no_ver_state.runtime_status == client.OrchestrationStatus.PENDING def test_strict_version_explicit_filters_with_versions(): @@ -367,22 +364,21 @@ def test_strict_version_explicit_filters_with_versions(): w.use_work_item_filters(custom_filters) w.start() - c = client.TaskHubGrpcClient(host_address=HOST) - - # Version "2.0" does not match the filter's "3.0" - bad_id = c.schedule_new_orchestration( - _simple_v2_orchestrator, input=1, version="2.0") + with client.TaskHubGrpcClient(host_address=HOST) as c: + # Version "2.0" does not match the filter's "3.0" + bad_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1, version="2.0") - # Version "3.0" should match - good_id = c.schedule_new_orchestration( - _simple_v2_orchestrator, input=1, version="3.0") - good_state = c.wait_for_orchestration_completion(good_id, timeout=30) + # Version "3.0" should match + good_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1, version="3.0") + good_state = c.wait_for_orchestration_completion(good_id, timeout=30) - assert good_state is not None - assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED - assert good_state.serialized_output == "2" + assert good_state is not None + assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert good_state.serialized_output == "2" - # Mismatched version must remain pending - bad_state = c.get_orchestration_state(bad_id) - assert bad_state is not None - assert bad_state.runtime_status == client.OrchestrationStatus.PENDING + # Mismatched version must remain pending + bad_state = c.get_orchestration_state(bad_id) + assert bad_state is not None + assert bad_state.runtime_status == client.OrchestrationStatus.PENDING