diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py index b79f7eb8..b76d07d6 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py @@ -331,6 +331,7 @@ def __init__( self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup self._channel_cleanup_threads: list[threading.Thread] = [] # Deferred channel close threads self._stream_ready = threading.Event() + self._runLoop: Optional[Thread] = None # Use provided concurrency options or create default ones self._concurrency_options = ( concurrency_options if concurrency_options is not None else ConcurrencyOptions() @@ -387,8 +388,13 @@ def run_loop(): self._logger.info(f'Starting gRPC worker that connects to {self._host_address}') self._runLoop = Thread(target=run_loop, name='WorkerRunLoop') self._runLoop.start() - if not self._stream_ready.wait(timeout=10): - raise RuntimeError('Failed to establish work item stream connection within 10 seconds') + while not self._stream_ready.wait(timeout=1): + if self._shutdown.is_set(): + raise RuntimeError('Worker was stopped before the work item stream was established') + if not self._runLoop.is_alive(): + raise RuntimeError( + 'Worker run loop exited before the work item stream was established' + ) self._is_running = True async def _keepalive_loop(self, stub): @@ -801,7 +807,9 @@ def _deferred_close(): def stop(self): """Stops the worker and waits for any pending work items to complete.""" - if not self._is_running: + # Guards on _runLoop rather than _is_running so stop() can unblock a start() + # that is still waiting for the work item stream to be established. + if self._runLoop is None: return self._logger.info('Stopping gRPC worker...') @@ -833,6 +841,7 @@ def stop(self): self._async_worker_manager.shutdown() self._logger.info('Worker shutdown completed') self._is_running = False + self._runLoop = None # TODO: This should be removed in the future as we do handle grpc errs def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: str): diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_worker_stop.py b/ext/dapr-ext-workflow/tests/durabletask/test_worker_stop.py index 30789c23..0fef6ab8 100644 --- a/ext/dapr-ext-workflow/tests/durabletask/test_worker_stop.py +++ b/ext/dapr-ext-workflow/tests/durabletask/test_worker_stop.py @@ -1,6 +1,10 @@ +import asyncio +import threading +import time from unittest.mock import MagicMock, patch import grpc +import pytest from dapr.ext.workflow._durabletask.worker import TaskHubGrpcWorker @@ -146,3 +150,65 @@ def test_deferred_close_prunes_finished_threads(): worker._channel_cleanup_threads[-1].join(timeout=2) # Only the still-alive (or just-finished ch2) thread remains; ch1's was pruned assert len(worker._channel_cleanup_threads) <= 1 + + +def test_stop_before_start_is_noop(): + """stop() is safe to call before start() — _runLoop is None, no AttributeError.""" + worker = TaskHubGrpcWorker() + with patch.object(worker._shutdown, 'set') as shutdown_set: + worker.stop() + shutdown_set.assert_not_called() + + +def test_stop_is_idempotent(): + """A second stop() returns early because _runLoop was cleared by the first.""" + worker = _make_running_worker() + worker._current_channel = MagicMock() + worker.stop() + assert worker._runLoop is None + with patch.object(worker._shutdown, 'set') as shutdown_set: + worker.stop() + shutdown_set.assert_not_called() + + +def test_start_raises_when_run_loop_exits_early(): + """start() raises RuntimeError if the run loop thread exits before _stream_ready is set.""" + worker = TaskHubGrpcWorker() + + async def fast_exit(): + return + + with patch.object(worker, '_async_run_loop', side_effect=fast_exit): + with pytest.raises(RuntimeError, match='Worker run loop exited'): + worker.start() + + +def test_start_raises_when_stopped_during_startup(): + """stop() unblocks a start() that is waiting for _stream_ready; start() raises.""" + worker = TaskHubGrpcWorker() + + async def wait_for_shutdown(): + # Block without setting _stream_ready so start() stays in its wait loop. + while not worker._shutdown.is_set(): + await asyncio.sleep(0.05) + + errors = [] + + def _start(): + try: + worker.start() + except Exception as e: # noqa: BLE001 + errors.append(e) + + with patch.object(worker, '_async_run_loop', side_effect=wait_for_shutdown): + t = threading.Thread(target=_start) + t.start() + # Let start() enter its wait loop (timeout=1 per iteration). + time.sleep(1.2) + worker.stop() + t.join(timeout=5) + + assert not t.is_alive(), 'start() did not return after stop()' + assert len(errors) == 1, f'Expected exactly one error, got: {errors}' + assert isinstance(errors[0], RuntimeError) + assert 'Worker was stopped' in str(errors[0])