@@ -3826,6 +3826,7 @@ def __init__(
38263826 "broadcast" : self .broadcast ,
38273827 "proxy" : self .proxy ,
38283828 "ncores" : self .get_ncores ,
3829+ "ncores_running" : self .get_ncores_running ,
38293830 "has_what" : self .get_has_what ,
38303831 "who_has" : self .get_who_has ,
38313832 "processing" : self .get_processing ,
@@ -5710,18 +5711,24 @@ async def scatter(
57105711 Scheduler.broadcast:
57115712 """
57125713 parent : SchedulerState = cast (SchedulerState , self )
5714+ ws : WorkerState
5715+
57135716 start = time ()
5714- while not parent ._workers_dv :
5715- await asyncio .sleep (0.2 )
5717+ while True :
5718+ if workers is None :
5719+ wss = parent ._running
5720+ else :
5721+ workers = [self .coerce_address (w ) for w in workers ]
5722+ wss = {parent ._workers_dv [w ] for w in workers }
5723+ wss = {ws for ws in wss if ws ._status == Status .running }
5724+
5725+ if wss :
5726+ break
57165727 if time () > start + timeout :
5717- raise TimeoutError ("No workers found" )
5728+ raise TimeoutError ("No valid workers found" )
5729+ await asyncio .sleep (0.1 )
57185730
5719- if workers is None :
5720- ws : WorkerState
5721- nthreads = {w : ws ._nthreads for w , ws in parent ._workers_dv .items ()}
5722- else :
5723- workers = [self .coerce_address (w ) for w in workers ]
5724- nthreads = {w : parent ._workers_dv [w ].nthreads for w in workers }
5731+ nthreads = {ws ._address : ws .nthreads for ws in wss }
57255732
57265733 assert isinstance (data , dict )
57275734
@@ -5732,10 +5739,7 @@ async def scatter(
57325739 self .update_data (who_has = who_has , nbytes = nbytes , client = client )
57335740
57345741 if broadcast :
5735- if broadcast == True : # noqa: E712
5736- n = len (nthreads )
5737- else :
5738- n = broadcast
5742+ n = len (nthreads ) if broadcast is True else broadcast
57395743 await self .replicate (keys = keys , workers = workers , n = n )
57405744
57415745 self .log_event (
@@ -6451,7 +6455,12 @@ async def replicate(
64516455
64526456 assert branching_factor > 0
64536457 async with self ._lock if lock else empty_context :
6454- workers = {parent ._workers_dv [w ] for w in self .workers_list (workers )}
6458+ if workers is not None :
6459+ workers = {parent ._workers_dv [w ] for w in self .workers_list (workers )}
6460+ workers = {ws for ws in workers if ws ._status == Status .running }
6461+ else :
6462+ workers = parent ._running
6463+
64556464 if n is None :
64566465 n = len (workers )
64576466 else :
@@ -6989,6 +6998,15 @@ def get_ncores(self, comm=None, workers=None):
69896998 else :
69906999 return {w : ws ._nthreads for w , ws in parent ._workers_dv .items ()}
69917000
7001+ def get_ncores_running (self , comm = None , workers = None ):
7002+ parent : SchedulerState = cast (SchedulerState , self )
7003+ ncores = self .get_ncores (workers = workers )
7004+ return {
7005+ w : n
7006+ for w , n in ncores .items ()
7007+ if parent ._workers_dv [w ].status == Status .running
7008+ }
7009+
69927010 async def get_call_stack (self , comm = None , keys = None ):
69937011 parent : SchedulerState = cast (SchedulerState , self )
69947012 ts : TaskState
0 commit comments