Skip to content

Commit 1579834

Browse files
committed
retire_workers to use AMM
1 parent 1585f85 commit 1579834

3 files changed

Lines changed: 171 additions & 78 deletions

File tree

distributed/active_memory_manager.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,59 @@ def run(self):
344344
# ndrop could be negative, which for range() is the same as 0.
345345
for _ in range(ndrop):
346346
yield "drop", ts, None
347+
348+
349+
class RetireWorker(ActiveMemoryManagerPolicy):
350+
"""Replicate somewhere else all unique keys on a worker, preparing for its shutdown.
351+
Once the worker has been retired, this policy automatically removes itself from the
352+
Active Memory Manager it's attached to.
353+
354+
**Retiring a worker with spilled keys**
355+
356+
On its very first iteration, this policy suggests other workers to fetch all unique
357+
in-memory tasks. Frequently, this means that in the next few moments the worker to
358+
be retired will be bombarded with ``Worker.get_data`` calls from the rest of the
359+
cluster. This can be a problem if most of the managed memory of the worker has been
360+
spilled out, as it could send the worker above the terminate threshold.
361+
Two things are in place in order to avoid this:
362+
363+
1. At every iteration, this policy drops all keys that have already been replicated
364+
somewhere else. This makes room for further keys to be moved out of the spill
365+
file in order to be replicated onto another worker.
366+
2. Once a worker passes the ``pause`` threshold, ``Worker.get_data`` throttles the
367+
number of outgoing connections to 1.
368+
"""
369+
370+
address: str
371+
372+
def __init__(self, address: str):
373+
self.address = address
374+
375+
def __repr__(self) -> str:
376+
return f"RetireWorker({self.address}, done={self.done})"
377+
378+
def run(self):
379+
ws = self.manager.scheduler.workers.get(self.address)
380+
if ws is None:
381+
self.manager.policies.remove(self)
382+
return
383+
384+
for ts in ws.has_what:
385+
if len(ts.who_has) == 1:
386+
yield "replicate", ts, None
387+
else:
388+
# This may be rejected by either the AMM (see _find_dropper) or by the
389+
# Worker; if so we'll try again at the next iteration
390+
yield "drop", ts, {ws}
391+
392+
def done(self) -> bool:
393+
"""Return True if it is safe to close the worker down, or False otherwise. True
394+
doesn't necessarily mean that run() won't issue any more suggestions - it could
395+
continue issuing ``drop`` suggestions afterwards.
396+
"""
397+
ws = self.manager.scheduler.workers.get(self.address)
398+
if ws is None:
399+
return True
400+
if ws.processing:
401+
return False
402+
return all(len(ts.who_has) > 1 for ts in ws.has_what)

distributed/scheduler.py

Lines changed: 91 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
from . import preloading, profile
5353
from . import versions as version_module
54-
from .active_memory_manager import ActiveMemoryManagerExtension
54+
from .active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
5555
from .batched import BatchedSend
5656
from .comm import (
5757
get_address_host,
@@ -6610,11 +6610,11 @@ def _key(group):
66106610
async def retire_workers(
66116611
self,
66126612
comm=None,
6613-
workers=None,
6614-
remove=True,
6615-
close_workers=False,
6616-
names=None,
6617-
lock=True,
6613+
*,
6614+
workers: "list[str] | None" = None,
6615+
names: "list[str] | None" = None,
6616+
close_workers: bool = False,
6617+
remove: bool = True,
66186618
**kwargs,
66196619
) -> dict:
66206620
"""Gracefully retire workers from cluster
@@ -6623,16 +6623,18 @@ async def retire_workers(
66236623
----------
66246624
workers: list (optional)
66256625
List of worker addresses to retire.
6626-
If not provided we call ``workers_to_close`` which finds a good set
66276626
names: list (optional)
66286627
List of worker names to retire.
6629-
remove: bool (defaults to True)
6630-
Whether or not to remove the worker metadata immediately or else
6631-
wait for the worker to contact us
6628+
Mutually exclusive with ``workers``.
6629+
If neither ``workers`` nor ``names`` are provided, we call
6630+
``workers_to_close`` which finds a good set.
66326631
close_workers: bool (defaults to False)
66336632
Whether or not to actually close the worker explicitly from here.
66346633
Otherwise we expect some external job scheduler to finish off the
66356634
worker.
6635+
remove: bool (defaults to True)
6636+
Whether or not to remove the worker metadata immediately or else
6637+
wait for the worker to contact us
66366638
**kwargs: dict
66376639
Extra options to pass to workers_to_close to determine which
66386640
workers we should drop
@@ -6650,78 +6652,91 @@ async def retire_workers(
66506652
ws: WorkerState
66516653
ts: TaskState
66526654
with log_errors():
6653-
async with self._lock if lock else empty_context:
6654-
if names is not None:
6655-
if workers is not None:
6656-
raise TypeError("names and workers are mutually exclusive")
6657-
if names:
6658-
logger.info("Retire worker names %s", names)
6659-
names = set(map(str, names))
6660-
workers = {
6661-
ws._address
6662-
for ws in parent._workers_dv.values()
6663-
if str(ws._name) in names
6664-
}
6665-
elif workers is None:
6666-
while True:
6667-
try:
6668-
workers = self.workers_to_close(**kwargs)
6669-
if not workers:
6670-
return {}
6671-
return await self.retire_workers(
6672-
workers=workers,
6673-
remove=remove,
6674-
close_workers=close_workers,
6675-
lock=False,
6676-
)
6677-
except KeyError: # keys left during replicate
6678-
pass
6679-
6680-
workers = {
6681-
parent._workers_dv[w] for w in workers if w in parent._workers_dv
6655+
if names is not None:
6656+
if workers is not None:
6657+
raise TypeError("names and workers are mutually exclusive")
6658+
if names:
6659+
logger.info("Retire worker names %s", names)
6660+
names_set = {str(name) for name in names}
6661+
wss = {
6662+
ws
6663+
for ws in parent._workers_dv.values()
6664+
if str(ws._name) in names_set
66826665
}
6683-
if not workers:
6684-
return {}
6685-
logger.info("Retire workers %s", workers)
6686-
6687-
# Keys orphaned by retiring those workers
6688-
keys = {k for w in workers for k in w.has_what}
6689-
keys = {ts._key for ts in keys if ts._who_has.issubset(workers)}
6690-
6691-
if keys:
6692-
other_workers = set(parent._workers_dv.values()) - workers
6693-
if not other_workers:
6694-
return {}
6695-
logger.info("Moving %d keys to other workers", len(keys))
6696-
await self.replicate(
6697-
keys=keys,
6698-
workers=[ws._address for ws in other_workers],
6699-
n=1,
6700-
delete=False,
6701-
lock=False,
6702-
)
6666+
elif workers is not None:
6667+
wss = {
6668+
parent._workers_dv[address]
6669+
for address in workers
6670+
if address in parent._workers_dv
6671+
}
6672+
else:
6673+
wss = {
6674+
parent._workers_dv[address]
6675+
for address in self.workers_to_close(**kwargs)
6676+
}
6677+
if not wss:
6678+
return {}
67036679

6704-
worker_keys = {ws._address: ws.identity() for ws in workers}
6705-
if close_workers:
6706-
await asyncio.gather(
6707-
*[self.close_worker(worker=w, safe=True) for w in worker_keys]
6708-
)
6709-
if remove:
6680+
workers_info = {ws._address: ws.identity() for ws in wss}
6681+
6682+
stop_amm = False
6683+
amm: ActiveMemoryManagerExtension = self.extensions["amm"]
6684+
if not amm.started:
6685+
amm = ActiveMemoryManagerExtension(
6686+
self, register=False, start=True, interval=2.0
6687+
)
6688+
stop_amm = True
6689+
6690+
# This lock makes retire_workers, rebalance, and replicate mutually
6691+
# exclusive and will no longer be necessary once rebalance and replicate are
6692+
# migrated to the Active Memory Manager.
6693+
async with self._lock:
6694+
try:
67106695
await asyncio.gather(
6711-
*[self.remove_worker(address=w, safe=True) for w in worker_keys]
6696+
*(
6697+
self._retire_worker(ws, amm, close_workers, remove)
6698+
for ws in wss
6699+
)
67126700
)
6701+
finally:
6702+
if stop_amm:
6703+
amm.stop()
67136704

6714-
self.log_event(
6715-
"all",
6716-
{
6717-
"action": "retire-workers",
6718-
"workers": worker_keys,
6719-
"moved-keys": len(keys),
6720-
},
6721-
)
6722-
self.log_event(list(worker_keys), {"action": "retired"})
6705+
self.log_event("all", {"action": "retire-workers", "workers": workers_info})
6706+
self.log_event(list(workers_info), {"action": "retired"})
6707+
6708+
return workers_info
6709+
6710+
async def _retire_worker(
6711+
self,
6712+
ws: WorkerState,
6713+
amm: ActiveMemoryManagerExtension,
6714+
close_workers: bool,
6715+
remove: bool,
6716+
) -> None:
6717+
logger.info("Retiring worker %s", ws)
6718+
6719+
policy = RetireWorker(ws._address)
6720+
amm.add_policy(policy)
6721+
6722+
ws.status = Status.closing_gracefully
6723+
self.running.discard(ws)
6724+
self.stream_comms[ws.address].send(
6725+
{"op": "worker-status-change", "status": ws.status.name}
6726+
)
6727+
6728+
while not policy.done():
6729+
# Sleep 0.1s when there are 10 in-memory tasks or less
6730+
# Sleep 3s when there are 300 or more
6731+
poll_interval = max(0.1, min(3.0, len(ws.has_what) / 100))
6732+
await asyncio.sleep(poll_interval)
6733+
6734+
if close_workers and ws._address in self._workers_dv:
6735+
await self.close_worker(worker=ws._address, safe=True)
6736+
if remove:
6737+
await self.remove_worker(address=ws._address, safe=True)
67236738

6724-
return worker_keys
6739+
logger.info("Retired worker %s", ws)
67256740

67266741
def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None):
67276742
"""

distributed/worker.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,7 @@ def __init__(
885885
"free-keys": self.handle_free_keys,
886886
"remove-replicas": self.handle_remove_replicas,
887887
"steal-request": self.handle_steal_request,
888+
"worker-status-change": self.handle_worker_status_change,
888889
}
889890

890891
super().__init__(
@@ -1551,7 +1552,9 @@ async def close_gracefully(self, restart=None):
15511552
restart = self.lifetime_restart
15521553

15531554
logger.info("Closing worker gracefully: %s", self.address)
1554-
self.status = Status.closing_gracefully
1555+
# Wait for all tasks to leave the worker and don't accept any new ones.
1556+
# Scheduler.retire_workers will set the status to closing_gracefully and push it
1557+
# back to this worker.
15551558
await self.scheduler.retire_workers(workers=[self.address], remove=False)
15561559
await self.close(safe=True, nanny=not restart)
15571560

@@ -2862,6 +2865,24 @@ def handle_steal_request(self, key, stimulus_id):
28622865
# `transition_constrained_executing`
28632866
self.transition(ts, "forgotten", stimulus_id=stimulus_id)
28642867

2868+
def handle_worker_status_change(self, status: str) -> None:
2869+
new_status = Status.lookup[status] # type: ignore
2870+
2871+
if new_status == Status.closing_gracefully and self.status not in (
2872+
Status.running,
2873+
Status.paused,
2874+
):
2875+
logger.error(
2876+
"Invalid Worker.status transition: %s -> %s", self.status, new_status
2877+
)
2878+
# Reiterate the current status to the scheduler to restore sync
2879+
# (see status.setter)
2880+
self.status = self.status
2881+
return
2882+
2883+
# Update status and send confirmation to the Scheduler (see status.setter).
2884+
self.status = new_status
2885+
28652886
def release_key(
28662887
self,
28672888
key: str,
@@ -3073,7 +3094,8 @@ async def _maybe_deserialize_task(self, ts, *, stimulus_id):
30733094
raise
30743095

30753096
def ensure_computing(self):
3076-
if self.status == Status.paused:
3097+
if self.status in (Status.paused, Status.closing_gracefully):
3098+
# Pending tasks shall be stolen
30773099
return
30783100
try:
30793101
stimulus_id = f"ensure-computing-{time()}"

0 commit comments

Comments
 (0)