Skip to content

Commit 7193bb9

Browse files
authored
scatter and replicate to avoid paused workers (#5441)
1 parent 7bab52f commit 7193bb9

4 files changed

Lines changed: 76 additions & 23 deletions

File tree

distributed/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,9 +2048,10 @@ async def _scatter(
20482048
await asyncio.sleep(0.1)
20492049
if time() > start + timeout:
20502050
raise TimeoutError("No valid workers found")
2051-
nthreads = await self.scheduler.ncores(workers=workers)
2051+
# Exclude paused and closing_gracefully workers
2052+
nthreads = await self.scheduler.ncores_running(workers=workers)
20522053
if not nthreads:
2053-
raise ValueError("No valid workers")
2054+
raise ValueError("No valid workers found")
20542055

20552056
_, who_has, nbytes = await scatter_to_workers(
20562057
nthreads, data2, report=False, rpc=self.rpc

distributed/scheduler.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3826,6 +3826,7 @@ def __init__(
38263826
"broadcast": self.broadcast,
38273827
"proxy": self.proxy,
38283828
"ncores": self.get_ncores,
3829+
"ncores_running": self.get_ncores_running,
38293830
"has_what": self.get_has_what,
38303831
"who_has": self.get_who_has,
38313832
"processing": self.get_processing,
@@ -5710,18 +5711,24 @@ async def scatter(
57105711
Scheduler.broadcast:
57115712
"""
57125713
parent: SchedulerState = cast(SchedulerState, self)
5714+
ws: WorkerState
5715+
57135716
start = time()
5714-
while not parent._workers_dv:
5715-
await asyncio.sleep(0.2)
5717+
while True:
5718+
if workers is None:
5719+
wss = parent._running
5720+
else:
5721+
workers = [self.coerce_address(w) for w in workers]
5722+
wss = {parent._workers_dv[w] for w in workers}
5723+
wss = {ws for ws in wss if ws._status == Status.running}
5724+
5725+
if wss:
5726+
break
57165727
if time() > start + timeout:
5717-
raise TimeoutError("No workers found")
5728+
raise TimeoutError("No valid workers found")
5729+
await asyncio.sleep(0.1)
57185730

5719-
if workers is None:
5720-
ws: WorkerState
5721-
nthreads = {w: ws._nthreads for w, ws in parent._workers_dv.items()}
5722-
else:
5723-
workers = [self.coerce_address(w) for w in workers]
5724-
nthreads = {w: parent._workers_dv[w].nthreads for w in workers}
5731+
nthreads = {ws._address: ws.nthreads for ws in wss}
57255732

57265733
assert isinstance(data, dict)
57275734

@@ -5732,10 +5739,7 @@ async def scatter(
57325739
self.update_data(who_has=who_has, nbytes=nbytes, client=client)
57335740

57345741
if broadcast:
5735-
if broadcast == True: # noqa: E712
5736-
n = len(nthreads)
5737-
else:
5738-
n = broadcast
5742+
n = len(nthreads) if broadcast is True else broadcast
57395743
await self.replicate(keys=keys, workers=workers, n=n)
57405744

57415745
self.log_event(
@@ -6451,7 +6455,12 @@ async def replicate(
64516455

64526456
assert branching_factor > 0
64536457
async with self._lock if lock else empty_context:
6454-
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
6458+
if workers is not None:
6459+
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
6460+
workers = {ws for ws in workers if ws._status == Status.running}
6461+
else:
6462+
workers = parent._running
6463+
64556464
if n is None:
64566465
n = len(workers)
64576466
else:
@@ -6989,6 +6998,15 @@ def get_ncores(self, comm=None, workers=None):
69896998
else:
69906999
return {w: ws._nthreads for w, ws in parent._workers_dv.items()}
69917000

7001+
def get_ncores_running(self, comm=None, workers=None):
7002+
parent: SchedulerState = cast(SchedulerState, self)
7003+
ncores = self.get_ncores(workers=workers)
7004+
return {
7005+
w: n
7006+
for w, n in ncores.items()
7007+
if parent._workers_dv[w].status == Status.running
7008+
}
7009+
69927010
async def get_call_stack(self, comm=None, keys=None):
69937011
parent: SchedulerState = cast(SchedulerState, self)
69947012
ts: TaskState

distributed/tests/test_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5734,6 +5734,33 @@ def bad_fn(x):
57345734
assert y.status == "error" # not cancelled
57355735

57365736

5737+
@pytest.mark.parametrize("workers_arg", [False, True])
5738+
@pytest.mark.parametrize("direct", [False, True])
5739+
@pytest.mark.parametrize("broadcast", [False, True, 10])
5740+
@gen_cluster(client=True, nthreads=[("", 1)] * 10)
5741+
async def test_scatter_and_replicate_avoid_paused_workers(
5742+
c, s, *workers, workers_arg, direct, broadcast
5743+
):
5744+
paused_workers = [w for i, w in enumerate(workers) if i not in (3, 7)]
5745+
for w in paused_workers:
5746+
w.memory_pause_fraction = 1e-15
5747+
while any(s.workers[w.address].status != Status.paused for w in paused_workers):
5748+
await asyncio.sleep(0.01)
5749+
5750+
f = await c.scatter(
5751+
{"x": 1},
5752+
workers=[w.address for w in workers[1:-1]] if workers_arg else None,
5753+
broadcast=broadcast,
5754+
direct=direct,
5755+
)
5756+
if not broadcast:
5757+
await c.replicate(f, n=10)
5758+
5759+
expect = [i in (3, 7) for i in range(10)]
5760+
actual = [("x" in w.data) for w in workers]
5761+
assert actual == expect
5762+
5763+
57375764
@pytest.mark.xfail(reason="GH#5409 Dask-Default-Threads are frequently detected")
57385765
def test_no_threads_lingering():
57395766
if threading.active_count() < 40:

distributed/tests/test_scheduler.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -784,20 +784,27 @@ async def test_story(c, s, a, b):
784784
assert s.story(x.key) == s.story(s.tasks[x.key])
785785

786786

787-
@gen_cluster(nthreads=[], client=True)
788-
async def test_scatter_no_workers(c, s):
787+
@pytest.mark.parametrize("direct", [False, True])
788+
@gen_cluster(client=True, nthreads=[])
789+
async def test_scatter_no_workers(c, s, direct):
789790
with pytest.raises(TimeoutError):
790791
await s.scatter(data={"x": 1}, client="alice", timeout=0.1)
791792

792793
start = time()
793794
with pytest.raises(TimeoutError):
794-
await c.scatter(123, timeout=0.1)
795+
await c.scatter(123, timeout=0.1, direct=direct)
795796
assert time() < start + 1.5
796797

797-
w = Worker(s.address, nthreads=3)
798-
await asyncio.gather(c.scatter(data={"y": 2}, timeout=5), w)
799-
800-
assert w.data["y"] == 2
798+
fut = c.scatter({"y": 2}, timeout=5, direct=direct)
799+
await asyncio.sleep(0.1)
800+
async with Worker(s.address) as w:
801+
await fut
802+
assert w.data["y"] == 2
803+
804+
# Test race condition between worker init and scatter
805+
w = Worker(s.address)
806+
await asyncio.gather(c.scatter({"z": 3}, timeout=5, direct=direct), w)
807+
assert w.data["z"] == 3
801808
await w.close()
802809

803810

0 commit comments

Comments
 (0)