Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def __init__(
self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup
self._channel_cleanup_threads: list[threading.Thread] = [] # Deferred channel close threads
self._stream_ready = threading.Event()
self._runLoop: Optional[Thread] = None
Comment thread
JoshVanL marked this conversation as resolved.
# Use provided concurrency options or create default ones
self._concurrency_options = (
concurrency_options if concurrency_options is not None else ConcurrencyOptions()
Expand Down Expand Up @@ -387,8 +388,13 @@ def run_loop():
self._logger.info(f'Starting gRPC worker that connects to {self._host_address}')
self._runLoop = Thread(target=run_loop, name='WorkerRunLoop')
self._runLoop.start()
if not self._stream_ready.wait(timeout=10):
raise RuntimeError('Failed to establish work item stream connection within 10 seconds')
while not self._stream_ready.wait(timeout=1):
if self._shutdown.is_set():
raise RuntimeError('Worker was stopped before the work item stream was established')
if not self._runLoop.is_alive():
raise RuntimeError(
'Worker run loop exited before the work item stream was established'
)
Comment thread
JoshVanL marked this conversation as resolved.
Comment thread
JoshVanL marked this conversation as resolved.
self._is_running = True
Comment thread
JoshVanL marked this conversation as resolved.

async def _keepalive_loop(self, stub):
Expand Down Expand Up @@ -801,7 +807,9 @@ def _deferred_close():

def stop(self):
"""Stops the worker and waits for any pending work items to complete."""
if not self._is_running:
# Guards on _runLoop rather than _is_running so stop() can unblock a start()
# that is still waiting for the work item stream to be established.
if self._runLoop is None:
return
Comment thread
JoshVanL marked this conversation as resolved.

self._logger.info('Stopping gRPC worker...')
Expand Down Expand Up @@ -833,6 +841,7 @@ def stop(self):
self._async_worker_manager.shutdown()
self._logger.info('Worker shutdown completed')
self._is_running = False
self._runLoop = None

# TODO: This should be removed in the future as we do handle grpc errs
def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: str):
Expand Down
66 changes: 66 additions & 0 deletions ext/dapr-ext-workflow/tests/durabletask/test_worker_stop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import asyncio
import threading
import time
from unittest.mock import MagicMock, patch

import grpc
import pytest
from dapr.ext.workflow._durabletask.worker import TaskHubGrpcWorker


Expand Down Expand Up @@ -146,3 +150,65 @@ def test_deferred_close_prunes_finished_threads():
worker._channel_cleanup_threads[-1].join(timeout=2)
# Only the still-alive (or just-finished ch2) thread remains; ch1's was pruned
assert len(worker._channel_cleanup_threads) <= 1


def test_stop_before_start_is_noop():
"""stop() is safe to call before start() — _runLoop is None, no AttributeError."""
worker = TaskHubGrpcWorker()
with patch.object(worker._shutdown, 'set') as shutdown_set:
worker.stop()
shutdown_set.assert_not_called()


def test_stop_is_idempotent():
"""A second stop() returns early because _runLoop was cleared by the first."""
worker = _make_running_worker()
worker._current_channel = MagicMock()
worker.stop()
assert worker._runLoop is None
with patch.object(worker._shutdown, 'set') as shutdown_set:
worker.stop()
shutdown_set.assert_not_called()


def test_start_raises_when_run_loop_exits_early():
"""start() raises RuntimeError if the run loop thread exits before _stream_ready is set."""
worker = TaskHubGrpcWorker()

async def fast_exit():
return

with patch.object(worker, '_async_run_loop', side_effect=fast_exit):
with pytest.raises(RuntimeError, match='Worker run loop exited'):
worker.start()


def test_start_raises_when_stopped_during_startup():
"""stop() unblocks a start() that is waiting for _stream_ready; start() raises."""
worker = TaskHubGrpcWorker()

async def wait_for_shutdown():
# Block without setting _stream_ready so start() stays in its wait loop.
while not worker._shutdown.is_set():
await asyncio.sleep(0.05)

errors = []

def _start():
try:
worker.start()
except Exception as e: # noqa: BLE001
errors.append(e)

with patch.object(worker, '_async_run_loop', side_effect=wait_for_shutdown):
t = threading.Thread(target=_start)
t.start()
# Let start() enter its wait loop (timeout=1 per iteration).
time.sleep(1.2)
worker.stop()
t.join(timeout=5)

assert not t.is_alive(), 'start() did not return after stop()'
assert len(errors) == 1, f'Expected exactly one error, got: {errors}'
assert isinstance(errors[0], RuntimeError)
assert 'Worker was stopped' in str(errors[0])
Loading