From 09969c584d8f3ac9ad7cb3739cf8f185a8721ec1 Mon Sep 17 00:00:00 2001 From: David Skoog Date: Fri, 16 Jan 2026 12:20:52 -0500 Subject: [PATCH 1/4] Rework map_async to handle failures better `map` stops the flow of items in the stream when the function raises but `map_async` is outside of the direct line of return so it fails weirdly during an exception. To address that, I added the idea about stopping the stream or not. This way, if the stream does not deliberately invoke `stop` during an exception, the stream continues to process inputs after an exception. Since the `map_async` now conceives of stopping or not, I added a boolean in the node state to control the loop inside the worker task. In the case of an exception during mapping, `map_async` will now release the references held on the metadata for the offending input. I added an example that shows off the failure modes of `map` and `map_async` that plainly demonstrates that exceptions can leave the stream in a weird state. --- examples/map_failure_modes.py | 45 +++++++++++++++++++++++++++++++++++ streamz/core.py | 22 ++++++++++++----- 2 files changed, 61 insertions(+), 6 deletions(-) create mode 100644 examples/map_failure_modes.py diff --git a/examples/map_failure_modes.py b/examples/map_failure_modes.py new file mode 100644 index 00000000..cb5366e9 --- /dev/null +++ b/examples/map_failure_modes.py @@ -0,0 +1,45 @@ +import asyncio +from itertools import count +from streamz import Stream + + +async def flaky_async(x, from_where): + return flaky_sync(x, from_where) + + +def flaky_sync(x, from_where): + if x % 5 == 4: + raise ValueError(f"I flaked out on {from_where}") + return x + + +def make_counter(name): + return Stream.from_iterable(count(), asynchronous=True, stream_name=name) + + +async def main(): + async_non_stop_source = make_counter("async not stopping") + s_async = async_non_stop_source.map_async(flaky_async, async_non_stop_source) + s_async.rate_limit("500ms").sink(print, async_non_stop_source.name) + + sync_source = make_counter("sync") + s_sync = sync_source.map(flaky_sync, sync_source) + s_sync.rate_limit("500ms").sink(print, sync_source.name) + + async_stopping_source = make_counter("async stopping") + s_async = async_stopping_source.map_async(flaky_async, async_stopping_source, stop_on_exception=True) + s_async.rate_limit("500ms").sink(print, async_stopping_source.name) + + async_non_stop_source.start() + sync_source.start() + async_stopping_source.start() + print(f"{async_non_stop_source.started=}, {sync_source.started=}, {async_stopping_source.started=}") + await asyncio.sleep(3) + print(f"{async_non_stop_source.stopped=}, {sync_source.stopped=}, {async_stopping_source.stopped=}") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/streamz/core.py b/streamz/core.py index 00b5ed4c..2efc5113 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -730,6 +730,8 @@ class map_async(Stream): The arguments to pass to the function. parallelism: The maximum number of parallel Tasks for evaluating func, default value is 1 + stop_on_exception: + If the mapped func raises an exception, should the stream stop or not. Default value is False. **kwargs: Keyword arguments to pass to func @@ -749,16 +751,23 @@ class map_async(Stream): 6 8 """ - def __init__(self, upstream, func, *args, parallelism=1, **kwargs): + def __init__(self, upstream, func, *args, parallelism=1, stop_on_exception=False, **kwargs): self.func = func stream_name = kwargs.pop('stream_name', None) self.kwargs = kwargs self.args = args + self.running = True + self.stop_on_exception = stop_on_exception self.work_queue = asyncio.Queue(maxsize=parallelism) Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True) self.work_task = self._create_task(self.work_callback()) + def stop(self): + if self.running: + self.running = False + super().stop() + def update(self, x, who=None, metadata=None): return self._create_task(self._insert_job(x, metadata)) @@ -768,19 +777,20 @@ def _create_task(self, coro): return self.loop.asyncio_loop.create_task(coro) async def work_callback(self): - while True: + while self.running: + task, metadata = await self.work_queue.get() + self.work_queue.task_done() try: - task, metadata = await self.work_queue.get() - self.work_queue.task_done() result = await task except Exception as e: logger.exception(e) - raise + if self.stop_on_exception: + self.stop() else: results = self._emit(result, metadata=metadata) if results: await asyncio.gather(*results) - self._release_refs(metadata) + self._release_refs(metadata) async def _wait_for_work_slot(self): while self.work_queue.full(): From c35772947f540f100db170f0fd37118c03284242 Mon Sep 17 00:00:00 2001 From: David Skoog Date: Fri, 16 Jan 2026 18:31:48 -0500 Subject: [PATCH 2/4] from_iterable was over-consuming during stop --- streamz/sources.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streamz/sources.py b/streamz/sources.py index 777f9181..940a2e9f 100644 --- a/streamz/sources.py +++ b/streamz/sources.py @@ -786,6 +786,8 @@ async def run(self): if self.stopped: break await asyncio.gather(*self._emit(x)) + if self.stopped: + break self.stopped = True From 96463b5cf67e99341e2799d5d764c402bc72f763 Mon Sep 17 00:00:00 2001 From: David Skoog Date: Fri, 16 Jan 2026 18:32:45 -0500 Subject: [PATCH 3/4] Make map_async restartable --- streamz/core.py | 43 +++++++++++++++++++++++++++++++------- streamz/tests/test_core.py | 22 +++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/streamz/core.py b/streamz/core.py index 2efc5113..8fea2813 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures from collections import deque, defaultdict from datetime import timedelta from itertools import chain @@ -6,7 +7,7 @@ import logging import threading from time import time -from typing import Any, Callable, Hashable, Union +from typing import Any, Callable, Coroutine, Hashable, Tuple, Union, overload import weakref import toolz @@ -756,28 +757,54 @@ def __init__(self, upstream, func, *args, parallelism=1, stop_on_exception=False stream_name = kwargs.pop('stream_name', None) self.kwargs = kwargs self.args = args - self.running = True self.stop_on_exception = stop_on_exception self.work_queue = asyncio.Queue(maxsize=parallelism) Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True) - self.work_task = self._create_task(self.work_callback()) + self.work_task = None + + def _create_work_task(self) -> Tuple[asyncio.Event, asyncio.Task[None]]: + stop_work = asyncio.Event() + work_task = self._create_task(self.work_callback(stop_work)) + return stop_work, work_task + + def start(self): + if self.work_task: + stop_work, _ = self.work_task + stop_work.set() + self.work_task = self._create_work_task() + super().start() def stop(self): - if self.running: - self.running = False - super().stop() + stop_work, _ = self.work_task + stop_work.set() + self.work_task = None + super().stop() def update(self, x, who=None, metadata=None): + if not self.work_task: + self.work_task = self._create_work_task() return self._create_task(self._insert_job(x, metadata)) + @overload + def _create_task(self, coro: asyncio.Future) -> asyncio.Future: + ... + + @overload + def _create_task(self, coro: concurrent.futures.Future) -> concurrent.futures.Future: + ... + + @overload + def _create_task(self, coro: Coroutine) -> asyncio.Task: + ... + def _create_task(self, coro): if gen.is_future(coro): return coro return self.loop.asyncio_loop.create_task(coro) - async def work_callback(self): - while self.running: + async def work_callback(self, stop_work: asyncio.Event): + while not stop_work.is_set(): task, metadata = await self.work_queue.get() self.work_queue.task_done() try: diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 9245f2e6..2dcc786c 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -151,6 +151,28 @@ def fail_func(): assert (time() - start) == pytest.approx(0.1, abs=4e-3) +@pytest.mark.asyncio +async def test_map_async_restart(): + async def flake_out(x): + if x == 2: + raise RuntimeError("I fail on 2.") + if x > 4: + raise RuntimeError("I fail on > 4.") + return x + + source = Stream.from_iterable(itertools.count()) + mapped = source.map_async(flake_out, stop_on_exception=True) + results = mapped.sink_to_list() + source.start() + + await await_for(lambda: results == [0, 1], 1) + await await_for(lambda: not mapped.work_task, 1) + + source.start() + + await await_for(lambda: results == [0, 1, 3, 4], 1) + + @pytest.mark.asyncio async def test_map_async(): @gen.coroutine From 9a7b3eda09d2268e8d62cb856d64e0ab592fde5b Mon Sep 17 00:00:00 2001 From: David Skoog Date: Fri, 16 Jan 2026 18:36:57 -0500 Subject: [PATCH 4/4] Show off restarting map_async when it stops --- examples/map_failure_modes.py | 41 +++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/examples/map_failure_modes.py b/examples/map_failure_modes.py index cb5366e9..4fd11d10 100644 --- a/examples/map_failure_modes.py +++ b/examples/map_failure_modes.py @@ -1,4 +1,5 @@ import asyncio +import sys from itertools import count from streamz import Stream @@ -9,7 +10,7 @@ async def flaky_async(x, from_where): def flaky_sync(x, from_where): if x % 5 == 4: - raise ValueError(f"I flaked out on {from_where}") + raise ValueError(f"I flaked out on {x} for {from_where}") return x @@ -17,29 +18,45 @@ def make_counter(name): return Stream.from_iterable(count(), asynchronous=True, stream_name=name) -async def main(): +async def main(run_flags): async_non_stop_source = make_counter("async not stopping") - s_async = async_non_stop_source.map_async(flaky_async, async_non_stop_source) - s_async.rate_limit("500ms").sink(print, async_non_stop_source.name) + s_async = async_non_stop_source.rate_limit("500ms").map_async(flaky_async, async_non_stop_source) + s_async.sink(print, async_non_stop_source.name) sync_source = make_counter("sync") - s_sync = sync_source.map(flaky_sync, sync_source) - s_sync.rate_limit("500ms").sink(print, sync_source.name) + s_sync = sync_source.rate_limit("500ms").map(flaky_sync, sync_source) + s_sync.sink(print, sync_source.name) async_stopping_source = make_counter("async stopping") - s_async = async_stopping_source.map_async(flaky_async, async_stopping_source, stop_on_exception=True) - s_async.rate_limit("500ms").sink(print, async_stopping_source.name) + s_async_stop = async_stopping_source.rate_limit("500ms").map_async(flaky_async, async_stopping_source, stop_on_exception=True) + s_async_stop.sink(print, async_stopping_source.name) + + if run_flags[0]: + async_non_stop_source.start() + if run_flags[1]: + sync_source.start() + if run_flags[2]: + async_stopping_source.start() - async_non_stop_source.start() - sync_source.start() - async_stopping_source.start() print(f"{async_non_stop_source.started=}, {sync_source.started=}, {async_stopping_source.started=}") await asyncio.sleep(3) print(f"{async_non_stop_source.stopped=}, {sync_source.stopped=}, {async_stopping_source.stopped=}") + if run_flags[2]: + print() + print(f"Restarting {async_stopping_source}") + async_stopping_source.start() + print() + await asyncio.sleep(2) + print(f"{async_non_stop_source.stopped=}, {sync_source.stopped=}, {async_stopping_source.stopped=}") + if __name__ == "__main__": try: - asyncio.run(main()) + if len(sys.argv) > 1: + flags = [char == "T" for char in sys.argv[1]] + else: + flags = [True, True, True] + asyncio.run(main(flags)) except KeyboardInterrupt: pass