Skip to content

Commit 3268b45

Browse files
Bernd VerstCopilot
andcommitted
Fire-and-forget channel recreate; hoist client state; tighten async interceptor exception handling
Addresses three follow-up review comments on the resiliency interceptor refactor: [1/3] Channel recreate now runs fire-and-forget (daemon thread for sync, asyncio.create_task for async). The original RPC error propagates to the caller without being delayed by DNS, TLS handshake, or contention on _recreate_lock. A client-side single-flight guard avoids spawning duplicate work when many failures land in a burst; the existing cooldown still prevents thrash. close() waits for any in-flight recreate to finish so the teardown path stays deterministic. A _recreate_done_event (test seam) lets tests synchronise on completion without polling. [2/3] Hoisted _closing, _recreate_lock, _last_recreate_time, _retired_channels / _retired_channel_close_tasks above ClientResiliencyInterceptor construction in both __init__ methods so the bound recreate callback is safe to invoke at any time during construction. [3/3] AsyncClientResiliencyInterceptor now uses 'except Exception' (so asyncio.CancelledError, KeyboardInterrupt and SystemExit propagate unchanged) and mirrors the sync interceptor's policy by resetting the failure counter on non-AioRpcError exceptions. _record_outcome is now synchronous on both interceptors because the on_recreate callback no longer awaits the recreate. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 053b579 commit 3268b45

3 files changed

Lines changed: 176 additions & 27 deletions

File tree

durabletask/client.py

Lines changed: 121 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,28 @@ def __init__(self, *,
198198
if resiliency_options is not None
199199
else GrpcClientResiliencyOptions()
200200
)
201+
# Resiliency state must be initialised BEFORE the interceptor is
202+
# constructed because the interceptor receives a bound reference to
203+
# ``self._schedule_recreate``; any failure handled during construction
204+
# of the underlying channel could otherwise observe a half-built
205+
# client.
206+
self._closing = False
207+
self._last_recreate_time = 0.0
208+
self._recreate_lock = threading.Lock()
209+
self._retired_channels: dict[grpc.Channel, threading.Timer] = {}
210+
self._recreate_thread_lock = threading.Lock()
211+
self._recreate_thread: Optional[threading.Thread] = None
212+
# Test seam: set after each fire-and-forget recreate attempt finishes
213+
# (whether it actually recreated the channel or short-circuited on
214+
# close / cooldown). Lets tests synchronise without polling and lets
215+
# ``close()`` wait deterministically for an in-flight recreate.
216+
self._recreate_done_event = threading.Event()
201217
self._client_failure_tracker = FailureTracker(
202218
self._resiliency_options.channel_recreate_failure_threshold
203219
)
204220
self._resiliency_interceptor = ClientResiliencyInterceptor(
205221
self._client_failure_tracker,
206-
self._maybe_recreate_channel,
222+
self._schedule_recreate,
207223
)
208224
resolved_interceptors = (
209225
prepare_sync_interceptors(metadata, interceptors) if channel is None else interceptors
@@ -230,10 +246,6 @@ def __init__(self, *,
230246
# can prepend the interceptor themselves via grpc.intercept_channel.
231247
self._channel = channel
232248
self._stub = stubs.TaskHubSidecarServiceStub(channel)
233-
self._closing = False
234-
self._last_recreate_time = 0.0
235-
self._recreate_lock = threading.Lock()
236-
self._retired_channels: dict[grpc.Channel, threading.Timer] = {}
237249
self._logger = shared.get_logger("client", log_handler, log_formatter)
238250
self.default_version = default_version
239251
self._payload_store = payload_store
@@ -252,6 +264,48 @@ def _compose_interceptors(
252264
composed.extend(user_interceptors)
253265
return composed
254266

267+
def _schedule_recreate(self) -> None:
268+
"""Spawn a daemon thread that recreates the channel fire-and-forget.
269+
270+
Called from the resiliency interceptor on the caller's thread when a
271+
unary RPC fails with a transport error. The interceptor returns to its
272+
caller as soon as this method returns, so the failing RPC's original
273+
error propagates without being delayed by DNS, TLS handshake, or
274+
contention on ``_recreate_lock``.
275+
276+
Single-flight under ``_recreate_thread_lock``: if a recreate thread is
277+
still alive, the new trigger is dropped. The in-flight recreate will
278+
pick up the latest channel state on completion; the cooldown inside
279+
``_maybe_recreate_channel`` further prevents thrash. ``thread.start()``
280+
is called under the lock so a follow-up caller's ``is_alive()`` check
281+
observes the running state rather than racing the start.
282+
"""
283+
try:
284+
if self._closing:
285+
return
286+
with self._recreate_thread_lock:
287+
existing = self._recreate_thread
288+
if existing is not None and existing.is_alive():
289+
return
290+
self._recreate_done_event.clear()
291+
thread = threading.Thread(
292+
target=self._run_recreate,
293+
name="durabletask-client-recreate",
294+
daemon=True,
295+
)
296+
self._recreate_thread = thread
297+
thread.start()
298+
except Exception:
299+
self._logger.exception("Failed to schedule channel recreate")
300+
301+
def _run_recreate(self) -> None:
302+
try:
303+
self._maybe_recreate_channel()
304+
except Exception:
305+
self._logger.exception("Channel recreate failed")
306+
finally:
307+
self._recreate_done_event.set()
308+
255309
def _maybe_recreate_channel(self) -> None:
256310
if not self._owns_channel or self._closing:
257311
return
@@ -296,8 +350,14 @@ def close(self) -> None:
296350
it.
297351
"""
298352
if self._owns_channel:
353+
# Signal early so any in-flight recreate thread bails out of
354+
# ``_maybe_recreate_channel`` before we tear the channel down.
355+
self._closing = True
356+
with self._recreate_thread_lock:
357+
recreate_thread = self._recreate_thread
358+
if recreate_thread is not None and recreate_thread.is_alive():
359+
recreate_thread.join(timeout=5.0)
299360
with self._recreate_lock:
300-
self._closing = True
301361
retired_channels = list(self._retired_channels.items())
302362
self._retired_channels.clear()
303363
current_channel = self._channel
@@ -628,12 +688,28 @@ def __init__(self, *,
628688
if resiliency_options is not None
629689
else GrpcClientResiliencyOptions()
630690
)
691+
# Resiliency state must be initialised BEFORE the interceptor is
692+
# constructed because the interceptor receives a bound reference to
693+
# ``self._schedule_recreate``; any failure handled during construction
694+
# of the underlying channel could otherwise observe a half-built
695+
# client.
696+
self._closing = False
697+
self._recreate_lock = asyncio.Lock()
698+
self._last_recreate_time = 0.0
699+
self._retired_channels: list[grpc.aio.Channel] = []
700+
self._retired_channel_close_tasks: set[asyncio.Task[None]] = set()
701+
self._recreate_task: Optional[asyncio.Task[None]] = None
702+
# Test seam: set after each fire-and-forget recreate attempt finishes
703+
# (whether it actually recreated the channel or short-circuited on
704+
# close / cooldown). Lets tests synchronise without polling and lets
705+
# ``close()`` await an in-flight recreate deterministically.
706+
self._recreate_done_event = asyncio.Event()
631707
self._client_failure_tracker = FailureTracker(
632708
self._resiliency_options.channel_recreate_failure_threshold
633709
)
634710
self._resiliency_interceptor = AsyncClientResiliencyInterceptor(
635711
self._client_failure_tracker,
636-
self._maybe_recreate_channel,
712+
self._schedule_recreate,
637713
)
638714
resolved_interceptors = (
639715
prepare_async_interceptors(metadata, interceptors) if channel is None else interceptors
@@ -660,11 +736,6 @@ def __init__(self, *,
660736
# resiliency should let us create the channel.
661737
self._channel = channel
662738
self._stub = stubs.TaskHubSidecarServiceStub(channel)
663-
self._closing = False
664-
self._recreate_lock = asyncio.Lock()
665-
self._last_recreate_time = 0.0
666-
self._retired_channels: list[grpc.aio.Channel] = []
667-
self._retired_channel_close_tasks: set[asyncio.Task[None]] = set()
668739
self._logger = shared.get_logger("async_client", log_handler, log_formatter)
669740
self.default_version = default_version
670741
self._payload_store = payload_store
@@ -688,7 +759,17 @@ async def close(self) -> None:
688759
it.
689760
"""
690761
if self._owns_channel:
762+
# Signal early so any in-flight recreate task bails out of
763+
# ``_maybe_recreate_channel`` before we tear the channel down.
691764
self._closing = True
765+
recreate_task = self._recreate_task
766+
if recreate_task is not None and not recreate_task.done():
767+
try:
768+
await recreate_task
769+
except Exception:
770+
# Already logged by ``_run_recreate``; suppressing here
771+
# ensures close() always tears down cleanly.
772+
pass
692773
async with self._recreate_lock:
693774
retired_channels = list(self._retired_channels)
694775
self._retired_channels.clear()
@@ -708,6 +789,34 @@ async def __aenter__(self):
708789
async def __aexit__(self, exc_type, exc_val, exc_tb):
709790
await self.close()
710791

792+
def _schedule_recreate(self) -> None:
793+
"""Schedule a fire-and-forget channel recreate on the event loop.
794+
795+
Called from the resiliency interceptor when a unary RPC fails with a
796+
transport error. Single-flight: if ``_recreate_task`` is still
797+
pending, the trigger is dropped — the in-flight recreate will pick up
798+
the latest channel state on completion. asyncio is single-threaded
799+
so ``done()`` is race-free; no extra lock is required.
800+
"""
801+
try:
802+
if self._closing:
803+
return
804+
existing = self._recreate_task
805+
if existing is not None and not existing.done():
806+
return
807+
self._recreate_done_event.clear()
808+
self._recreate_task = asyncio.create_task(self._run_recreate())
809+
except Exception:
810+
self._logger.exception("Failed to schedule channel recreate")
811+
812+
async def _run_recreate(self) -> None:
813+
try:
814+
await self._maybe_recreate_channel()
815+
except Exception:
816+
self._logger.exception("Channel recreate failed")
817+
finally:
818+
self._recreate_done_event.set()
819+
711820
async def _maybe_recreate_channel(self) -> None:
712821
if not self._owns_channel or self._closing:
713822
return

durabletask/internal/grpc_resiliency.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
import inspect
54
import random
65
import threading
76
from dataclasses import dataclass, field
8-
from typing import Awaitable, Callable, Optional, Union
7+
from typing import Callable, Optional
98

109
import grpc
1110
import grpc.aio
@@ -92,6 +91,13 @@ class ClientResiliencyInterceptor(grpc.UnaryUnaryClientInterceptor):
9291
need to wrap every stub call: any unary RPC sent through the intercepted
9392
channel automatically participates in failure tracking and channel
9493
recreation, including future RPCs that are added to the service stub.
94+
95+
The ``on_recreate`` callback is invoked **fire-and-forget** by the owning
96+
client (it schedules the actual recreate on a daemon thread / asyncio
97+
task), so this interceptor never blocks the calling thread or event loop
98+
on DNS, TLS handshake, or any other channel construction work. The
99+
triggering call's original error propagates to the caller without added
100+
latency; subsequent calls benefit from the recreated channel.
95101
"""
96102

97103
def __init__(
@@ -121,37 +127,48 @@ def _record_outcome(self, method: str, error: Optional[BaseException]) -> None:
121127

122128

123129
class AsyncClientResiliencyInterceptor(grpc.aio.UnaryUnaryClientInterceptor):
124-
"""Async counterpart of :class:`ClientResiliencyInterceptor`."""
130+
"""Async counterpart of :class:`ClientResiliencyInterceptor`.
131+
132+
The ``on_recreate`` callback is a *synchronous* function (it schedules an
133+
``asyncio.Task`` for the actual recreate); this keeps the original RPC
134+
error free of any extra latency from DNS / TLS handshake / lock waits and
135+
guarantees the caller sees its original ``AioRpcError`` rather than an
136+
exception that happened during recreate scheduling.
137+
138+
Non-``AioRpcError`` exceptions reset the failure counter (matching the
139+
sync interceptor's policy, where ``.exception()`` returning a non-RpcError
140+
falls through to ``record_success``). ``CancelledError`` and other
141+
non-``Exception`` ``BaseException`` subclasses propagate without bookkeeping,
142+
which is the correct asyncio convention.
143+
"""
125144

126145
def __init__(
127146
self,
128147
failure_tracker: FailureTracker,
129-
on_recreate: Callable[[], Union[None, Awaitable[object]]],
148+
on_recreate: Callable[[], None],
130149
):
131150
self._failure_tracker = failure_tracker
132151
self._on_recreate = on_recreate
133152

134153
async def intercept_unary_unary(self, continuation, client_call_details, request):
135154
try:
136155
response = await continuation(client_call_details, request)
137-
except grpc.aio.AioRpcError as rpc_error:
138-
await self._record_outcome(client_call_details.method, rpc_error)
156+
except Exception as exc:
157+
if isinstance(exc, grpc.aio.AioRpcError):
158+
self._record_outcome(client_call_details.method, exc)
159+
else:
160+
self._failure_tracker.record_success()
139161
raise
140-
await self._record_outcome(client_call_details.method, None)
162+
self._record_outcome(client_call_details.method, None)
141163
return response
142164

143-
async def _record_outcome(self, method: str, error: Optional[BaseException]) -> None:
165+
def _record_outcome(self, method: str, error: Optional[BaseException]) -> None:
144166
if error is None:
145167
self._failure_tracker.record_success()
146168
return
147169
status_code = getattr(error, "code", lambda: None)()
148170
if status_code is not None and is_client_transport_failure(method, status_code):
149171
if self._failure_tracker.record_failure():
150-
result = self._on_recreate()
151-
if inspect.isawaitable(result):
152-
# ``_ =`` signals that we await purely for the side effect
153-
# of running the recreate callback to completion; the
154-
# awaitable's return value is intentionally discarded.
155-
_ = await result
172+
self._on_recreate()
156173
else:
157174
self._failure_tracker.record_success()

tests/durabletask/test_client.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ def install_resilient_test_stubs(client):
209209
interception pipeline. This helper wraps the current stub in a thin
210210
interceptor-aware shim, and re-wraps after each ``_maybe_recreate_channel``
211211
call so newly created stubs continue to participate in failure tracking.
212+
213+
The interceptor's ``_on_recreate`` callback is intentionally left alone:
214+
it is the client's ``_schedule_recreate`` (fire-and-forget), and tests
215+
that need to observe the recreate's completion can wait on
216+
``client._recreate_done_event``.
212217
"""
213218
is_async = inspect.iscoroutinefunction(client._maybe_recreate_channel)
214219
wrapper_cls = _ResilientAsyncTestStub if is_async else _ResilientSyncTestStub
@@ -231,7 +236,6 @@ def wrapped_recreate():
231236
wrap_if_needed()
232237

233238
client._maybe_recreate_channel = wrapped_recreate
234-
client._resiliency_interceptor._on_recreate = wrapped_recreate
235239

236240

237241
class FakePayloadStore(PayloadStore):
@@ -557,6 +561,9 @@ def test_sync_client_recreates_sdk_owned_channel_with_original_transport_inputs(
557561
install_resilient_test_stubs(client)
558562
with pytest.raises(FakeRpcError):
559563
client.get_orchestration_state("abc")
564+
# Fire-and-forget recreate runs on a daemon thread; wait for it to
565+
# complete before issuing the call that should see the new channel.
566+
assert client._recreate_done_event.wait(timeout=5.0)
560567
client.get_orchestration_state("abc")
561568

562569
expected_channel_call = call(
@@ -641,8 +648,13 @@ def test_sync_client_close_closes_all_retired_sdk_channels_immediately():
641648
install_resilient_test_stubs(client)
642649
with pytest.raises(FakeRpcError):
643650
client.get_orchestration_state("abc")
651+
# Wait for the first fire-and-forget recreate to complete so the
652+
# single-flight guard in _schedule_recreate does not drop the second
653+
# trigger.
654+
assert client._recreate_done_event.wait(timeout=5.0)
644655
with pytest.raises(FakeRpcError):
645656
client.get_orchestration_state("abc")
657+
assert client._recreate_done_event.wait(timeout=5.0)
646658

647659
client.close()
648660

@@ -746,16 +758,23 @@ def test_sync_client_recreate_cooldown_prevents_immediate_repeated_recreation():
746758
install_resilient_test_stubs(client)
747759
with pytest.raises(FakeRpcError):
748760
client.get_orchestration_state("abc")
761+
# Wait for the fire-and-forget recreate to complete before asserting
762+
# the channel was swapped.
763+
assert client._recreate_done_event.wait(timeout=5.0)
749764
assert client._channel is second_channel
750765
assert mock_get_channel.call_count == 2
751766

752767
with pytest.raises(FakeRpcError):
753768
client.get_orchestration_state("abc")
769+
# Cooldown should fire-and-forget but exit without recreating; wait
770+
# for the no-op recreate to complete so the assertion is deterministic.
771+
assert client._recreate_done_event.wait(timeout=5.0)
754772
assert client._channel is second_channel
755773
assert mock_get_channel.call_count == 2
756774

757775
with pytest.raises(FakeRpcError):
758776
client.get_orchestration_state("abc")
777+
assert client._recreate_done_event.wait(timeout=5.0)
759778
assert client._channel is third_channel
760779

761780
expected_channel_call = call(
@@ -861,6 +880,10 @@ async def test_async_client_recreates_sdk_owned_channel_with_original_transport_
861880
try:
862881
with pytest.raises(grpc.aio.AioRpcError):
863882
await client.get_orchestration_state("abc")
883+
# Fire-and-forget recreate runs as an asyncio task; await its
884+
# completion before issuing the call that should see the new
885+
# channel.
886+
await asyncio.wait_for(client._recreate_done_event.wait(), timeout=5.0)
864887
await client.get_orchestration_state("abc")
865888
finally:
866889
await client.close()

0 commit comments

Comments
 (0)