Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,15 @@ def pending_tasks(self) -> int:

def on_child_completed(self, task: Task[T]):
if self.is_complete:
raise ValueError('The task has already completed.')
# Already completed (e.g. a previous child failed), ignore late arrivals
return
self._completed_tasks += 1
if task.is_failed and self._exception is None:
self._exception = task.get_exception()
self._is_complete = True
if self._completed_tasks == len(self._tasks):
if self._parent is not None:
self._parent.on_child_completed(self)
elif self._completed_tasks == len(self._tasks):
# The order of the result MUST match the order of the tasks provided to the constructor.
self._result = [task.get_result() for task in self._tasks]
self._is_complete = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,107 @@ def orchestrator(ctx: task.OrchestrationContext, _):
assert str(ex) in complete_action.failureDetails.errorMessage


def test_when_all_failure_after_success_bubbles_to_orchestrator():
"""Tests that when_all correctly surfaces a failure to the orchestrator
even when succeeding tasks complete before the failing one.

This is a regression test: previously the exception from the failed task
was swallowed when another task had already completed successfully.
"""

def dummy_activity(ctx, _):
pass

def orchestrator(ctx: task.OrchestrationContext, _):
t1 = ctx.call_activity(dummy_activity, input='will-succeed')
t2 = ctx.call_activity(dummy_activity, input='will-fail')
try:
yield task.when_all([t1, t2])
except task.TaskFailedError:
return 'caught'
return 'not caught'

registry = worker._Registry()
orchestrator_name = registry.add_orchestrator(orchestrator)
activity_name = registry.add_activity(dummy_activity)

old_events = [
helpers.new_workflow_started_event(),
helpers.new_execution_started_event(
orchestrator_name, TEST_INSTANCE_ID, encoded_input=None
),
helpers.new_task_scheduled_event(1, activity_name),
helpers.new_task_scheduled_event(2, activity_name),
]

# t1 succeeds FIRST, then t2 fails — this is the order that triggered the bug
ex = Exception('activity error')
new_events = [
helpers.new_task_completed_event(1, encoded_output=json.dumps('ok')),
helpers.new_task_failed_event(2, ex),
]

executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
result = executor.execute(TEST_INSTANCE_ID, old_events, new_events)
actions = result.actions

complete_action = get_and_validate_single_complete_workflow_action(actions)
# The orchestrator should have caught the exception and returned 'caught'
assert complete_action.workflowStatus == pb.ORCHESTRATION_STATUS_COMPLETED
assert complete_action.result.value == json.dumps('caught')


def test_when_all_success_after_failure_does_not_crash():
"""Tests that task completions arriving after when_all already failed
do not crash the orchestration.

This is a regression test: previously a ValueError was raised when
a successful task completed after the WhenAllTask was already marked
complete due to a prior child failure.
"""

def dummy_activity(ctx, _):
pass

def orchestrator(ctx: task.OrchestrationContext, _):
t1 = ctx.call_activity(dummy_activity, input='will-fail')
t2 = ctx.call_activity(dummy_activity, input='will-succeed')
try:
yield task.when_all([t1, t2])
except task.TaskFailedError:
return 'caught'
return 'not caught'

registry = worker._Registry()
orchestrator_name = registry.add_orchestrator(orchestrator)
activity_name = registry.add_activity(dummy_activity)

old_events = [
helpers.new_workflow_started_event(),
helpers.new_execution_started_event(
orchestrator_name, TEST_INSTANCE_ID, encoded_input=None
),
helpers.new_task_scheduled_event(1, activity_name),
helpers.new_task_scheduled_event(2, activity_name),
]

# t1 fails FIRST, then t2 succeeds — this would previously raise ValueError
ex = Exception('activity error')
new_events = [
helpers.new_task_failed_event(1, ex),
helpers.new_task_completed_event(2, encoded_output=json.dumps('ok')),
]

executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
result = executor.execute(TEST_INSTANCE_ID, old_events, new_events)
actions = result.actions

complete_action = get_and_validate_single_complete_workflow_action(actions)
# The orchestrator should have caught the exception and returned 'caught'
assert complete_action.workflowStatus == pb.ORCHESTRATION_STATUS_COMPLETED
assert complete_action.result.value == json.dumps('caught')


def test_when_any():
"""Tests that a when_any pattern works correctly"""

Expand Down
74 changes: 74 additions & 0 deletions ext/dapr-ext-workflow/tests/durabletask/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@

"""Unit tests for durabletask.task primitives."""

import dapr.ext.workflow._durabletask.internal.helpers as pbh
import pytest
from dapr.ext.workflow._durabletask import task


def _make_failure_details(message: str = 'test error', error_type: str = 'TestError'):
"""Create a TaskFailureDetails proto for testing."""
return pbh.new_failure_details(Exception(message))


def test_when_all_empty_returns_successfully():
"""task.when_all([]) should complete immediately and return an empty list."""
when_all_task = task.when_all([])
Expand Down Expand Up @@ -121,3 +128,70 @@ def test_when_any_happy_path_returns_winner_task_and_completes_on_first():
a.complete('A')

assert any_task.get_result() is b


def test_when_all_failure_after_success_still_reports_failure():
"""When a child fails after another child has already succeeded,
the WhenAllTask must still complete with the failure — not swallow it."""
c1 = task.CompletableTask()
c2 = task.CompletableTask()

all_task = task.when_all([c1, c2])

# c1 succeeds first
c1.complete('one')
assert not all_task.is_complete

# c2 fails second — this is the order that used to swallow the exception
c2.fail('activity failed', _make_failure_details('activity failed'))

assert all_task.is_complete
assert all_task.is_failed
with pytest.raises(task.TaskFailedError):
all_task.get_result()


def test_when_all_failure_before_success_still_reports_failure():
"""When a child fails before the other children succeed,
the WhenAllTask must complete with the failure immediately."""
c1 = task.CompletableTask()
c2 = task.CompletableTask()

all_task = task.when_all([c1, c2])

# c1 fails first
c1.fail('activity failed', _make_failure_details('activity failed'))

assert all_task.is_complete
assert all_task.is_failed
with pytest.raises(task.TaskFailedError):
all_task.get_result()

# c2 succeeds after — must not raise ValueError
c2.complete('two')

# WhenAllTask should still be in the same failed state
assert all_task.is_complete
assert all_task.is_failed
with pytest.raises(task.TaskFailedError):
all_task.get_result()


def test_when_all_failure_propagates_to_parent():
"""When a WhenAllTask fails due to a child failure,
it should notify its parent composite task."""
c1 = task.CompletableTask()
c2 = task.CompletableTask()

all_task = task.when_all([c1, c2])
any_task = task.when_any([all_task])

assert not any_task.is_complete

c1.fail('activity failed', _make_failure_details('activity failed'))

assert all_task.is_complete
assert all_task.is_failed
# The parent WhenAnyTask should also have completed
assert any_task.is_complete
assert any_task.get_result() is all_task
Loading