diff --git a/distributed/__init__.py b/distributed/__init__.py index 3f075b977c..091c14e0eb 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -146,3 +146,4 @@ "widgets", "worker_client", ] +from distributed.condition import Condition diff --git a/distributed/condition.py b/distributed/condition.py new file mode 100644 index 0000000000..6714935df8 --- /dev/null +++ b/distributed/condition.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections import defaultdict +from contextlib import suppress + +from distributed.lock import Lock +from distributed.utils import log_errors +from distributed.worker import get_client + +logger = logging.getLogger(__name__) + + +class ConditionExtension: + """Scheduler extension for managing distributed Conditions. + + Waiters are keyed by (name, waiter_id) so multiple waiters from the + same client can coexist without overwriting each other's events. + """ + + def __init__(self, scheduler): + self.scheduler = scheduler + # name -> {waiter_id -> asyncio.Event} + self.waiters = defaultdict(dict) + self._closed = False + + self.scheduler.handlers.update( + { + "condition_wait": self.wait, + "condition_notify": self.notify, + "condition_notify_all": self.notify_all, + } + ) + self.scheduler.extensions["conditions"] = self + + async def close(self): + """Cancel all pending waiters so scheduler can shut down cleanly.""" + self._closed = True + for name in list(self.waiters): + for event in self.waiters[name].values(): + event.set() + del self.waiters[name] + + @log_errors + async def wait(self, name=None, waiter_id=None, **kwargs): + """Register a waiter and block until notified or closed.""" + if self._closed: + return + + event = asyncio.Event() + self.waiters[name][waiter_id] = event + + try: + await event.wait() + finally: + with suppress(KeyError): + del self.waiters[name][waiter_id] + if not self.waiters[name]: + del self.waiters[name] + + async def notify(self, name=None, n=1, **kwargs): + """Wake up to n waiters.""" + if name not in self.waiters: + return + + notified = 0 + for wid in list(self.waiters[name]): + if notified >= n: + break + event = self.waiters[name].get(wid) + if event and not event.is_set(): + event.set() + notified += 1 + + async def notify_all(self, name=None, **kwargs): + """Wake all waiters.""" + if name not in self.waiters: + return + + for event in self.waiters[name].values(): + if not event.is_set(): + event.set() + + +class Condition: + """Distributed Condition Variable + + A distributed version of asyncio.Condition. Allows one or more clients + to wait until notified by another client. + + Like asyncio.Condition, this must be used with a lock. The lock is + released before waiting and reacquired afterwards. + + Parameters + ---------- + name : str, optional + Name of the condition. If not provided, a random name is generated. + client : Client, optional + Client instance. If not provided, uses the default client. + lock : Lock, optional + Lock to use with this condition. If not provided, creates a new Lock. + + Examples + -------- + >>> from distributed import Client, Condition + >>> client = Client() # doctest: +SKIP + >>> condition = Condition() # doctest: +SKIP + + >>> async with condition: # doctest: +SKIP + ... await condition.wait() + + >>> async with condition: # doctest: +SKIP + ... condition.notify() # Wake one waiter + """ + + def __init__(self, name=None, client=None, lock=None): + self._client = client + self.name = name or "condition-" + uuid.uuid4().hex + + if lock is None: + lock = Lock() + elif not isinstance(lock, Lock): + raise TypeError(f"lock must be a Lock, not {type(lock)}") + + self._lock = lock + + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. " + "Ensure it's created within a Client context." + ) + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.release() + + def __repr__(self): + return f"" + + async def acquire(self, timeout=None): + self._verify_running() + return await self._lock.acquire(timeout=timeout) + + async def release(self): + self._verify_running() + return await self._lock.release() + + async def locked(self): + return await self._lock.locked() + + async def wait(self, timeout=None): + """Wait until notified. + + Releases the underlying lock, waits until notified, then reacquires + the lock before returning. Must be called with the lock held. + + Returns True if woken by notify, False on timeout. + """ + self._verify_running() + await self.release() + + # Each wait() call gets a unique ID so the scheduler can track + # multiple waiters from the same client independently. + waiter_id = uuid.uuid4().hex + + try: + coro = self.client.scheduler.condition_wait( + name=self.name, waiter_id=waiter_id + ) + if timeout is not None: + try: + await asyncio.wait_for(coro, timeout=timeout) + return True + except asyncio.TimeoutError: + return False + else: + await coro + return True + finally: + # Always reacquire lock — mirrors asyncio.Condition semantics + try: + await self.acquire() + except asyncio.CancelledError: + with suppress(Exception): + await asyncio.shield(self.acquire()) + raise + + async def wait_for(self, predicate, timeout=None): + """Wait until predicate() returns True. + + Returns the predicate result (True unless timeout). + """ + result = predicate() + if result: + return result + + if timeout is not None: + import time + + deadline = time.monotonic() + timeout + while not result: + remaining = deadline - time.monotonic() + if remaining <= 0: + return predicate() + if not await self.wait(timeout=remaining): + return predicate() + result = predicate() + else: + while not result: + await self.wait() + result = predicate() + + return result + + async def notify(self, n=1): + """Wake up n waiters (default: 1).""" + self._verify_running() + await self.client.scheduler.condition_notify(name=self.name, n=n) + + async def notify_all(self): + """Wake up all waiters.""" + self._verify_running() + await self.client.scheduler.condition_notify_all(name=self.name) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d73da7c71f..897cfbb32f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -97,6 +97,7 @@ ) from distributed.comm.addressing import addresses_from_user_args from distributed.compatibility import PeriodicCallback +from distributed.condition import ConditionExtension from distributed.core import ( ErrorMessage, OKMessage, @@ -195,6 +196,7 @@ "semaphores": SemaphoreExtension, "events": EventExtension, "amm": ActiveMemoryManagerExtension, + "conditions": ConditionExtension, "memory_sampler": MemorySamplerExtension, "shuffle": ShuffleSchedulerPlugin, "spans": SpansSchedulerExtension, diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py new file mode 100644 index 0000000000..f1a7f9e435 --- /dev/null +++ b/distributed/tests/test_condition.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from distributed import Condition, Lock +from distributed.metrics import time +from distributed.utils_test import gen_cluster + + +@gen_cluster(client=True) +async def test_condition_basic(c, s, a, b): + condition = Condition() + results = [] + + async def waiter(): + async with condition: + results.append("waiting") + await condition.wait() + results.append("notified") + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + assert results == ["waiting"] + + async with condition: + await condition.notify() + + await task + assert results == ["waiting", "notified"] + + +@gen_cluster(client=True) +async def test_condition_notify_one(c, s, a, b): + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + tasks = [asyncio.create_task(waiter(i)) for i in range(3)] + await asyncio.sleep(0.1) + + async with condition: + await condition.notify() + await asyncio.sleep(0.1) + assert len(results) == 1 + + async with condition: + await condition.notify() + await asyncio.sleep(0.1) + assert len(results) == 2 + + async with condition: + await condition.notify_all() + await asyncio.gather(*tasks) + assert len(results) == 3 + + +@gen_cluster(client=True) +async def test_condition_notify_n(c, s, a, b): + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + tasks = [asyncio.create_task(waiter(i)) for i in range(5)] + await asyncio.sleep(0.1) + + async with condition: + await condition.notify(2) + await asyncio.sleep(0.1) + assert len(results) == 2 + + async with condition: + await condition.notify(3) + await asyncio.gather(*tasks) + assert len(results) == 5 + + +@gen_cluster(client=True) +async def test_condition_notify_all(c, s, a, b): + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + tasks = [asyncio.create_task(waiter(i)) for i in range(10)] + await asyncio.sleep(0.1) + + async with condition: + await condition.notify_all() + + await asyncio.gather(*tasks) + assert len(results) == 10 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout(c, s, a, b): + condition = Condition() + + async with condition: + start = time() + result = await condition.wait(timeout=0.2) + elapsed = time() - start + + assert result is False + assert 0.15 < elapsed < 1.0 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout_then_notify(c, s, a, b): + condition = Condition() + + async def notifier(): + await asyncio.sleep(0.1) + async with condition: + await condition.notify() + + task = asyncio.create_task(notifier()) + + async with condition: + result = await condition.wait(timeout=2.0) + + await task + assert result is True + + +@gen_cluster(client=True) +async def test_condition_wait_for(c, s, a, b): + condition = Condition() + state = {"value": 0} + + async def incrementer(): + for _ in range(5): + await asyncio.sleep(0.05) + async with condition: + state["value"] += 1 + await condition.notify_all() + + task = asyncio.create_task(incrementer()) + + async with condition: + result = await condition.wait_for(lambda: state["value"] >= 3) + + await task + assert result is True + assert state["value"] >= 3 + + +@gen_cluster(client=True) +async def test_condition_wait_for_timeout(c, s, a, b): + condition = Condition() + + async with condition: + result = await condition.wait_for(lambda: False, timeout=0.2) + + assert result is False + + +@gen_cluster(client=True) +async def test_condition_wait_for_already_true(c, s, a, b): + """wait_for returns immediately if predicate is already true.""" + condition = Condition() + + async with condition: + result = await condition.wait_for(lambda: True) + + assert result is True + + +@gen_cluster(client=True) +async def test_condition_context_manager(c, s, a, b): + condition = Condition() + assert not await condition.locked() + + async with condition: + assert await condition.locked() + + assert not await condition.locked() + + +@gen_cluster(client=True) +async def test_condition_with_explicit_lock(c, s, a, b): + lock = Lock() + condition = Condition(lock=lock) + + async with lock: + assert await condition.locked() + + assert not await condition.locked() + + +@gen_cluster(client=True) +async def test_condition_multiple_notify_calls(c, s, a, b): + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + tasks = [asyncio.create_task(waiter(i)) for i in range(3)] + await asyncio.sleep(0.1) + + for _ in range(3): + async with condition: + await condition.notify() + await asyncio.sleep(0.05) + + await asyncio.gather(*tasks) + assert set(results) == {0, 1, 2} + + +@gen_cluster(client=True) +async def test_condition_notify_without_waiters(c, s, a, b): + condition = Condition() + + async with condition: + await condition.notify() + await condition.notify_all() + await condition.notify(5) + + +@gen_cluster(client=True) +async def test_condition_producer_consumer(c, s, a, b): + condition = Condition() + queue = [] + + async def producer(): + for i in range(5): + await asyncio.sleep(0.05) + async with condition: + queue.append(i) + await condition.notify() + + async def consumer(): + items = [] + for _ in range(5): + async with condition: + await condition.wait_for(lambda: len(queue) > 0) + items.append(queue.pop(0)) + return items + + prod = asyncio.create_task(producer()) + cons = asyncio.create_task(consumer()) + + result = await cons + await prod + assert result == [0, 1, 2, 3, 4] + + +@gen_cluster(client=True) +async def test_condition_same_name(c, s, a, b): + cond1 = Condition(name="shared") + cond2 = Condition(name="shared") + result = [] + + async def waiter(): + async with cond1: + await cond1.wait() + result.append("done") + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + + async with cond2: + await cond2.notify() + + await task + assert result == ["done"] + + +@gen_cluster(client=True) +async def test_condition_repr(c, s, a, b): + condition = Condition(name="test-cond") + assert "test-cond" in repr(condition) + + +@gen_cluster(client=True) +async def test_condition_waiter_cancelled(c, s, a, b): + condition = Condition() + + async def waiter(): + async with condition: + await condition.wait() + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Lock should be released — verify no deadlock + async with condition: + await condition.notify() + + +@gen_cluster(client=True) +async def test_condition_locked_status(c, s, a, b): + condition = Condition() + assert not await condition.locked() + + await condition.acquire() + assert await condition.locked() + + await condition.release() + assert not await condition.locked() + + +@gen_cluster(client=True) +async def test_condition_reacquire_after_wait(c, s, a, b): + condition = Condition() + lock_states = [] + + async def waiter(): + async with condition: + lock_states.append(("before_wait", await condition.locked())) + await condition.wait() + lock_states.append(("after_wait", await condition.locked())) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + + async with condition: + lock_states.append(("notifier", await condition.locked())) + await condition.notify() + + await task + + assert lock_states[0] == ("before_wait", True) + assert lock_states[1] == ("notifier", True) + assert lock_states[2] == ("after_wait", True) + + +@gen_cluster(client=True) +async def test_condition_many_waiters(c, s, a, b): + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + tasks = [asyncio.create_task(waiter(i)) for i in range(50)] + await asyncio.sleep(0.3) + + async with condition: + await condition.notify_all() + + await asyncio.gather(*tasks) + assert len(results) == 50 + assert set(results) == set(range(50))