From 40508f68111c7e13e7fc3836745bf253644290e7 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 1 Feb 2026 15:08:24 -0800 Subject: [PATCH 1/3] Update client.py --- distributed/client.py | 408 +++++++++++------------------------------- 1 file changed, 105 insertions(+), 303 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 415394ab17..b45c8b98ef 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -135,9 +135,7 @@ logger = logging.getLogger(__name__) -_global_clients: weakref.WeakValueDictionary[int, Client] = ( - weakref.WeakValueDictionary() -) +_global_clients: weakref.WeakValueDictionary[int, Client] = weakref.WeakValueDictionary() _global_client_index = [0] _current_client: ContextVar[Client | None] = ContextVar("_current_client", default=None) @@ -173,16 +171,12 @@ class FuturesCancelledError(CancelledError): error_groups: list[CancelledFuturesGroup] def __init__(self, error_groups: list[CancelledFuturesGroup]): - self.error_groups = sorted( - error_groups, key=lambda group: len(group.errors), reverse=True - ) + self.error_groups = sorted(error_groups, key=lambda group: len(group.errors), reverse=True) def __str__(self): count = sum(map(lambda group: len(group.errors), self.error_groups)) result = f"{count} Future{'s' if count > 1 else ''} cancelled:" - return "\n".join( - [result, "Reasons:"] + [str(group) for group in self.error_groups] - ) + return "\n".join([result, "Reasons:"] + [str(group) for group in self.error_groups]) class CancelledFuturesGroup: @@ -358,13 +352,20 @@ def executor(self): return self.client @property - def status(self): + def status(self) -> Literal["pending", "cancelled", "finished", "lost", "error"] | None: """Returns the status Returns ------- str The status + The status of the future. Possible values: + - "pending": The future is waiting to be computed + - "finished": The future has completed successfully + - "error": The future encountered an error during computation + - "cancelled": The future was cancelled + - "lost": The future's data was lost from memory + - None: The future is not yet bound to a client """ if self._state: return self._state.status @@ -476,9 +477,7 @@ def add_done_callback(self, fn): cls = Future if cls._cb_executor is None or cls._cb_executor_pid != os.getpid(): try: - cls._cb_executor = ThreadPoolExecutor( - 1, thread_name_prefix="Dask-Callback-Thread" - ) + cls._cb_executor = ThreadPoolExecutor(1, thread_name_prefix="Dask-Callback-Thread") except TypeError: cls._cb_executor = ThreadPoolExecutor(1) cls._cb_executor_pid = os.getpid() @@ -490,9 +489,7 @@ def execute_callback(fut): logger.exception("Error in callback %s of %s:", fn, fut) raise - self.client.loop.add_callback( - done_callback, self, partial(cls._cb_executor.submit, execute_callback) - ) + self.client.loop.add_callback(done_callback, self, partial(cls._cb_executor.submit, execute_callback)) def cancel(self, reason=None, msg=None, **kwargs): """Cancel the request to run this future @@ -610,9 +607,7 @@ def __str__(self): def __repr__(self): if self.type: - return ( - f"" - ) + return f"" else: return f"" @@ -645,7 +640,7 @@ def __init__(self, key: str): self._event = None self.key = key self.exception = None - self.status = "pending" + self.status: Literal["pending", "cancelled", "finished", "lost", "error"] = "pending" self.traceback = None self.type = None @@ -768,9 +763,7 @@ def _handle_print(event): if not isinstance(args, tuple): # worker.print() will always send us a tuple of args, even if it's an # empty tuple. - raise TypeError( - f"_handle_print: client received non-tuple print args: {args!r}" - ) + raise TypeError(f"_handle_print: client received non-tuple print args: {args!r}") file = msg.get("file") if file == 1: @@ -778,13 +771,9 @@ def _handle_print(event): elif file == 2: file = sys.stderr elif file is not None: - raise TypeError( - f"_handle_print: client received unsupported file kwarg: {file!r}" - ) + raise TypeError(f"_handle_print: client received unsupported file kwarg: {file!r}") - print( - *args, sep=msg.get("sep"), end=msg.get("end"), file=file, flush=msg.get("flush") - ) + print(*args, sep=msg.get("sep"), end=msg.get("end"), file=file, flush=msg.get("flush")) def _handle_warn(event): @@ -798,10 +787,7 @@ def _handle_warn(event): if "message" not in msg: # TypeError makes sense here because it's analogous to calling a # function without a required positional argument - raise TypeError( - "_handle_warn: client received a warn event missing the required " - '"message" argument.' - ) + raise TypeError('_handle_warn: client received a warn event missing the required "message" argument.') if "category" in msg: category = pickle.loads(msg["category"]) else: @@ -838,11 +824,7 @@ class VersionsDict(TypedDict): def _is_nested(iterable): for item in iterable: - if ( - isinstance(item, Iterable) - and not isinstance(item, str) - and not isinstance(item, bytes) - ): + if isinstance(item, Iterable) and not isinstance(item, str) and not isinstance(item, bytes): return True return False @@ -891,10 +873,7 @@ def keys(self) -> Iterable[Key]: else: uid = str(uuid.uuid4()) keys = ( - [ - f"{self.key}-{uid}-{i}" - for i in range(min(map(len, self.iterables))) - ] + [f"{self.key}-{uid}-{i}" for i in range(min(map(len, self.iterables)))] if self.iterables else [] ) @@ -1071,11 +1050,7 @@ def __init__( self._handle_report_task = None if name is None: name = dask.config.get("client-name", None) - self.id = ( - type(self).__name__ - + ("-" + name + "-" if name else "-") - + str(uuid.uuid1(clock_seq=os.getpid())) - ) + self.id = type(self).__name__ + ("-" + name + "-" if name else "-") + str(uuid.uuid1(clock_seq=os.getpid())) self.generation = 0 self.status = "newly-created" self._pending_msg_buffer = [] @@ -1118,17 +1093,13 @@ def __init__( self.cluster = address status = self.cluster.status if status in (Status.closed, Status.closing): - raise RuntimeError( - f"Trying to connect to an already closed or closing Cluster {self.cluster}." - ) + raise RuntimeError(f"Trying to connect to an already closed or closing Cluster {self.cluster}.") with suppress(AttributeError): loop = address.loop if security is None: security = getattr(self.cluster, "security", None) elif address is not None and not isinstance(address, str): - raise TypeError( - f"Scheduler address must be a string or a Cluster instance, got {type(address)}" - ) + raise TypeError(f"Scheduler address must be a string or a Cluster instance, got {type(address)}") # If connecting to an address and no explicit security is configured, attempt # to load security credentials with a security loader (if configured). @@ -1171,9 +1142,7 @@ def __init__( self._periodic_callbacks["scheduler-info"] = PeriodicCallback( self._update_scheduler_info, scheduler_info_interval * 1000 ) - self._periodic_callbacks["heartbeat"] = PeriodicCallback( - self._heartbeat, heartbeat_interval * 1000 - ) + self._periodic_callbacks["heartbeat"] = PeriodicCallback(self._heartbeat, heartbeat_interval * 1000) self._start_arg = address self._set_as_default = set_as_default @@ -1209,9 +1178,7 @@ def __init__( server=self, ) - self.extensions = { - name: extension(self) for name, extension in extensions.items() - } + self.extensions = {name: extension(self) for name, extension in extensions.items()} preload = dask.config.get("distributed.client.preload") preload_argv = dask.config.get("distributed.client.preload-argv") @@ -1226,16 +1193,12 @@ def __init__( @property def io_loop(self) -> IOLoop | None: - warnings.warn( - "The io_loop property is deprecated", DeprecationWarning, stacklevel=2 - ) + warnings.warn("The io_loop property is deprecated", DeprecationWarning, stacklevel=2) return self.loop @io_loop.setter def io_loop(self, value: IOLoop) -> None: - warnings.warn( - "The io_loop property is deprecated", DeprecationWarning, stacklevel=2 - ) + warnings.warn("The io_loop property is deprecated", DeprecationWarning, stacklevel=2) self.loop = value @property @@ -1251,9 +1214,7 @@ def loop(self) -> IOLoop | None: @loop.setter def loop(self, value: IOLoop) -> None: - warnings.warn( - "setting the loop property is deprecated", DeprecationWarning, stacklevel=2 - ) + warnings.warn("setting the loop property is deprecated", DeprecationWarning, stacklevel=2) self.__loop = value @contextmanager @@ -1342,16 +1303,10 @@ def dashboard_link(self): def _get_scheduler_info(self, n_workers): from distributed.scheduler import Scheduler - if ( - self.cluster - and hasattr(self.cluster, "scheduler") - and isinstance(self.cluster.scheduler, Scheduler) - ): + if self.cluster and hasattr(self.cluster, "scheduler") and isinstance(self.cluster.scheduler, Scheduler): info = self.cluster.scheduler.identity(n_workers=n_workers) scheduler = self.cluster.scheduler - elif ( - self._loop_runner.is_started() and self.scheduler and not self.asynchronous - ): + elif self._loop_runner.is_started() and self.scheduler and not self.asynchronous: info = sync(self.loop, self.scheduler.identity, n_workers=n_workers) scheduler = self.scheduler else: @@ -1440,9 +1395,7 @@ def _send_to_scheduler(self, msg): if self.status in ("running", "closing", "connecting", "newly-created"): self.loop.add_callback(self._send_to_scheduler_safe, msg) else: - raise ClosedClientError( - f"Client is {self.status}. Can't send {msg['op']} message." - ) + raise ClosedClientError(f"Client is {self.status}. Can't send {msg['op']} message.") async def _start(self, timeout=no_default, **kwargs): self.status = "connecting" @@ -1519,10 +1472,7 @@ async def _reconnect(self): for st in self.futures.values(): st.cancel( reason="scheduler-connection-lost", - msg=( - "Client lost the connection to the scheduler. " - "Please check your connection and re-run your work." - ), + msg=("Client lost the connection to the scheduler. Please check your connection and re-run your work."), ) self.futures.clear() @@ -1542,8 +1492,7 @@ async def _reconnect(self): else: logger.error( - "Failed to reconnect to scheduler after %.2f " - "seconds, closing client", + "Failed to reconnect to scheduler after %.2f seconds, closing client", self._timeout, ) await self._close() @@ -1560,9 +1509,7 @@ async def _ensure_connected(self, timeout=None): self._connecting_to_scheduler = True try: - comm = await connect( - self.scheduler.address, timeout=timeout, **self.connection_args - ) + comm = await connect(self.scheduler.address, timeout=timeout, **self.connection_args) comm.name = "Client->Scheduler" if timeout is not None: await wait_for(self._update_scheduler_info(), timeout) @@ -1610,15 +1557,11 @@ async def _update_scheduler_info(self, n_workers=5): if self.status not in ("running", "connecting") or self.scheduler is None: return try: - self._scheduler_identity = SchedulerInfo( - await self.scheduler.identity(n_workers=n_workers) - ) + self._scheduler_identity = SchedulerInfo(await self.scheduler.identity(n_workers=n_workers)) except OSError: logger.debug("Not able to query scheduler for identity") - async def _wait_for_workers( - self, n_workers: int, timeout: float | None = None - ) -> None: + async def _wait_for_workers(self, n_workers: int, timeout: float | None = None) -> None: info = await self.scheduler.identity(n_workers=-1) self._scheduler_identity = SchedulerInfo(info) if timeout: @@ -1627,13 +1570,7 @@ async def _wait_for_workers( deadline = None def running_workers(info): - return len( - [ - ws - for ws in info["workers"].values() - if ws["status"] == Status.running.name - ] - ) + return len([ws for ws in info["workers"].values() if ws["status"] == Status.running.name]) while running_workers(info) < n_workers: if deadline and time() > deadline: @@ -1655,9 +1592,7 @@ def wait_for_workers(self, n_workers: int, timeout: float | None = None) -> None ``dask.distributed.TimeoutError`` """ if not isinstance(n_workers, int) or n_workers < 1: - raise ValueError( - f"`n_workers` must be a positive integer. Instead got {n_workers}." - ) + raise ValueError(f"`n_workers` must be a positive integer. Instead got {n_workers}.") if self.cluster and hasattr(self.cluster, "wait_for_workers"): return self.cluster.wait_for_workers(n_workers, timeout) @@ -1698,8 +1633,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): if not e.args[0].endswith(" was created in a different Context"): raise # pragma: nocover warnings.warn( - "It is deprecated to enter and exit the Client context " - "manager from different tasks", + "It is deprecated to enter and exit the Client context manager from different tasks", DeprecationWarning, stacklevel=2, ) @@ -1717,8 +1651,7 @@ def __exit__(self, exc_type, exc_value, traceback): if not e.args[0].endswith(" was created in a different Context"): raise # pragma: nocover warnings.warn( - "It is deprecated to enter and exit the Client context " - "manager from different threads", + "It is deprecated to enter and exit the Client context manager from different threads", DeprecationWarning, stacklevel=2, ) @@ -1748,9 +1681,7 @@ def _release_key(self, key): if st is not None: st.cancel() if self.status != "closed": - self._send_to_scheduler( - {"op": "client-releases-keys", "keys": [key], "client": self.id} - ) + self._send_to_scheduler({"op": "client-releases-keys", "keys": [key], "client": self.id}) @log_errors async def _handle_report(self): @@ -1866,9 +1797,7 @@ async def _wait_for_handle_report_task(self, fast=False): handle_report_task = self._handle_report_task # Give the scheduler 'stream-closed' message 100ms to come through # This makes the shutdown slightly smoother and quieter - should_wait = ( - handle_report_task is not None and handle_report_task is not current_task - ) + should_wait = handle_report_task is not None and handle_report_task is not current_task if should_wait: with suppress(asyncio.CancelledError, TimeoutError): await wait_for(asyncio.shield(handle_report_task), 0.1) @@ -1910,19 +1839,11 @@ async def _close(self, fast: bool = False) -> None: if self.get == dask.config.get("get", None): del dask.config.config["get"] - if ( - self.scheduler_comm - and self.scheduler_comm.comm - and not self.scheduler_comm.comm.closed() - ): + if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed(): self._send_to_scheduler({"op": "close-client"}) self._send_to_scheduler({"op": "close-stream"}) async with self._wait_for_handle_report_task(fast=fast): - if ( - self.scheduler_comm - and self.scheduler_comm.comm - and not self.scheduler_comm.comm.closed() - ): + if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed(): await self.scheduler_comm.close() for key in list(self.futures): @@ -2280,18 +2201,14 @@ def map( if not callable(func): raise TypeError("First input to map must be a callable function") - if all(isinstance(it, pyQueue) for it in iterables) or all( - isinstance(i, Iterator) for i in iterables - ): + if all(isinstance(it, pyQueue) for it in iterables) or all(isinstance(i, Iterator) for i in iterables): raise TypeError( "Dask no longer supports mapping over Iterators or Queues." "Consider using a normal for loop and Client.submit" ) total_length = sum(len(x) for x in iterables) if batch_size and batch_size > 1 and total_length > batch_size: - batches = list( - zip(*(partition_all(batch_size, iterable) for iterable in iterables)) - ) + batches = list(zip(*(partition_all(batch_size, iterable) for iterable in iterables))) keys: list[list[Any]] | list[Any] if isinstance(key, list): keys = [list(element) for element in partition_all(batch_size, key)] @@ -2427,9 +2344,7 @@ async def wait(k): keys = [k for k in keys if k not in bad_keys and k not in data] if local_worker: # look inside local worker - data.update( - {k: local_worker.data[k] for k in keys if k in local_worker.data} - ) + data.update({k: local_worker.data[k] for k in keys if k in local_worker.data}) keys = [k for k in keys if k not in data] # We now do an actual remote communication with workers or scheduler @@ -2438,9 +2353,7 @@ async def wait(k): response = await self._gather_future else: # no one waiting, go ahead self._gather_keys = set(keys) - future = asyncio.ensure_future( - self._gather_remote(direct, local_worker) - ) + future = asyncio.ensure_future(self._gather_remote(direct, local_worker)) if self._gather_keys is None: self._gather_future = None else: @@ -2485,14 +2398,10 @@ async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, An if direct or local_worker: # gather directly from workers who_has = await retry_operation(self.scheduler.who_has, keys=keys) - data, missing_keys, failed_keys, _ = await gather_from_workers( - who_has, rpc=self.rpc - ) + data, missing_keys, failed_keys, _ = await gather_from_workers(who_has, rpc=self.rpc) response: dict[str, Any] = {"status": "OK", "data": data} if missing_keys or failed_keys: - response = await retry_operation( - self.scheduler.gather, keys=missing_keys + failed_keys - ) + response = await retry_operation(self.scheduler.gather, keys=missing_keys + failed_keys) if response["status"] == "OK": response["data"].update(data) @@ -2641,9 +2550,7 @@ async def _scatter( _, who_has, nbytes = await scatter_to_workers(workers, data2, self.rpc) - await self.scheduler.update_data( - who_has=who_has, nbytes=nbytes, client=self.id - ) + await self.scheduler.update_data(who_has=who_has, nbytes=nbytes, client=self.id) else: await self.scheduler.scatter( data=data2, @@ -2876,10 +2783,7 @@ async def _(): if name: if len(args) == 0: - raise ValueError( - "If name is provided, expecting call signature like" - " publish_dataset(df, name='ds')" - ) + raise ValueError("If name is provided, expecting call signature like publish_dataset(df, name='ds')") # in case this is a singleton, collapse it elif len(args) == 1: args = args[0] @@ -3105,10 +3009,7 @@ async def _run( elif on_error == "return": results[key] = exc elif on_error != "ignore": - raise ValueError( - "on_error must be 'raise', 'return', or 'ignore'; " - f"got {on_error!r}" - ) + raise ValueError(f"on_error must be 'raise', 'return', or 'ignore'; got {on_error!r}") if wait: return results @@ -3208,9 +3109,7 @@ def run( ) @staticmethod - def _get_computation_code( - stacklevel: int | None = None, nframes: int = 1 - ) -> tuple[SourceCode, ...]: + def _get_computation_code(stacklevel: int | None = None, nframes: int = 1) -> tuple[SourceCode, ...]: """Walk up the stack to the user code and extract the code surrounding the compute/submit/persist call. All modules encountered which are ignored through the option @@ -3224,42 +3123,26 @@ def _get_computation_code( if nframes <= 0: return () - ignore_modules = dask.config.get( - "distributed.diagnostics.computations.ignore-modules" - ) + ignore_modules = dask.config.get("distributed.diagnostics.computations.ignore-modules") if not isinstance(ignore_modules, list): - raise TypeError( - "Ignored modules must be a list. Instead got " - f"({type(ignore_modules)}, {ignore_modules})" - ) - ignore_files = dask.config.get( - "distributed.diagnostics.computations.ignore-files" - ) + raise TypeError(f"Ignored modules must be a list. Instead got ({type(ignore_modules)}, {ignore_modules})") + ignore_files = dask.config.get("distributed.diagnostics.computations.ignore-files") if not isinstance(ignore_files, list): - raise TypeError( - "Ignored files must be a list. Instead got " - f"({type(ignore_files)}, {ignore_files})" - ) + raise TypeError(f"Ignored files must be a list. Instead got ({type(ignore_files)}, {ignore_files})") mod_pattern: re.Pattern | None = None fname_pattern: re.Pattern | None = None if stacklevel is None: if ignore_modules: - mod_pattern = re.compile( - "|".join([f"(?:{mod})" for mod in ignore_modules]) - ) + mod_pattern = re.compile("|".join([f"(?:{mod})" for mod in ignore_modules])) if ignore_files: - fname_pattern = re.compile( - r".*[\\/](" + "|".join(mod for mod in ignore_files) + r")([\\/]|$)" - ) + fname_pattern = re.compile(r".*[\\/](" + "|".join(mod for mod in ignore_files) + r")([\\/]|$)") else: # stacklevel 0 or less - shows dask internals which likely isn't helpful stacklevel = stacklevel if stacklevel > 0 else 1 code: list[SourceCode] = [] - for i, (fr, lineno_frame) in enumerate( - traceback.walk_stack(sys._getframe().f_back), 1 - ): + for i, (fr, lineno_frame) in enumerate(traceback.walk_stack(sys._getframe().f_back), 1): if len(code) >= nframes: break if stacklevel is not None and i != stacklevel: @@ -3368,9 +3251,7 @@ def _graph_to_futures( expr_ser = Serialized(*serialize(to_serialize(expr), on_error="raise")) pickled_size = sum(map(nbytes, [expr_ser.header] + expr_ser.frames)) - if pickled_size > parse_bytes( - dask.config.get("distributed.admin.large-graph-warning-threshold") - ): + if pickled_size > parse_bytes(dask.config.get("distributed.admin.large-graph-warning-threshold")): warnings.warn( f"Sending large graph of size {format_bytes(pickled_size)}.\n" "This may cause some slowdown.\n" @@ -3659,18 +3540,11 @@ def compute( if traverse: collections = tuple( - ( - dask.delayed(a) - if isinstance(a, (list, set, tuple, dict, Iterator)) - else a - ) - for a in collections + (dask.delayed(a) if isinstance(a, (list, set, tuple, dict, Iterator)) else a) for a in collections ) variables = [a for a in collections if dask.is_dask_collection(a)] - metadata = SpanMetadata( - collections=[get_collections_metadata(v) for v in variables] - ) + metadata = SpanMetadata(collections=[get_collections_metadata(v) for v in variables]) futures_dict = {} if variables: expr = collections_to_expr(variables, optimize_graph, **kwargs) @@ -3781,9 +3655,7 @@ def persist( collections = [collections] assert all(map(dask.is_dask_collection, collections)) - metadata = SpanMetadata( - collections=[get_collections_metadata(v) for v in collections] - ) + metadata = SpanMetadata(collections=[get_collections_metadata(v) for v in collections]) expr = collections_to_expr(collections, optimize_graph) expr2 = expr.optimize() @@ -3806,17 +3678,12 @@ def persist( postpersists = [c.__dask_postpersist__() for c in collections] assert len(postpersists) == len(keys) - result = [ - func({k: futures[k] for k in flatten(ks)}, *args) - for (func, args), ks in zip(postpersists, keys) - ] + result = [func({k: futures[k] for k in flatten(ks)}, *args) for (func, args), ks in zip(postpersists, keys)] if singleton: return result[0] return result - async def _restart( - self, timeout: str | int | float | NoDefault, wait_for_workers: bool - ) -> None: + async def _restart(self, timeout: str | int | float | NoDefault, wait_for_workers: bool) -> None: if timeout is no_default: timeout = self._timeout * 4 timeout = parse_timedelta(cast("str|int|float", timeout), "s") @@ -3861,9 +3728,7 @@ def restart( Scheduler.restart Client.restart_workers """ - return self.sync( - self._restart, timeout=timeout, wait_for_workers=wait_for_workers - ) + return self.sync(self._restart, timeout=timeout, wait_for_workers=wait_for_workers) async def _restart_workers( self, @@ -3879,13 +3744,11 @@ async def _restart_workers( name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()} worker_addrs = [name_to_addr.get(w, w) for w in workers] - out: dict[str, Literal["OK", "removed", "timed out"]] = ( - await self.scheduler.restart_workers( - workers=worker_addrs, - timeout=timeout, - on_error="raise" if raise_for_error else "return", - stimulus_id=f"client-restart-workers-{time()}", - ) + out: dict[str, Literal["OK", "removed", "timed out"]] = await self.scheduler.restart_workers( + workers=worker_addrs, + timeout=timeout, + on_error="raise" if raise_for_error else "return", + stimulus_id=f"client-restart-workers-{time()}", ) # Map keys back to original `workers` input names/addresses out = {w: out[w_addr] for w, w_addr in zip(workers, worker_addrs)} @@ -4024,9 +3887,7 @@ def upload_file(self, filename, load: bool = True): async def _(): results = await asyncio.gather( - self.register_plugin( - SchedulerUploadFile(filename, load=load), name=name - ), + self.register_plugin(SchedulerUploadFile(filename, load=load), name=name), # FIXME: Make scheduler plugin responsible for (de)registering worker plugin self.register_plugin(UploadFile(filename, load=load), name=name), ) @@ -4074,9 +3935,7 @@ async def _replicate(self, futures, n=None, workers=None, branching_factor=2): futures = self.futures_of(futures) await _wait(futures) keys = {f.key for f in futures} - await self.scheduler.replicate( - keys=list(keys), n=n, workers=workers, branching_factor=branching_factor - ) + await self.scheduler.replicate(keys=list(keys), n=n, workers=workers, branching_factor=branching_factor) def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs): """Set replication of futures within network @@ -4154,9 +4013,7 @@ def nthreads(self, workers=None, **kwargs): Client.who_has Client.has_what """ - if isinstance(workers, tuple) and all( - isinstance(i, (str, tuple)) for i in workers - ): + if isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -4231,9 +4088,7 @@ def has_what(self, workers=None, **kwargs): Client.nthreads Client.processing """ - if isinstance(workers, tuple) and all( - isinstance(i, (str, tuple)) for i in workers - ): + if isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -4265,9 +4120,7 @@ def processing(self, workers=None): Client.has_what Client.nthreads """ - if isinstance(workers, tuple) and all( - isinstance(i, (str, tuple)) for i in workers - ): + if isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -4702,9 +4555,7 @@ def log_event(self, topic: str | Collection[str], msg: Any): >>> client.log_event("current-time", time()) """ if not _is_dumpable(msg): - raise TypeError( - f"Message must be msgpack serializable. Got {type(msg)=} instead." - ) + raise TypeError(f"Message must be msgpack serializable. Got {type(msg)=} instead.") return self.sync(self.scheduler.log_event, topic=topic, msg=msg) def get_events(self, topic: str | None = None): @@ -4773,9 +4624,7 @@ def unsubscribe_topic(self, topic): else: raise ValueError(f"No event handler known for topic {topic}.") - def retire_workers( - self, workers: list[str] | None = None, close_workers: bool = True, **kwargs - ): + def retire_workers(self, workers: list[str] | None = None, close_workers: bool = True, **kwargs): """Retire certain workers on the scheduler See :meth:`distributed.Scheduler.retire_workers` for the full docstring. @@ -4876,9 +4725,7 @@ def get_versions( """ return self.sync(self._get_versions, check=check, packages=packages or []) - async def _get_versions( - self, check: bool = False, packages: Sequence[str] | None = None - ) -> VersionsDict: + async def _get_versions(self, check: bool = False, packages: Sequence[str] | None = None) -> VersionsDict: packages = packages or [] client = version_module.get_versions(packages=packages) scheduler = await self.scheduler.versions(packages=packages) @@ -4925,9 +4772,7 @@ async def _story(self, *keys_or_stimuli: str, on_error="raise"): assert on_error in ("raise", "ignore") try: - flat_stories = await self.scheduler.get_story( - keys_or_stimuli=keys_or_stimuli - ) + flat_stories = await self.scheduler.get_story(keys_or_stimuli=keys_or_stimuli) flat_stories = [("scheduler", *msg) for msg in flat_stories] except Exception: if on_error == "raise": @@ -5102,10 +4947,7 @@ def _register_plugin( ): if isinstance(plugin, type): raise TypeError("Please provide an instance of a plugin, not a type.") - if any( - "dask.distributed.diagnostics.plugin" in str(c) - for c in plugin.__class__.__bases__ - ): + if any("dask.distributed.diagnostics.plugin" in str(c) for c in plugin.__class__.__bases__): raise TypeError( "Importing plugin base classes from `dask.distributed.diagnostics.plugin` is not supported. " "Please import directly from `distributed.diagnostics.plugin` instead." @@ -5125,9 +4967,7 @@ def _(self, plugin: SchedulerPlugin, name: str, idempotent: bool): ) @_register_plugin.register - def _( - self, plugin: NannyPlugin, name: str, idempotent: bool - ) -> dict[str, OKMessage]: + def _(self, plugin: NannyPlugin, name: str, idempotent: bool) -> dict[str, OKMessage]: return self.sync( self._register_nanny_plugin, plugin=plugin, @@ -5144,9 +4984,7 @@ def _(self, plugin: WorkerPlugin, name: str, idempotent: bool): idempotent=idempotent, ) - async def _register_scheduler_plugin( - self, plugin: SchedulerPlugin, name: str, idempotent: bool - ): + async def _register_scheduler_plugin(self, plugin: SchedulerPlugin, name: str, idempotent: bool): return await self.scheduler.register_scheduler_plugin( plugin=dumps(plugin), name=name, @@ -5178,8 +5016,7 @@ def register_scheduler_plugin( Do not re-register if a plugin of the given name already exists. """ warnings.warn( - "`Client.register_scheduler_plugin` has been deprecated; " - "please `Client.register_plugin` instead", + "`Client.register_scheduler_plugin` has been deprecated; please `Client.register_plugin` instead", DeprecationWarning, stacklevel=2, ) @@ -5243,32 +5080,20 @@ def register_worker_callbacks(self, setup=None): """ return self.register_plugin(_WorkerSetupPlugin(setup)) - async def _register_worker_plugin( - self, plugin: WorkerPlugin, name: str, idempotent: bool - ) -> dict[str, OKMessage]: - responses = await self.scheduler.register_worker_plugin( - plugin=dumps(plugin), name=name, idempotent=idempotent - ) + async def _register_worker_plugin(self, plugin: WorkerPlugin, name: str, idempotent: bool) -> dict[str, OKMessage]: + responses = await self.scheduler.register_worker_plugin(plugin=dumps(plugin), name=name, idempotent=idempotent) for response in responses.values(): if response["status"] == "error": - _, exc, tb = clean_exception( - response["exception"], response["traceback"] - ) + _, exc, tb = clean_exception(response["exception"], response["traceback"]) assert exc raise exc.with_traceback(tb) return cast(dict[str, OKMessage], responses) - async def _register_nanny_plugin( - self, plugin: NannyPlugin, name: str, idempotent: bool - ) -> dict[str, OKMessage]: - responses = await self.scheduler.register_nanny_plugin( - plugin=dumps(plugin), name=name, idempotent=idempotent - ) + async def _register_nanny_plugin(self, plugin: NannyPlugin, name: str, idempotent: bool) -> dict[str, OKMessage]: + responses = await self.scheduler.register_nanny_plugin(plugin=dumps(plugin), name=name, idempotent=idempotent) for response in responses.values(): if response["status"] == "error": - _, exc, tb = clean_exception( - response["exception"], response["traceback"] - ) + _, exc, tb = clean_exception(response["exception"], response["traceback"]) assert exc raise exc.with_traceback(tb) return cast(dict[str, OKMessage], responses) @@ -5348,8 +5173,7 @@ def register_worker_plugin( unregister_worker_plugin """ warnings.warn( - "`Client.register_worker_plugin` has been deprecated; " - "please use `Client.register_plugin` instead", + "`Client.register_worker_plugin` has been deprecated; please use `Client.register_plugin` instead", DeprecationWarning, stacklevel=2, ) @@ -5593,9 +5417,7 @@ class : logging.StreamHandler # removed and torn down (see distributed.worker.Worker.plugin_add()), so # this is effectively idempotent, i.e., forwarding the same logger twice # won't cause every LogRecord to be forwarded twice - return self.register_plugin( - ForwardLoggingPlugin(logger_name, level, topic), plugin_name - ) + return self.register_plugin(ForwardLoggingPlugin(logger_name, level, topic), plugin_name) def unforward_logging(self, logger_name=None): """ @@ -5654,9 +5476,7 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): elif return_when == FIRST_COMPLETED: future = distributed.utils.Any({f._state.wait() for f in fs}) else: - raise NotImplementedError( - "Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are supported" - ) + raise NotImplementedError("Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are supported") if timeout is not None: future = wait_for(future, timeout) @@ -5674,10 +5494,7 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): assert isinstance(exception, FutureCancelledError) cancelled_errors[exception.reason].append(exception) if cancelled_errors: - groups = [ - CancelledFuturesGroup(errors=errors, reason=reason) - for reason, errors in cancelled_errors.items() - ] + groups = [CancelledFuturesGroup(errors=errors, reason=reason) for reason, errors in cancelled_errors.items()] raise FuturesCancelledError(groups) return DoneAndNotDoneFutures(done, not_done) @@ -5710,9 +5527,7 @@ async def _as_completed(fs, queue): fs = futures_of(fs) groups = groupby(lambda f: f.key, fs) firsts = [v[0] for v in groups.values()] - wait_iterator = gen.WaitIterator( - *map(asyncio.ensure_future, [f._state.wait() for f in firsts]) - ) + wait_iterator = gen.WaitIterator(*map(asyncio.ensure_future, [f._state.wait() for f in firsts])) while not wait_iterator.done(): await wait_iterator.next() @@ -5892,9 +5707,7 @@ def count(self): return len(self.futures) + len(self.queue.queue) def __repr__(self): - return ( - f"" - ) + return f"" def __iter__(self): return self @@ -6128,8 +5941,7 @@ def futures_of(o, client=None): cancelled_errors[exception.reason].append(exception) if cancelled_errors: groups = [ - CancelledFuturesGroup(errors=errors, reason=reason) - for reason, errors in cancelled_errors.items() + CancelledFuturesGroup(errors=errors, reason=reason) for reason, errors in cancelled_errors.items() ] raise FuturesCancelledError(groups) return futures[::-1] @@ -6231,9 +6043,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - L = self.client.get_task_stream( - start=self.start, plot=self._plot, filename=self._filename - ) + L = self.client.get_task_stream(start=self.start, plot=self._plot, filename=self._filename) if self._plot: L, self.figure = L self.data.extend(L) @@ -6242,9 +6052,7 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): - L = await self.client.get_task_stream( - start=self.start, plot=self._plot, filename=self._filename - ) + L = await self.client.get_task_stream(start=self.start, plot=self._plot, filename=self._filename) if self._plot: L, self.figure = L self.data.extend(L) @@ -6281,9 +6089,7 @@ class performance_report: ... x.compute() """ - def __init__( - self, filename="dask-report.html", stacklevel=1, mode=None, storage_options=None - ): + def __init__(self, filename="dask-report.html", stacklevel=1, mode=None, storage_options=None): self.filename = filename # stacklevel 0 or less - shows dask internals which likely isn't helpful self._stacklevel = stacklevel if stacklevel > 0 else 1 @@ -6292,9 +6098,7 @@ def __init__( async def __aenter__(self): self.start = time() - self.last_count = await get_client().run_on_scheduler( - lambda dask_scheduler: dask_scheduler.monitor.count - ) + self.last_count = await get_client().run_on_scheduler(lambda dask_scheduler: dask_scheduler.monitor.count) await get_client().get_task_stream(start=0, stop=0) # ensure plugin async def __aexit__(self, exc_type, exc_value, traceback, code=None): @@ -6307,9 +6111,7 @@ async def __aexit__(self, exc_type, exc_value, traceback, code=None): data = await client.scheduler.performance_report( start=self.start, last_count=self.last_count, code=code, mode=self.mode ) - with fsspec.open( - self.filename, mode="w", compression="infer", **self.storage_options - ) as f: + with fsspec.open(self.filename, mode="w", compression="infer", **self.storage_options) as f: f.write(data) def __enter__(self): From 67d2198e027d01dfefc33721b58fde6eba9e6611 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 1 Feb 2026 15:16:03 -0800 Subject: [PATCH 2/3] Update client.py --- distributed/client.py | 380 ++++++++++++++++++++++++++++++++---------- 1 file changed, 288 insertions(+), 92 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 658f48394b..5e3ef5d00e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -135,7 +135,9 @@ logger = logging.getLogger(__name__) -_global_clients: weakref.WeakValueDictionary[int, Client] = weakref.WeakValueDictionary() +_global_clients: weakref.WeakValueDictionary[int, Client] = ( + weakref.WeakValueDictionary() +) _global_client_index = [0] _current_client: ContextVar[Client | None] = ContextVar("_current_client", default=None) @@ -171,12 +173,16 @@ class FuturesCancelledError(CancelledError): error_groups: list[CancelledFuturesGroup] def __init__(self, error_groups: list[CancelledFuturesGroup]): - self.error_groups = sorted(error_groups, key=lambda group: len(group.errors), reverse=True) + self.error_groups = sorted( + error_groups, key=lambda group: len(group.errors), reverse=True + ) def __str__(self): count = sum(map(lambda group: len(group.errors), self.error_groups)) result = f"{count} Future{'s' if count > 1 else ''} cancelled:" - return "\n".join([result, "Reasons:"] + [str(group) for group in self.error_groups]) + return "\n".join( + [result, "Reasons:"] + [str(group) for group in self.error_groups] + ) class CancelledFuturesGroup: @@ -352,7 +358,9 @@ def executor(self): return self.client @property - def status(self) -> Literal["pending", "cancelled", "finished", "lost", "error"] | None: + def status( + self, + ) -> Literal["pending", "cancelled", "finished", "lost", "error"] | None: """Returns the status Returns @@ -477,7 +485,9 @@ def add_done_callback(self, fn): cls = Future if cls._cb_executor is None or cls._cb_executor_pid != os.getpid(): try: - cls._cb_executor = ThreadPoolExecutor(1, thread_name_prefix="Dask-Callback-Thread") + cls._cb_executor = ThreadPoolExecutor( + 1, thread_name_prefix="Dask-Callback-Thread" + ) except TypeError: cls._cb_executor = ThreadPoolExecutor(1) cls._cb_executor_pid = os.getpid() @@ -489,7 +499,9 @@ def execute_callback(fut): logger.exception("Error in callback %s of %s:", fn, fut) raise - self.client.loop.add_callback(done_callback, self, partial(cls._cb_executor.submit, execute_callback)) + self.client.loop.add_callback( + done_callback, self, partial(cls._cb_executor.submit, execute_callback) + ) def cancel(self, reason=None, msg=None, **kwargs): """Cancel the request to run this future @@ -607,7 +619,9 @@ def __str__(self): def __repr__(self): if self.type: - return f"" + return ( + f"" + ) else: return f"" @@ -640,7 +654,9 @@ def __init__(self, key: str): self._event = None self.key = key self.exception = None - self.status: Literal["pending", "cancelled", "finished", "lost", "error"] = "pending" + self.status: Literal["pending", "cancelled", "finished", "lost", "error"] = ( + "pending" + ) self.traceback = None self.type = None @@ -763,7 +779,9 @@ def _handle_print(event): if not isinstance(args, tuple): # worker.print() will always send us a tuple of args, even if it's an # empty tuple. - raise TypeError(f"_handle_print: client received non-tuple print args: {args!r}") + raise TypeError( + f"_handle_print: client received non-tuple print args: {args!r}" + ) file = msg.get("file") if file == 1: @@ -771,9 +789,13 @@ def _handle_print(event): elif file == 2: file = sys.stderr elif file is not None: - raise TypeError(f"_handle_print: client received unsupported file kwarg: {file!r}") + raise TypeError( + f"_handle_print: client received unsupported file kwarg: {file!r}" + ) - print(*args, sep=msg.get("sep"), end=msg.get("end"), file=file, flush=msg.get("flush")) + print( + *args, sep=msg.get("sep"), end=msg.get("end"), file=file, flush=msg.get("flush") + ) def _handle_warn(event): @@ -787,7 +809,9 @@ def _handle_warn(event): if "message" not in msg: # TypeError makes sense here because it's analogous to calling a # function without a required positional argument - raise TypeError('_handle_warn: client received a warn event missing the required "message" argument.') + raise TypeError( + '_handle_warn: client received a warn event missing the required "message" argument.' + ) if "category" in msg: category = pickle.loads(msg["category"]) else: @@ -824,7 +848,11 @@ class VersionsDict(TypedDict): def _is_nested(iterable): for item in iterable: - if isinstance(item, Iterable) and not isinstance(item, str) and not isinstance(item, bytes): + if ( + isinstance(item, Iterable) + and not isinstance(item, str) + and not isinstance(item, bytes) + ): return True return False @@ -873,7 +901,10 @@ def keys(self) -> Iterable[Key]: else: uid = str(uuid.uuid4()) keys = ( - [f"{self.key}-{uid}-{i}" for i in range(min(map(len, self.iterables)))] + [ + f"{self.key}-{uid}-{i}" + for i in range(min(map(len, self.iterables))) + ] if self.iterables else [] ) @@ -1050,7 +1081,11 @@ def __init__( self._handle_report_task = None if name is None: name = dask.config.get("client-name", None) - self.id = type(self).__name__ + ("-" + name + "-" if name else "-") + str(uuid.uuid1(clock_seq=os.getpid())) + self.id = ( + type(self).__name__ + + ("-" + name + "-" if name else "-") + + str(uuid.uuid1(clock_seq=os.getpid())) + ) self.generation = 0 self.status = "newly-created" self._pending_msg_buffer = [] @@ -1093,13 +1128,17 @@ def __init__( self.cluster = address status = self.cluster.status if status in (Status.closed, Status.closing): - raise RuntimeError(f"Trying to connect to an already closed or closing Cluster {self.cluster}.") + raise RuntimeError( + f"Trying to connect to an already closed or closing Cluster {self.cluster}." + ) with suppress(AttributeError): loop = address.loop if security is None: security = getattr(self.cluster, "security", None) elif address is not None and not isinstance(address, str): - raise TypeError(f"Scheduler address must be a string or a Cluster instance, got {type(address)}") + raise TypeError( + f"Scheduler address must be a string or a Cluster instance, got {type(address)}" + ) # If connecting to an address and no explicit security is configured, attempt # to load security credentials with a security loader (if configured). @@ -1142,7 +1181,9 @@ def __init__( self._periodic_callbacks["scheduler-info"] = PeriodicCallback( self._update_scheduler_info, scheduler_info_interval * 1000 ) - self._periodic_callbacks["heartbeat"] = PeriodicCallback(self._heartbeat, heartbeat_interval * 1000) + self._periodic_callbacks["heartbeat"] = PeriodicCallback( + self._heartbeat, heartbeat_interval * 1000 + ) self._start_arg = address self._set_as_default = set_as_default @@ -1178,7 +1219,9 @@ def __init__( server=self, ) - self.extensions = {name: extension(self) for name, extension in extensions.items()} + self.extensions = { + name: extension(self) for name, extension in extensions.items() + } preload = dask.config.get("distributed.client.preload") preload_argv = dask.config.get("distributed.client.preload-argv") @@ -1193,12 +1236,16 @@ def __init__( @property def io_loop(self) -> IOLoop | None: - warnings.warn("The io_loop property is deprecated", DeprecationWarning, stacklevel=2) + warnings.warn( + "The io_loop property is deprecated", DeprecationWarning, stacklevel=2 + ) return self.loop @io_loop.setter def io_loop(self, value: IOLoop) -> None: - warnings.warn("The io_loop property is deprecated", DeprecationWarning, stacklevel=2) + warnings.warn( + "The io_loop property is deprecated", DeprecationWarning, stacklevel=2 + ) self.loop = value @property @@ -1214,7 +1261,9 @@ def loop(self) -> IOLoop | None: @loop.setter def loop(self, value: IOLoop) -> None: - warnings.warn("setting the loop property is deprecated", DeprecationWarning, stacklevel=2) + warnings.warn( + "setting the loop property is deprecated", DeprecationWarning, stacklevel=2 + ) self.__loop = value @contextmanager @@ -1303,10 +1352,16 @@ def dashboard_link(self): def _get_scheduler_info(self, n_workers): from distributed.scheduler import Scheduler - if self.cluster and hasattr(self.cluster, "scheduler") and isinstance(self.cluster.scheduler, Scheduler): + if ( + self.cluster + and hasattr(self.cluster, "scheduler") + and isinstance(self.cluster.scheduler, Scheduler) + ): info = self.cluster.scheduler.identity(n_workers=n_workers) scheduler = self.cluster.scheduler - elif self._loop_runner.is_started() and self.scheduler and not self.asynchronous: + elif ( + self._loop_runner.is_started() and self.scheduler and not self.asynchronous + ): info = sync(self.loop, self.scheduler.identity, n_workers=n_workers) scheduler = self.scheduler else: @@ -1395,7 +1450,9 @@ def _send_to_scheduler(self, msg): if self.status in ("running", "closing", "connecting", "newly-created"): self.loop.add_callback(self._send_to_scheduler_safe, msg) else: - raise ClosedClientError(f"Client is {self.status}. Can't send {msg['op']} message.") + raise ClosedClientError( + f"Client is {self.status}. Can't send {msg['op']} message." + ) async def _start(self, timeout=no_default, **kwargs): self.status = "connecting" @@ -1472,7 +1529,9 @@ async def _reconnect(self): for st in self.futures.values(): st.cancel( reason="scheduler-connection-lost", - msg=("Client lost the connection to the scheduler. Please check your connection and re-run your work."), + msg=( + "Client lost the connection to the scheduler. Please check your connection and re-run your work." + ), ) self.futures.clear() @@ -1509,7 +1568,9 @@ async def _ensure_connected(self, timeout=None): self._connecting_to_scheduler = True try: - comm = await connect(self.scheduler.address, timeout=timeout, **self.connection_args) + comm = await connect( + self.scheduler.address, timeout=timeout, **self.connection_args + ) comm.name = "Client->Scheduler" if timeout is not None: await wait_for(self._update_scheduler_info(), timeout) @@ -1557,11 +1618,15 @@ async def _update_scheduler_info(self, n_workers=5): if self.status not in ("running", "connecting") or self.scheduler is None: return try: - self._scheduler_identity = SchedulerInfo(await self.scheduler.identity(n_workers=n_workers)) + self._scheduler_identity = SchedulerInfo( + await self.scheduler.identity(n_workers=n_workers) + ) except OSError: logger.debug("Not able to query scheduler for identity") - async def _wait_for_workers(self, n_workers: int, timeout: float | None = None) -> None: + async def _wait_for_workers( + self, n_workers: int, timeout: float | None = None + ) -> None: info = await self.scheduler.identity(n_workers=-1) self._scheduler_identity = SchedulerInfo(info) if timeout: @@ -1570,7 +1635,13 @@ async def _wait_for_workers(self, n_workers: int, timeout: float | None = None) deadline = None def running_workers(info): - return len([ws for ws in info["workers"].values() if ws["status"] == Status.running.name]) + return len( + [ + ws + for ws in info["workers"].values() + if ws["status"] == Status.running.name + ] + ) while running_workers(info) < n_workers: if deadline and time() > deadline: @@ -1592,7 +1663,9 @@ def wait_for_workers(self, n_workers: int, timeout: float | None = None) -> None ``dask.distributed.TimeoutError`` """ if not isinstance(n_workers, int) or n_workers < 1: - raise ValueError(f"`n_workers` must be a positive integer. Instead got {n_workers}.") + raise ValueError( + f"`n_workers` must be a positive integer. Instead got {n_workers}." + ) if self.cluster and hasattr(self.cluster, "wait_for_workers"): return self.cluster.wait_for_workers(n_workers, timeout) @@ -1681,7 +1754,9 @@ def _release_key(self, key): if st is not None: st.cancel() if self.status != "closed": - self._send_to_scheduler({"op": "client-releases-keys", "keys": [key], "client": self.id}) + self._send_to_scheduler( + {"op": "client-releases-keys", "keys": [key], "client": self.id} + ) @log_errors async def _handle_report(self): @@ -1797,7 +1872,9 @@ async def _wait_for_handle_report_task(self, fast=False): handle_report_task = self._handle_report_task # Give the scheduler 'stream-closed' message 100ms to come through # This makes the shutdown slightly smoother and quieter - should_wait = handle_report_task is not None and handle_report_task is not current_task + should_wait = ( + handle_report_task is not None and handle_report_task is not current_task + ) if should_wait: with suppress(asyncio.CancelledError, TimeoutError): await wait_for(asyncio.shield(handle_report_task), 0.1) @@ -1839,11 +1916,19 @@ async def _close(self, fast: bool = False) -> None: if self.get == dask.config.get("get", None): del dask.config.config["get"] - if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed(): + if ( + self.scheduler_comm + and self.scheduler_comm.comm + and not self.scheduler_comm.comm.closed() + ): self._send_to_scheduler({"op": "close-client"}) self._send_to_scheduler({"op": "close-stream"}) async with self._wait_for_handle_report_task(fast=fast): - if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed(): + if ( + self.scheduler_comm + and self.scheduler_comm.comm + and not self.scheduler_comm.comm.closed() + ): await self.scheduler_comm.close() for key in list(self.futures): @@ -2201,14 +2286,18 @@ def map( if not callable(func): raise TypeError("First input to map must be a callable function") - if all(isinstance(it, pyQueue) for it in iterables) or all(isinstance(i, Iterator) for i in iterables): + if all(isinstance(it, pyQueue) for it in iterables) or all( + isinstance(i, Iterator) for i in iterables + ): raise TypeError( "Dask no longer supports mapping over Iterators or Queues." "Consider using a normal for loop and Client.submit" ) total_length = sum(len(x) for x in iterables) if batch_size and batch_size > 1 and total_length > batch_size: - batches = list(zip(*(partition_all(batch_size, iterable) for iterable in iterables))) + batches = list( + zip(*(partition_all(batch_size, iterable) for iterable in iterables)) + ) keys: list[list[Any]] | list[Any] if isinstance(key, list): keys = [list(element) for element in partition_all(batch_size, key)] @@ -2344,7 +2433,9 @@ async def wait(k): keys = [k for k in keys if k not in bad_keys and k not in data] if local_worker: # look inside local worker - data.update({k: local_worker.data[k] for k in keys if k in local_worker.data}) + data.update( + {k: local_worker.data[k] for k in keys if k in local_worker.data} + ) keys = [k for k in keys if k not in data] # We now do an actual remote communication with workers or scheduler @@ -2353,7 +2444,9 @@ async def wait(k): response = await self._gather_future else: # no one waiting, go ahead self._gather_keys = set(keys) - future = asyncio.ensure_future(self._gather_remote(direct, local_worker)) + future = asyncio.ensure_future( + self._gather_remote(direct, local_worker) + ) if self._gather_keys is None: self._gather_future = None else: @@ -2398,10 +2491,14 @@ async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, An if direct or local_worker: # gather directly from workers who_has = await retry_operation(self.scheduler.who_has, keys=keys) - data, missing_keys, failed_keys, _ = await gather_from_workers(who_has, rpc=self.rpc) + data, missing_keys, failed_keys, _ = await gather_from_workers( + who_has, rpc=self.rpc + ) response: dict[str, Any] = {"status": "OK", "data": data} if missing_keys or failed_keys: - response = await retry_operation(self.scheduler.gather, keys=missing_keys + failed_keys) + response = await retry_operation( + self.scheduler.gather, keys=missing_keys + failed_keys + ) if response["status"] == "OK": response["data"].update(data) @@ -2550,7 +2647,9 @@ async def _scatter( _, who_has, nbytes = await scatter_to_workers(workers, data2, self.rpc) - await self.scheduler.update_data(who_has=who_has, nbytes=nbytes, client=self.id) + await self.scheduler.update_data( + who_has=who_has, nbytes=nbytes, client=self.id + ) else: await self.scheduler.scatter( data=data2, @@ -2783,7 +2882,9 @@ async def _(): if name: if len(args) == 0: - raise ValueError("If name is provided, expecting call signature like publish_dataset(df, name='ds')") + raise ValueError( + "If name is provided, expecting call signature like publish_dataset(df, name='ds')" + ) # in case this is a singleton, collapse it elif len(args) == 1: args = args[0] @@ -3010,7 +3111,9 @@ async def _run( elif on_error == "return": results[key] = exc elif on_error != "ignore": - raise ValueError(f"on_error must be 'raise', 'return', or 'ignore'; got {on_error!r}") + raise ValueError( + f"on_error must be 'raise', 'return', or 'ignore'; got {on_error!r}" + ) if wait: return results @@ -3110,7 +3213,9 @@ def run( ) @staticmethod - def _get_computation_code(stacklevel: int | None = None, nframes: int = 1) -> tuple[SourceCode, ...]: + def _get_computation_code( + stacklevel: int | None = None, nframes: int = 1 + ) -> tuple[SourceCode, ...]: """Walk up the stack to the user code and extract the code surrounding the compute/submit/persist call. All modules encountered which are ignored through the option @@ -3124,18 +3229,28 @@ def _get_computation_code(stacklevel: int | None = None, nframes: int = 1) -> tu if nframes <= 0: return () - ignore_modules = dask.config.get("distributed.diagnostics.computations.ignore-modules") + ignore_modules = dask.config.get( + "distributed.diagnostics.computations.ignore-modules" + ) if not isinstance(ignore_modules, list): - raise TypeError(f"Ignored modules must be a list. Instead got ({type(ignore_modules)}, {ignore_modules})") - ignore_files = dask.config.get("distributed.diagnostics.computations.ignore-files") + raise TypeError( + f"Ignored modules must be a list. Instead got ({type(ignore_modules)}, {ignore_modules})" + ) + ignore_files = dask.config.get( + "distributed.diagnostics.computations.ignore-files" + ) if not isinstance(ignore_files, list): - raise TypeError(f"Ignored files must be a list. Instead got ({type(ignore_files)}, {ignore_files})") + raise TypeError( + f"Ignored files must be a list. Instead got ({type(ignore_files)}, {ignore_files})" + ) mod_pattern: re.Pattern | None = None fname_pattern: re.Pattern | None = None if stacklevel is None: if ignore_modules: - mod_pattern = re.compile("|".join([f"(?:{mod})" for mod in ignore_modules])) + mod_pattern = re.compile( + "|".join([f"(?:{mod})" for mod in ignore_modules]) + ) if ignore_files: # Given ignore-files = [foo], match: # /path/to/foo @@ -3157,7 +3272,9 @@ def _get_computation_code(stacklevel: int | None = None, nframes: int = 1) -> tu stacklevel = stacklevel if stacklevel > 0 else 1 code: list[SourceCode] = [] - for i, (fr, lineno_frame) in enumerate(traceback.walk_stack(sys._getframe().f_back), 1): + for i, (fr, lineno_frame) in enumerate( + traceback.walk_stack(sys._getframe().f_back), 1 + ): if len(code) >= nframes: break if stacklevel is not None and i != stacklevel: @@ -3266,7 +3383,9 @@ def _graph_to_futures( expr_ser = Serialized(*serialize(to_serialize(expr), on_error="raise")) pickled_size = sum(map(nbytes, [expr_ser.header] + expr_ser.frames)) - if pickled_size > parse_bytes(dask.config.get("distributed.admin.large-graph-warning-threshold")): + if pickled_size > parse_bytes( + dask.config.get("distributed.admin.large-graph-warning-threshold") + ): warnings.warn( f"Sending large graph of size {format_bytes(pickled_size)}.\n" "This may cause some slowdown.\n" @@ -3555,11 +3674,18 @@ def compute( if traverse: collections = tuple( - (dask.delayed(a) if isinstance(a, (list, set, tuple, dict, Iterator)) else a) for a in collections + ( + dask.delayed(a) + if isinstance(a, (list, set, tuple, dict, Iterator)) + else a + ) + for a in collections ) variables = [a for a in collections if dask.is_dask_collection(a)] - metadata = SpanMetadata(collections=[get_collections_metadata(v) for v in variables]) + metadata = SpanMetadata( + collections=[get_collections_metadata(v) for v in variables] + ) futures_dict = {} if variables: expr = collections_to_expr(variables, optimize_graph, **kwargs) @@ -3670,7 +3796,9 @@ def persist( collections = [collections] assert all(map(dask.is_dask_collection, collections)) - metadata = SpanMetadata(collections=[get_collections_metadata(v) for v in collections]) + metadata = SpanMetadata( + collections=[get_collections_metadata(v) for v in collections] + ) expr = collections_to_expr(collections, optimize_graph) expr2 = expr.optimize() @@ -3693,12 +3821,17 @@ def persist( postpersists = [c.__dask_postpersist__() for c in collections] assert len(postpersists) == len(keys) - result = [func({k: futures[k] for k in flatten(ks)}, *args) for (func, args), ks in zip(postpersists, keys)] + result = [ + func({k: futures[k] for k in flatten(ks)}, *args) + for (func, args), ks in zip(postpersists, keys) + ] if singleton: return result[0] return result - async def _restart(self, timeout: str | int | float | NoDefault, wait_for_workers: bool) -> None: + async def _restart( + self, timeout: str | int | float | NoDefault, wait_for_workers: bool + ) -> None: if timeout is no_default: timeout = self._timeout * 4 timeout = parse_timedelta(cast("str|int|float", timeout), "s") @@ -3743,7 +3876,9 @@ def restart( Scheduler.restart Client.restart_workers """ - return self.sync(self._restart, timeout=timeout, wait_for_workers=wait_for_workers) + return self.sync( + self._restart, timeout=timeout, wait_for_workers=wait_for_workers + ) async def _restart_workers( self, @@ -3759,11 +3894,13 @@ async def _restart_workers( name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()} worker_addrs = [name_to_addr.get(w, w) for w in workers] - out: dict[str, Literal["OK", "removed", "timed out"]] = await self.scheduler.restart_workers( - workers=worker_addrs, - timeout=timeout, - on_error="raise" if raise_for_error else "return", - stimulus_id=f"client-restart-workers-{time()}", + out: dict[str, Literal["OK", "removed", "timed out"]] = ( + await self.scheduler.restart_workers( + workers=worker_addrs, + timeout=timeout, + on_error="raise" if raise_for_error else "return", + stimulus_id=f"client-restart-workers-{time()}", + ) ) # Map keys back to original `workers` input names/addresses out = {w: out[w_addr] for w, w_addr in zip(workers, worker_addrs)} @@ -3902,7 +4039,9 @@ def upload_file(self, filename, load: bool = True): async def _(): results = await asyncio.gather( - self.register_plugin(SchedulerUploadFile(filename, load=load), name=name), + self.register_plugin( + SchedulerUploadFile(filename, load=load), name=name + ), # FIXME: Make scheduler plugin responsible for (de)registering worker plugin self.register_plugin(UploadFile(filename, load=load), name=name), ) @@ -3950,7 +4089,9 @@ async def _replicate(self, futures, n=None, workers=None, branching_factor=2): futures = self.futures_of(futures) await _wait(futures) keys = {f.key for f in futures} - await self.scheduler.replicate(keys=list(keys), n=n, workers=workers, branching_factor=branching_factor) + await self.scheduler.replicate( + keys=list(keys), n=n, workers=workers, branching_factor=branching_factor + ) def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs): """Set replication of futures within network @@ -4028,7 +4169,9 @@ def nthreads(self, workers=None, **kwargs): Client.who_has Client.has_what """ - if isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers): + if isinstance(workers, tuple) and all( + isinstance(i, (str, tuple)) for i in workers + ): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -4103,7 +4246,9 @@ def has_what(self, workers=None, **kwargs): Client.nthreads Client.processing """ - if isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers): + if isinstance(workers, tuple) and all( + isinstance(i, (str, tuple)) for i in workers + ): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -4135,7 +4280,9 @@ def processing(self, workers=None): Client.has_what Client.nthreads """ - if isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers): + if isinstance(workers, tuple) and all( + isinstance(i, (str, tuple)) for i in workers + ): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -4570,7 +4717,9 @@ def log_event(self, topic: str | Collection[str], msg: Any): >>> client.log_event("current-time", time()) """ if not _is_dumpable(msg): - raise TypeError(f"Message must be msgpack serializable. Got {type(msg)=} instead.") + raise TypeError( + f"Message must be msgpack serializable. Got {type(msg)=} instead." + ) return self.sync(self.scheduler.log_event, topic=topic, msg=msg) def get_events(self, topic: str | None = None): @@ -4639,7 +4788,9 @@ def unsubscribe_topic(self, topic): else: raise ValueError(f"No event handler known for topic {topic}.") - def retire_workers(self, workers: list[str] | None = None, close_workers: bool = True, **kwargs): + def retire_workers( + self, workers: list[str] | None = None, close_workers: bool = True, **kwargs + ): """Retire certain workers on the scheduler See :meth:`distributed.Scheduler.retire_workers` for the full docstring. @@ -4740,7 +4891,9 @@ def get_versions( """ return self.sync(self._get_versions, check=check, packages=packages or []) - async def _get_versions(self, check: bool = False, packages: Sequence[str] | None = None) -> VersionsDict: + async def _get_versions( + self, check: bool = False, packages: Sequence[str] | None = None + ) -> VersionsDict: packages = packages or [] client = version_module.get_versions(packages=packages) scheduler = await self.scheduler.versions(packages=packages) @@ -4787,7 +4940,9 @@ async def _story(self, *keys_or_stimuli: str, on_error="raise"): assert on_error in ("raise", "ignore") try: - flat_stories = await self.scheduler.get_story(keys_or_stimuli=keys_or_stimuli) + flat_stories = await self.scheduler.get_story( + keys_or_stimuli=keys_or_stimuli + ) flat_stories = [("scheduler", *msg) for msg in flat_stories] except Exception: if on_error == "raise": @@ -4962,7 +5117,10 @@ def _register_plugin( ): if isinstance(plugin, type): raise TypeError("Please provide an instance of a plugin, not a type.") - if any("dask.distributed.diagnostics.plugin" in str(c) for c in plugin.__class__.__bases__): + if any( + "dask.distributed.diagnostics.plugin" in str(c) + for c in plugin.__class__.__bases__ + ): raise TypeError( "Importing plugin base classes from `dask.distributed.diagnostics.plugin` is not supported. " "Please import directly from `distributed.diagnostics.plugin` instead." @@ -4982,7 +5140,9 @@ def _(self, plugin: SchedulerPlugin, name: str, idempotent: bool): ) @_register_plugin.register - def _(self, plugin: NannyPlugin, name: str, idempotent: bool) -> dict[str, OKMessage]: + def _( + self, plugin: NannyPlugin, name: str, idempotent: bool + ) -> dict[str, OKMessage]: return self.sync( self._register_nanny_plugin, plugin=plugin, @@ -4999,7 +5159,9 @@ def _(self, plugin: WorkerPlugin, name: str, idempotent: bool): idempotent=idempotent, ) - async def _register_scheduler_plugin(self, plugin: SchedulerPlugin, name: str, idempotent: bool): + async def _register_scheduler_plugin( + self, plugin: SchedulerPlugin, name: str, idempotent: bool + ): return await self.scheduler.register_scheduler_plugin( plugin=dumps(plugin), name=name, @@ -5095,20 +5257,32 @@ def register_worker_callbacks(self, setup=None): """ return self.register_plugin(_WorkerSetupPlugin(setup)) - async def _register_worker_plugin(self, plugin: WorkerPlugin, name: str, idempotent: bool) -> dict[str, OKMessage]: - responses = await self.scheduler.register_worker_plugin(plugin=dumps(plugin), name=name, idempotent=idempotent) + async def _register_worker_plugin( + self, plugin: WorkerPlugin, name: str, idempotent: bool + ) -> dict[str, OKMessage]: + responses = await self.scheduler.register_worker_plugin( + plugin=dumps(plugin), name=name, idempotent=idempotent + ) for response in responses.values(): if response["status"] == "error": - _, exc, tb = clean_exception(response["exception"], response["traceback"]) + _, exc, tb = clean_exception( + response["exception"], response["traceback"] + ) assert exc raise exc.with_traceback(tb) return cast(dict[str, OKMessage], responses) - async def _register_nanny_plugin(self, plugin: NannyPlugin, name: str, idempotent: bool) -> dict[str, OKMessage]: - responses = await self.scheduler.register_nanny_plugin(plugin=dumps(plugin), name=name, idempotent=idempotent) + async def _register_nanny_plugin( + self, plugin: NannyPlugin, name: str, idempotent: bool + ) -> dict[str, OKMessage]: + responses = await self.scheduler.register_nanny_plugin( + plugin=dumps(plugin), name=name, idempotent=idempotent + ) for response in responses.values(): if response["status"] == "error": - _, exc, tb = clean_exception(response["exception"], response["traceback"]) + _, exc, tb = clean_exception( + response["exception"], response["traceback"] + ) assert exc raise exc.with_traceback(tb) return cast(dict[str, OKMessage], responses) @@ -5432,7 +5606,9 @@ class : logging.StreamHandler # removed and torn down (see distributed.worker.Worker.plugin_add()), so # this is effectively idempotent, i.e., forwarding the same logger twice # won't cause every LogRecord to be forwarded twice - return self.register_plugin(ForwardLoggingPlugin(logger_name, level, topic), plugin_name) + return self.register_plugin( + ForwardLoggingPlugin(logger_name, level, topic), plugin_name + ) def unforward_logging(self, logger_name=None): """ @@ -5491,7 +5667,9 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): elif return_when == FIRST_COMPLETED: future = distributed.utils.Any({f._state.wait() for f in fs}) else: - raise NotImplementedError("Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are supported") + raise NotImplementedError( + "Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are supported" + ) if timeout is not None: future = wait_for(future, timeout) @@ -5509,7 +5687,10 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): assert isinstance(exception, FutureCancelledError) cancelled_errors[exception.reason].append(exception) if cancelled_errors: - groups = [CancelledFuturesGroup(errors=errors, reason=reason) for reason, errors in cancelled_errors.items()] + groups = [ + CancelledFuturesGroup(errors=errors, reason=reason) + for reason, errors in cancelled_errors.items() + ] raise FuturesCancelledError(groups) return DoneAndNotDoneFutures(done, not_done) @@ -5542,7 +5723,9 @@ async def _as_completed(fs, queue): fs = futures_of(fs) groups = groupby(lambda f: f.key, fs) firsts = [v[0] for v in groups.values()] - wait_iterator = gen.WaitIterator(*map(asyncio.ensure_future, [f._state.wait() for f in firsts])) + wait_iterator = gen.WaitIterator( + *map(asyncio.ensure_future, [f._state.wait() for f in firsts]) + ) while not wait_iterator.done(): await wait_iterator.next() @@ -5722,7 +5905,9 @@ def count(self): return len(self.futures) + len(self.queue.queue) def __repr__(self): - return f"" + return ( + f"" + ) def __iter__(self): return self @@ -5956,7 +6141,8 @@ def futures_of(o, client=None): cancelled_errors[exception.reason].append(exception) if cancelled_errors: groups = [ - CancelledFuturesGroup(errors=errors, reason=reason) for reason, errors in cancelled_errors.items() + CancelledFuturesGroup(errors=errors, reason=reason) + for reason, errors in cancelled_errors.items() ] raise FuturesCancelledError(groups) return futures[::-1] @@ -6058,7 +6244,9 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - L = self.client.get_task_stream(start=self.start, plot=self._plot, filename=self._filename) + L = self.client.get_task_stream( + start=self.start, plot=self._plot, filename=self._filename + ) if self._plot: L, self.figure = L self.data.extend(L) @@ -6067,7 +6255,9 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): - L = await self.client.get_task_stream(start=self.start, plot=self._plot, filename=self._filename) + L = await self.client.get_task_stream( + start=self.start, plot=self._plot, filename=self._filename + ) if self._plot: L, self.figure = L self.data.extend(L) @@ -6104,7 +6294,9 @@ class performance_report: ... x.compute() """ - def __init__(self, filename="dask-report.html", stacklevel=1, mode=None, storage_options=None): + def __init__( + self, filename="dask-report.html", stacklevel=1, mode=None, storage_options=None + ): self.filename = filename # stacklevel 0 or less - shows dask internals which likely isn't helpful self._stacklevel = stacklevel if stacklevel > 0 else 1 @@ -6113,7 +6305,9 @@ def __init__(self, filename="dask-report.html", stacklevel=1, mode=None, storage async def __aenter__(self): self.start = time() - self.last_count = await get_client().run_on_scheduler(lambda dask_scheduler: dask_scheduler.monitor.count) + self.last_count = await get_client().run_on_scheduler( + lambda dask_scheduler: dask_scheduler.monitor.count + ) await get_client().get_task_stream(start=0, stop=0) # ensure plugin async def __aexit__(self, exc_type, exc_value, traceback, code=None): @@ -6126,7 +6320,9 @@ async def __aexit__(self, exc_type, exc_value, traceback, code=None): data = await client.scheduler.performance_report( start=self.start, last_count=self.last_count, code=code, mode=self.mode ) - with fsspec.open(self.filename, mode="w", compression="infer", **self.storage_options) as f: + with fsspec.open( + self.filename, mode="w", compression="infer", **self.storage_options + ) as f: f.write(data) def __enter__(self): From 3040922cba48b385142695238780f3240a8fb908 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 1 Feb 2026 15:34:53 -0800 Subject: [PATCH 3/3] Update --- distributed/client.py | 63 +++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 38 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 5e3ef5d00e..415394ab17 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -358,22 +358,13 @@ def executor(self): return self.client @property - def status( - self, - ) -> Literal["pending", "cancelled", "finished", "lost", "error"] | None: + def status(self): """Returns the status Returns ------- str The status - The status of the future. Possible values: - - "pending": The future is waiting to be computed - - "finished": The future has completed successfully - - "error": The future encountered an error during computation - - "cancelled": The future was cancelled - - "lost": The future's data was lost from memory - - None: The future is not yet bound to a client """ if self._state: return self._state.status @@ -654,9 +645,7 @@ def __init__(self, key: str): self._event = None self.key = key self.exception = None - self.status: Literal["pending", "cancelled", "finished", "lost", "error"] = ( - "pending" - ) + self.status = "pending" self.traceback = None self.type = None @@ -810,7 +799,8 @@ def _handle_warn(event): # TypeError makes sense here because it's analogous to calling a # function without a required positional argument raise TypeError( - '_handle_warn: client received a warn event missing the required "message" argument.' + "_handle_warn: client received a warn event missing the required " + '"message" argument.' ) if "category" in msg: category = pickle.loads(msg["category"]) @@ -1530,7 +1520,8 @@ async def _reconnect(self): st.cancel( reason="scheduler-connection-lost", msg=( - "Client lost the connection to the scheduler. Please check your connection and re-run your work." + "Client lost the connection to the scheduler. " + "Please check your connection and re-run your work." ), ) self.futures.clear() @@ -1551,7 +1542,8 @@ async def _reconnect(self): else: logger.error( - "Failed to reconnect to scheduler after %.2f seconds, closing client", + "Failed to reconnect to scheduler after %.2f " + "seconds, closing client", self._timeout, ) await self._close() @@ -1706,7 +1698,8 @@ async def __aexit__(self, exc_type, exc_value, traceback): if not e.args[0].endswith(" was created in a different Context"): raise # pragma: nocover warnings.warn( - "It is deprecated to enter and exit the Client context manager from different tasks", + "It is deprecated to enter and exit the Client context " + "manager from different tasks", DeprecationWarning, stacklevel=2, ) @@ -1724,7 +1717,8 @@ def __exit__(self, exc_type, exc_value, traceback): if not e.args[0].endswith(" was created in a different Context"): raise # pragma: nocover warnings.warn( - "It is deprecated to enter and exit the Client context manager from different threads", + "It is deprecated to enter and exit the Client context " + "manager from different threads", DeprecationWarning, stacklevel=2, ) @@ -2883,7 +2877,8 @@ async def _(): if name: if len(args) == 0: raise ValueError( - "If name is provided, expecting call signature like publish_dataset(df, name='ds')" + "If name is provided, expecting call signature like" + " publish_dataset(df, name='ds')" ) # in case this is a singleton, collapse it elif len(args) == 1: @@ -3099,7 +3094,6 @@ async def _run( elif resp["status"] == "error": # Exception raised by the remote function _, exc, tb = clean_exception(**resp) - assert exc is not None exc = exc.with_traceback(tb) else: assert resp["status"] == "OK" @@ -3112,7 +3106,8 @@ async def _run( results[key] = exc elif on_error != "ignore": raise ValueError( - f"on_error must be 'raise', 'return', or 'ignore'; got {on_error!r}" + "on_error must be 'raise', 'return', or 'ignore'; " + f"got {on_error!r}" ) if wait: @@ -3234,14 +3229,16 @@ def _get_computation_code( ) if not isinstance(ignore_modules, list): raise TypeError( - f"Ignored modules must be a list. Instead got ({type(ignore_modules)}, {ignore_modules})" + "Ignored modules must be a list. Instead got " + f"({type(ignore_modules)}, {ignore_modules})" ) ignore_files = dask.config.get( "distributed.diagnostics.computations.ignore-files" ) if not isinstance(ignore_files, list): raise TypeError( - f"Ignored files must be a list. Instead got ({type(ignore_files)}, {ignore_files})" + "Ignored files must be a list. Instead got " + f"({type(ignore_files)}, {ignore_files})" ) mod_pattern: re.Pattern | None = None @@ -3252,20 +3249,8 @@ def _get_computation_code( "|".join([f"(?:{mod})" for mod in ignore_modules]) ) if ignore_files: - # Given ignore-files = [foo], match: - # /path/to/foo - # /path/to/foo.py[c] - # /path/to/foo/bar.py[c] - # \path\to\foo - # \path\to\foo.py[c] - # \path\to\foo\bar.py[c] - # - # Do not match files that have 'foo' as a substring, - # unless the user explicitly states '.*foo.*'. - ignore_files_or = "|".join(mod for mod in ignore_files) fname_pattern = re.compile( - rf".*[\\/]({ignore_files_or})([\\/]|\.pyc?$|$)" - rf"|$" + r".*[\\/](" + "|".join(mod for mod in ignore_files) + r")([\\/]|$)" ) else: # stacklevel 0 or less - shows dask internals which likely isn't helpful @@ -5193,7 +5178,8 @@ def register_scheduler_plugin( Do not re-register if a plugin of the given name already exists. """ warnings.warn( - "`Client.register_scheduler_plugin` has been deprecated; please `Client.register_plugin` instead", + "`Client.register_scheduler_plugin` has been deprecated; " + "please `Client.register_plugin` instead", DeprecationWarning, stacklevel=2, ) @@ -5362,7 +5348,8 @@ def register_worker_plugin( unregister_worker_plugin """ warnings.warn( - "`Client.register_worker_plugin` has been deprecated; please use `Client.register_plugin` instead", + "`Client.register_worker_plugin` has been deprecated; " + "please use `Client.register_plugin` instead", DeprecationWarning, stacklevel=2, )