Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit ac301e6

Browse files
committed
fix: address pr comments
Signed-off-by: Casper Nielsen <casper@diagrid.io>
1 parent bdb1a44 commit ac301e6

5 files changed

Lines changed: 178 additions & 94 deletions

File tree

durabletask/internal/shared.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,6 @@ def get_default_host_address() -> str:
5858
)
5959

6060

61-
def _merge_grpc_options(
62-
user_options: Optional[Sequence[tuple[str, Any]]],
63-
defaults: Sequence[tuple[str, Any]] = DEFAULT_GRPC_KEEPALIVE_OPTIONS,
64-
) -> list[tuple[str, Any]]:
65-
"""Merge user gRPC options with defaults. User options take precedence."""
66-
merged = dict(defaults)
67-
if user_options:
68-
merged.update(dict(user_options))
69-
return list(merged.items())
70-
71-
7261
def get_grpc_channel(
7362
host_address: Optional[str],
7463
secure_channel: bool = False,
@@ -100,7 +89,10 @@ def get_grpc_channel(
10089
host_address = host_address[len(protocol) :]
10190
break
10291

103-
merged_options = _merge_grpc_options(options)
92+
merged = dict(DEFAULT_GRPC_KEEPALIVE_OPTIONS)
93+
if options:
94+
merged.update(dict(options))
95+
merged_options = list(merged.items())
10496
if secure_channel:
10597
channel = grpc.secure_channel(
10698
host_address, grpc.ssl_channel_credentials(), options=merged_options

durabletask/worker.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def __init__(
307307
self._channel_options = channel_options
308308
self._stop_timeout = stop_timeout
309309
self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup
310+
self._channel_cleanup_threads: list[threading.Thread] = [] # Deferred channel close threads
310311
self._stream_ready = threading.Event()
311312
# Use provided concurrency options or create default ones
312313
self._concurrency_options = (
@@ -388,9 +389,12 @@ async def _async_run_loop(self):
388389

389390
def create_fresh_connection():
390391
nonlocal current_channel, current_stub, conn_retry_count
391-
# Don't call channel.close() — in-flight activity RPCs on the
392-
# old stub may still reference the channel from another thread.
393-
# The old channel is GC'd once all references are released.
392+
# Schedule deferred close of old channel to avoid orphaned TCP
393+
# connections. In-flight RPCs on the old stub may still reference
394+
# the channel from another thread, so we wait a grace period
395+
# before closing instead of closing immediately.
396+
if current_channel is not None:
397+
self._schedule_deferred_channel_close(current_channel)
394398
current_channel = None
395399
current_stub = None
396400
try:
@@ -415,12 +419,12 @@ def create_fresh_connection():
415419

416420
def invalidate_connection():
417421
nonlocal current_channel, current_stub, current_reader_thread
418-
# Null out references so the next iteration creates a fresh connection.
419-
# Do NOT call channel.close() here — in-flight activity RPCs
420-
# (CompleteActivityTask) may still be using the stub on another
421-
# thread. Closing the channel concurrently causes segfaults in the
422-
# gRPC C extension. The old channel is GC'd once all references
423-
# (including captured stub refs in activity threads) are released.
422+
# Schedule deferred close of old channel to avoid orphaned TCP
423+
# connections. In-flight RPCs (e.g. CompleteActivityTask) may still
424+
# be using the stub on another thread, so we defer the close by a
425+
# grace period instead of closing immediately.
426+
if current_channel is not None:
427+
self._schedule_deferred_channel_close(current_channel)
424428
current_channel = None
425429
self._current_channel = None
426430
current_stub = None
@@ -717,6 +721,38 @@ def stream_reader():
717721
except Exception as e:
718722
self._logger.warning(f"Error while waiting for worker task shutdown: {e}")
719723

724+
def _schedule_deferred_channel_close(
725+
self, old_channel: grpc.Channel, grace_timeout: float = 10.0
726+
):
727+
"""Schedule a deferred close of an old gRPC channel.
728+
729+
Waits up to *grace_timeout* seconds for in-flight RPCs to complete
730+
before closing the channel. This prevents orphaned TCP connections
731+
while still allowing in-flight work (e.g. ``CompleteActivityTask``
732+
calls on another thread) to finish gracefully.
733+
734+
During ``stop()``, ``_shutdown`` is already set so the wait returns
735+
immediately and the channel is closed at once.
736+
"""
737+
# Prune already-finished cleanup threads to avoid unbounded growth
738+
self._channel_cleanup_threads = [t for t in self._channel_cleanup_threads if t.is_alive()]
739+
740+
def _deferred_close():
741+
try:
742+
# Normal reconnect: wait grace period for RPCs to drain.
743+
# Shutdown: _shutdown is already set, returns immediately.
744+
self._shutdown.wait(timeout=grace_timeout)
745+
finally:
746+
try:
747+
old_channel.close()
748+
self._logger.debug("Deferred channel close completed")
749+
except Exception as e:
750+
self._logger.debug(f"Error during deferred channel close: {e}")
751+
752+
thread = threading.Thread(target=_deferred_close, daemon=True, name="ChannelCleanup")
753+
thread.start()
754+
self._channel_cleanup_threads.append(thread)
755+
720756
def stop(self):
721757
"""Stops the worker and waits for any pending work items to complete."""
722758
if not self._is_running:
@@ -743,6 +779,11 @@ def stop(self):
743779
else:
744780
self._logger.debug("Worker thread completed successfully")
745781

782+
# Wait for any deferred channel-cleanup threads to finish
783+
for t in self._channel_cleanup_threads:
784+
t.join(timeout=5)
785+
self._channel_cleanup_threads.clear()
786+
746787
self._async_worker_manager.shutdown()
747788
self._logger.info("Worker shutdown completed")
748789
self._is_running = False

tests/durabletask/test_client.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22

33
from durabletask import client
44
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
5-
from durabletask.internal.shared import get_default_host_address, get_grpc_channel
5+
from durabletask.internal.shared import (
6+
DEFAULT_GRPC_KEEPALIVE_OPTIONS,
7+
get_default_host_address,
8+
get_grpc_channel,
9+
)
10+
11+
EXPECTED_DEFAULT_OPTIONS = list(DEFAULT_GRPC_KEEPALIVE_OPTIONS)
612

713
HOST_ADDRESS = "localhost:50051"
814
METADATA = [("key1", "value1"), ("key2", "value2")]
@@ -14,7 +20,7 @@ def test_get_grpc_channel_insecure():
1420
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
1521
args, kwargs = mock_channel.call_args
1622
assert args[0] == HOST_ADDRESS
17-
assert "options" in kwargs and kwargs["options"] is None
23+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
1824

1925

2026
def test_get_grpc_channel_secure():
@@ -26,15 +32,15 @@ def test_get_grpc_channel_secure():
2632
args, kwargs = mock_channel.call_args
2733
assert args[0] == HOST_ADDRESS
2834
assert args[1] == mock_credentials.return_value
29-
assert "options" in kwargs and kwargs["options"] is None
35+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
3036

3137

3238
def test_get_grpc_channel_default_host_address():
3339
with patch("grpc.insecure_channel") as mock_channel:
3440
get_grpc_channel(None, False, interceptors=INTERCEPTORS)
3541
args, kwargs = mock_channel.call_args
3642
assert args[0] == get_default_host_address()
37-
assert "options" in kwargs and kwargs["options"] is None
43+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
3844

3945

4046
def test_get_grpc_channel_with_metadata():
@@ -45,7 +51,7 @@ def test_get_grpc_channel_with_metadata():
4551
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
4652
args, kwargs = mock_channel.call_args
4753
assert args[0] == HOST_ADDRESS
48-
assert "options" in kwargs and kwargs["options"] is None
54+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
4955
mock_intercept_channel.assert_called_once()
5056

5157
# Capture and check the arguments passed to intercept_channel()
@@ -66,61 +72,61 @@ def test_grpc_channel_with_host_name_protocol_stripping():
6672
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
6773
args, kwargs = mock_insecure_channel.call_args
6874
assert args[0] == host_name
69-
assert "options" in kwargs and kwargs["options"] is None
75+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
7076

7177
prefix = "http://"
7278
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
7379
args, kwargs = mock_insecure_channel.call_args
7480
assert args[0] == host_name
75-
assert "options" in kwargs and kwargs["options"] is None
81+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
7682

7783
prefix = "HTTP://"
7884
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
7985
args, kwargs = mock_insecure_channel.call_args
8086
assert args[0] == host_name
81-
assert "options" in kwargs and kwargs["options"] is None
87+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
8288

8389
prefix = "GRPC://"
8490
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
8591
args, kwargs = mock_insecure_channel.call_args
8692
assert args[0] == host_name
87-
assert "options" in kwargs and kwargs["options"] is None
93+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
8894

8995
prefix = ""
9096
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
9197
args, kwargs = mock_insecure_channel.call_args
9298
assert args[0] == host_name
93-
assert "options" in kwargs and kwargs["options"] is None
99+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
94100

95101
prefix = "grpcs://"
96102
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
97103
args, kwargs = mock_secure_channel.call_args
98104
assert args[0] == host_name
99-
assert "options" in kwargs and kwargs["options"] is None
105+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
100106

101107
prefix = "https://"
102108
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
103109
args, kwargs = mock_secure_channel.call_args
104110
assert args[0] == host_name
105-
assert "options" in kwargs and kwargs["options"] is None
111+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
106112

107113
prefix = "HTTPS://"
108114
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
109115
args, kwargs = mock_secure_channel.call_args
110116
assert args[0] == host_name
111-
assert "options" in kwargs and kwargs["options"] is None
117+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
112118

113119
prefix = "GRPCS://"
114120
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
115121
args, kwargs = mock_secure_channel.call_args
116122
assert args[0] == host_name
117-
assert "options" in kwargs and kwargs["options"] is None
123+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
118124

119125
prefix = ""
120126
get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS)
121127
args, kwargs = mock_secure_channel.call_args
122128
assert args[0] == host_name
123-
assert "options" in kwargs and kwargs["options"] is None
129+
assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS
124130

125131

126132
def test_sync_channel_passes_base_options_and_max_lengths():

tests/durabletask/test_shared.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)