From 5763bca5324bd094165984b9506b11bc0db53efc Mon Sep 17 00:00:00 2001 From: monoxgas Date: Wed, 23 Jul 2025 17:29:23 -0600 Subject: [PATCH] Fix error handling for pipeline callbacks so they respect on_failed/catch settings --- rigging/chat.py | 28 +++++++++++++------ tests/test_chat_pipeline.py | 56 +++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/rigging/chat.py b/rigging/chat.py index f2842b5..8f5f964 100644 --- a/rigging/chat.py +++ b/rigging/chat.py @@ -1845,13 +1845,12 @@ async def _step( # noqa: PLR0915, PLR0912 *[state.ready_event.wait() for state in states if not state.completed], ) - # TODO(nick): Are we good to throw exceptions here? - for task in tasks: + for state, task in zip(states, tasks, strict=True): if task.done() and (exception := task.exception()): - raise exception + state.chat.error = exception + state.chat.failed = True - for state in states: - if state.ready_event.is_set() and state.step: + elif state.ready_event.is_set() and state.step: step = state.step.with_parent(current_step) if step.depth > max_depth: @@ -1875,10 +1874,12 @@ async def _step( # noqa: PLR0915, PLR0912 for task in tasks: if not task.done(): task.cancel() - await asyncio.gather(*tasks) # TODO(nick): return_exceptions=True ? + await asyncio.gather(*tasks, return_exceptions=True) chats = ChatList([state.chat for state in states if state.chat]) + self._raise_if_failed(chats, on_failed) + current_step = PipelineStep( state="callback", chats=chats, @@ -1917,9 +1918,18 @@ async def _step( # noqa: PLR0915, PLR0912 ) async with contextlib.AsyncExitStack() as exit_stack: - result = map_task(chats) - if inspect.isawaitable(result): - result = await result + try: + result = map_task(chats) + if inspect.isawaitable(result): + result = await result + except Exception as e: # noqa: BLE001 + # If the map raised an exception, assign it to all the chats + for chat in chats: + chat.error = e + chat.failed = True + + self._raise_if_failed(chats, on_failed) + continue if isinstance(result, contextlib.AbstractAsyncContextManager): result = await exit_stack.enter_async_context(result) diff --git a/tests/test_chat_pipeline.py b/tests/test_chat_pipeline.py index ae16f04..0042d73 100644 --- a/tests/test_chat_pipeline.py +++ b/tests/test_chat_pipeline.py @@ -233,3 +233,59 @@ async def watch_function(chats: list[Chat]) -> None: # Watch should be called at least once assert len(watch_calls) >= 1 assert all(calls >= 1 for calls in watch_calls) + + +@pytest.mark.asyncio +async def test_map_callback_exception_handling() -> None: + """Test that exceptions in map callback functions are properly caught and assigned to chat.error and chat.failed.""" + + generator = FixedGenerator(model="fixed", text="Response", params=GenerateParams()) + + async def failing_map_callback(chats: list[Chat]) -> list[Chat]: + # Simulate an exception in the map callback + raise RuntimeError("Map callback failure") + + pipeline = generator.chat("test").map(failing_map_callback) + + # Test with default on_failed behavior (should raise) + with pytest.raises(RuntimeError): + await pipeline.run() + + # Should still raise as RuntimeError is not in the default catch list + with pytest.raises(RuntimeError): + await pipeline.run(on_failed="include") + + # Should capture now + chat = await pipeline.catch(RuntimeError, on_failed="include").run() + + assert chat.failed is True + assert isinstance(chat.error, RuntimeError) + assert str(chat.error) == "Map callback failure" + + +@pytest.mark.asyncio +async def test_then_callback_exception_handling() -> None: + """Test that exceptions in then callback functions are properly caught and assigned to chat.error and chat.failed.""" + + generator = FixedGenerator(model="fixed", text="Response", params=GenerateParams()) + + async def failing_then_callback(chat: Chat) -> PipelineStepContextManager: + # Simulate an exception in the then callback + raise RuntimeError("Then callback failure") + + pipeline = generator.chat("test").then(failing_then_callback) + + # Test with default on_failed behavior (should raise) + with pytest.raises(RuntimeError): + await pipeline.run() + + # Should still raise as RuntimeError is not in the default catch list + with pytest.raises(RuntimeError): + await pipeline.run(on_failed="include") + + # Should capture now + chat = await pipeline.catch(RuntimeError, on_failed="include").run() + + assert chat.failed is True + assert isinstance(chat.error, RuntimeError) + assert str(chat.error) == "Then callback failure"