diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/task.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/task.py index d1a211dcb..e52831a07 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/task.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/task.py @@ -360,12 +360,15 @@ def pending_tasks(self) -> int: def on_child_completed(self, task: Task[T]): if self.is_complete: - raise ValueError('The task has already completed.') + # Already completed (e.g. a previous child failed), ignore late arrivals + return self._completed_tasks += 1 if task.is_failed and self._exception is None: self._exception = task.get_exception() self._is_complete = True - if self._completed_tasks == len(self._tasks): + if self._parent is not None: + self._parent.on_child_completed(self) + elif self._completed_tasks == len(self._tasks): # The order of the result MUST match the order of the tasks provided to the constructor. self._result = [task.get_result() for task in self._tasks] self._is_complete = True diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_executor.py b/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_executor.py index 0bc5c9981..37b44d150 100644 --- a/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_executor.py +++ b/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_executor.py @@ -1212,6 +1212,107 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert str(ex) in complete_action.failureDetails.errorMessage +def test_when_all_failure_after_success_bubbles_to_orchestrator(): + """Tests that when_all correctly surfaces a failure to the orchestrator + even when succeeding tasks complete before the failing one. + + This is a regression test: previously the exception from the failed task + was swallowed when another task had already completed successfully. + """ + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + t1 = ctx.call_activity(dummy_activity, input='will-succeed') + t2 = ctx.call_activity(dummy_activity, input='will-fail') + try: + yield task.when_all([t1, t2]) + except task.TaskFailedError: + return 'caught' + return 'not caught' + + registry = worker._Registry() + orchestrator_name = registry.add_orchestrator(orchestrator) + activity_name = registry.add_activity(dummy_activity) + + old_events = [ + helpers.new_workflow_started_event(), + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + helpers.new_task_scheduled_event(1, activity_name), + helpers.new_task_scheduled_event(2, activity_name), + ] + + # t1 succeeds FIRST, then t2 fails — this is the order that triggered the bug + ex = Exception('activity error') + new_events = [ + helpers.new_task_completed_event(1, encoded_output=json.dumps('ok')), + helpers.new_task_failed_event(2, ex), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + + complete_action = get_and_validate_single_complete_workflow_action(actions) + # The orchestrator should have caught the exception and returned 'caught' + assert complete_action.workflowStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert complete_action.result.value == json.dumps('caught') + + +def test_when_all_success_after_failure_does_not_crash(): + """Tests that task completions arriving after when_all already failed + do not crash the orchestration. + + This is a regression test: previously a ValueError was raised when + a successful task completed after the WhenAllTask was already marked + complete due to a prior child failure. + """ + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + t1 = ctx.call_activity(dummy_activity, input='will-fail') + t2 = ctx.call_activity(dummy_activity, input='will-succeed') + try: + yield task.when_all([t1, t2]) + except task.TaskFailedError: + return 'caught' + return 'not caught' + + registry = worker._Registry() + orchestrator_name = registry.add_orchestrator(orchestrator) + activity_name = registry.add_activity(dummy_activity) + + old_events = [ + helpers.new_workflow_started_event(), + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + helpers.new_task_scheduled_event(1, activity_name), + helpers.new_task_scheduled_event(2, activity_name), + ] + + # t1 fails FIRST, then t2 succeeds — this would previously raise ValueError + ex = Exception('activity error') + new_events = [ + helpers.new_task_failed_event(1, ex), + helpers.new_task_completed_event(2, encoded_output=json.dumps('ok')), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + + complete_action = get_and_validate_single_complete_workflow_action(actions) + # The orchestrator should have caught the exception and returned 'caught' + assert complete_action.workflowStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert complete_action.result.value == json.dumps('caught') + + def test_when_any(): """Tests that a when_any pattern works correctly""" diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_task.py b/ext/dapr-ext-workflow/tests/durabletask/test_task.py index 02c2b6458..a381ab3d1 100644 --- a/ext/dapr-ext-workflow/tests/durabletask/test_task.py +++ b/ext/dapr-ext-workflow/tests/durabletask/test_task.py @@ -11,9 +11,16 @@ """Unit tests for durabletask.task primitives.""" +import dapr.ext.workflow._durabletask.internal.helpers as pbh +import pytest from dapr.ext.workflow._durabletask import task +def _make_failure_details(message: str = 'test error', error_type: str = 'TestError'): + """Create a TaskFailureDetails proto for testing.""" + return pbh.new_failure_details(Exception(message)) + + def test_when_all_empty_returns_successfully(): """task.when_all([]) should complete immediately and return an empty list.""" when_all_task = task.when_all([]) @@ -121,3 +128,70 @@ def test_when_any_happy_path_returns_winner_task_and_completes_on_first(): a.complete('A') assert any_task.get_result() is b + + +def test_when_all_failure_after_success_still_reports_failure(): + """When a child fails after another child has already succeeded, + the WhenAllTask must still complete with the failure — not swallow it.""" + c1 = task.CompletableTask() + c2 = task.CompletableTask() + + all_task = task.when_all([c1, c2]) + + # c1 succeeds first + c1.complete('one') + assert not all_task.is_complete + + # c2 fails second — this is the order that used to swallow the exception + c2.fail('activity failed', _make_failure_details('activity failed')) + + assert all_task.is_complete + assert all_task.is_failed + with pytest.raises(task.TaskFailedError): + all_task.get_result() + + +def test_when_all_failure_before_success_still_reports_failure(): + """When a child fails before the other children succeed, + the WhenAllTask must complete with the failure immediately.""" + c1 = task.CompletableTask() + c2 = task.CompletableTask() + + all_task = task.when_all([c1, c2]) + + # c1 fails first + c1.fail('activity failed', _make_failure_details('activity failed')) + + assert all_task.is_complete + assert all_task.is_failed + with pytest.raises(task.TaskFailedError): + all_task.get_result() + + # c2 succeeds after — must not raise ValueError + c2.complete('two') + + # WhenAllTask should still be in the same failed state + assert all_task.is_complete + assert all_task.is_failed + with pytest.raises(task.TaskFailedError): + all_task.get_result() + + +def test_when_all_failure_propagates_to_parent(): + """When a WhenAllTask fails due to a child failure, + it should notify its parent composite task.""" + c1 = task.CompletableTask() + c2 = task.CompletableTask() + + all_task = task.when_all([c1, c2]) + any_task = task.when_any([all_task]) + + assert not any_task.is_complete + + c1.fail('activity failed', _make_failure_details('activity failed')) + + assert all_task.is_complete + assert all_task.is_failed + # The parent WhenAnyTask should also have completed + assert any_task.is_complete + assert any_task.get_result() is all_task