diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index f57102b2bce..6bd7df112ee 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio import contextlib import copy import functools @@ -198,6 +199,18 @@ def __init__( self.type = resolved_type self.type_ = resolved_type + # Serialize per-executor message processing. Within a superstep the runner may + # dispatch deliveries to the same target executor from multiple sources + # concurrently; this lock guarantees the executor processes them one at a time + # (and, per source, in the order they were sent). + # + # The lock must be created lazily under the running loop (see ``_get_execution_lock``) + # in order to support executors that are constructed outside an event loop and reused + # across multiple async loops (e.g., successive ``asyncio.run`` calls on the same workflow). + # Binding a single lock to one loop would raise "bound to a different event loop" on reuse. + self._execution_lock: asyncio.Lock | None = None + self._execution_lock_loop: asyncio.AbstractEventLoop | None = None + from builtins import type as builtin_type self._handlers: dict[ @@ -216,6 +229,20 @@ def __init__( # Initialize RequestInfoMixin to discover response handlers self._discover_response_handlers() + def _get_execution_lock(self) -> asyncio.Lock: + """Return this executor's serialization lock, bound to the running event loop. + + The lock is created lazily and re-created if the running loop has changed since it + was last used (for example, the executor is reused across successive + ``asyncio.run`` calls), avoiding ``asyncio.Lock`` "bound to a different event loop" + errors. Must be called from within a running loop. + """ + loop = asyncio.get_running_loop() + if self._execution_lock is None or self._execution_lock_loop is not loop: + self._execution_lock = asyncio.Lock() + self._execution_lock_loop = loop + return self._execution_lock + async def execute( self, message: Any, @@ -241,57 +268,58 @@ async def execute( Returns: An awaitable that resolves to the result of the execution. """ - # Create processing span for tracing (gracefully handles disabled tracing) - with create_processing_span( - self.id, - self.__class__.__name__, - str(MessageType.STANDARD if not isinstance(message, WorkflowMessage) else message.type), - type(message).__name__, - source_trace_contexts=trace_contexts, - source_span_ids=source_span_ids, - ): - # Find the handler and handler spec that matches the message type. - handler = self._find_handler(message) - - original_message = message - if isinstance(message, WorkflowMessage): - # Unwrap raw data for handler call - message = message.data - - # Create the appropriate WorkflowContext based on handler specs - context = self._create_context_for_handler( - source_executor_ids=source_executor_ids, - state=state, - runner_context=runner_context, - trace_contexts=trace_contexts, + async with self._get_execution_lock(): + # Create processing span for tracing (gracefully handles disabled tracing) + with create_processing_span( + self.id, + self.__class__.__name__, + str(MessageType.STANDARD if not isinstance(message, WorkflowMessage) else message.type), + type(message).__name__, + source_trace_contexts=trace_contexts, source_span_ids=source_span_ids, - request_id=original_message.original_request_info_event.request_id - if isinstance(original_message, WorkflowMessage) and original_message.original_request_info_event - else None, - ) + ): + # Find the handler and handler spec that matches the message type. + handler = self._find_handler(message) + + original_message = message + if isinstance(message, WorkflowMessage): + # Unwrap raw data for handler call + message = message.data + + # Create the appropriate WorkflowContext based on handler specs + context = self._create_context_for_handler( + source_executor_ids=source_executor_ids, + state=state, + runner_context=runner_context, + trace_contexts=trace_contexts, + source_span_ids=source_span_ids, + request_id=original_message.original_request_info_event.request_id + if isinstance(original_message, WorkflowMessage) and original_message.original_request_info_event + else None, + ) - # Invoke the handler with the message and context - # Use deepcopy to capture original input state before handler can mutate it - with _framework_event_origin(): - invoke_event = WorkflowEvent.executor_invoked(self.id, copy.deepcopy(message)) - await context.add_event(invoke_event) - try: - await handler(message, context) - except Exception as exc: - # Surface structured executor failure before propagating + # Invoke the handler with the message and context + # Use deepcopy to capture original input state before handler can mutate it with _framework_event_origin(): - failure_event = WorkflowEvent.executor_failed(self.id, WorkflowErrorDetails.from_exception(exc)) - await context.add_event(failure_event) - raise - with _framework_event_origin(): - # Include sent messages and yielded outputs as the completion data - sent_messages = context.get_sent_messages() - yielded_outputs = context.get_yielded_outputs() - completion_data = sent_messages + yielded_outputs - completed_event = WorkflowEvent.executor_completed( - self.id, completion_data if completion_data else None - ) - await context.add_event(completed_event) + invoke_event = WorkflowEvent.executor_invoked(self.id, copy.deepcopy(message)) + await context.add_event(invoke_event) + try: + await handler(message, context) + except Exception as exc: + # Surface structured executor failure before propagating + with _framework_event_origin(): + failure_event = WorkflowEvent.executor_failed(self.id, WorkflowErrorDetails.from_exception(exc)) + await context.add_event(failure_event) + raise + with _framework_event_origin(): + # Include sent messages and yielded outputs as the completion data + sent_messages = context.get_sent_messages() + yielded_outputs = context.get_yielded_outputs() + completion_data = sent_messages + yielded_outputs + completed_event = WorkflowEvent.executor_completed( + self.id, completion_data if completion_data else None + ) + await context.add_event(completed_event) def _create_context_for_handler( self, diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index a6bbbe3c73d..6cbee0b039c 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -285,8 +285,13 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None): checkpoint_storage: Optional storage to enable checkpointing. """ self._messages: dict[str, list[WorkflowMessage]] = {} - # Event queue for immediate streaming of events - self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() + + # The queue must be created lazily under the running loop (see ``_get_event_queue``) + # in order to support contexts that are constructed outside an event loop and reused + # across multiple async loops (e.g., successive ``asyncio.run`` calls on the same workflow). + # Binding a single queue to one loop would raise "bound to a different event loop" on reuse. + self._event_queue: asyncio.Queue[WorkflowEvent] | None = None + self._event_queue_loop: asyncio.AbstractEventLoop | None = None # An additional storage for pending request info events self._pending_request_info_events: dict[str, WorkflowEvent[Any]] = {} @@ -312,33 +317,42 @@ async def drain_messages(self) -> dict[str, list[WorkflowMessage]]: async def has_messages(self) -> bool: return bool(self._messages) + def _get_event_queue(self) -> asyncio.Queue[WorkflowEvent]: + """Return the event queue bound to the running loop, re-creating it on loop change.""" + loop = asyncio.get_running_loop() + if self._event_queue is None or self._event_queue_loop is not loop: + self._event_queue = asyncio.Queue() + self._event_queue_loop = loop + return self._event_queue + async def add_event(self, event: WorkflowEvent) -> None: """Add an event to the context immediately. Events are enqueued so runners can stream them in real time instead of waiting for superstep boundaries. """ - await self._event_queue.put(event) + await self._get_event_queue().put(event) async def drain_events(self) -> list[WorkflowEvent]: """Drain all currently queued events without blocking for new ones.""" events: list[WorkflowEvent] = [] + queue = self._get_event_queue() while True: try: - events.append(self._event_queue.get_nowait()) + events.append(queue.get_nowait()) except asyncio.QueueEmpty: break return events async def has_events(self) -> bool: - return not self._event_queue.empty() + return not self._get_event_queue().empty() async def next_event(self) -> WorkflowEvent: """Wait for and return the next event. Used by the runner to interleave event emission with ongoing iteration work. """ - return await self._event_queue.get() + return await self._get_event_queue().get() # endregion Messaging and Events @@ -407,8 +421,10 @@ def reset_for_new_run(self) -> None: Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level. """ self._messages.clear() - # Clear any pending events (best-effort) by recreating the queue - self._event_queue = asyncio.Queue() + # Drop any pending events. The queue and its loop marker are cleared so the queue + # rebinds lazily under the running loop on next use. + self._event_queue = None + self._event_queue_loop = None self._streaming = False # Reset streaming flag async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 260613c74a1..c77688c03c2 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -30,6 +30,7 @@ WorkflowEvent, WorkflowException, WorkflowMessage, + WorkflowRunResult, WorkflowRunState, handler, response_handler, @@ -1079,6 +1080,117 @@ async def test_workflow_partial_stream_does_not_clobber_successor_runtime_storag await asyncio.sleep(0) +async def test_workflow_serializes_concurrent_delivery_to_same_executor(): + """Messages delivered to one executor within a superstep must be processed serially. + + A start executor fans out to two intermediate executors that both send to a single + target in the same superstep. The runner dispatches those two deliveries concurrently + (they have different source executors), but the target must process them one at a time + rather than interleaving at ``await`` points. + """ + + class _FanSource(Executor): + def __init__(self, id: str, label: str) -> None: + super().__init__(id=id) + self._label = label + + @handler + async def run(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(self._label) + + class _SerialTarget(Executor): + def __init__(self, id: str) -> None: + super().__init__(id=id) + self.active = 0 + self.max_active = 0 + self.received: list[str] = [] + + @handler + async def run(self, message: str, ctx: WorkflowContext[str]) -> None: + self.active += 1 + self.max_active = max(self.max_active, self.active) + # Yield control. If execution were not serialized per executor, a concurrent + # delivery would enter here and push ``active`` (and ``max_active``) to 2. + await asyncio.sleep(0) + self.received.append(message) + self.active -= 1 + + start = _FanSource(id="start", label="go") + source_a = _FanSource(id="a", label="from_a") + source_b = _FanSource(id="b", label="from_b") + target = _SerialTarget(id="target") + + # superstep 1: start -> {a, b}; superstep 2: a -> target, b -> target (both in the + # same superstep); superstep 3: target receives both deliveries and processes them serially. + workflow = ( + WorkflowBuilder(start_executor=start) + .add_edge(start, source_a) + .add_edge(start, source_b) + .add_edge(source_a, target) + .add_edge(source_b, target) + .build() + ) + + await workflow.run("go") + + assert target.max_active == 1, "Target processed concurrent deliveries (executions overlapped)" + assert sorted(target.received) == ["from_a", "from_b"] + + +def test_executor_serialization_lock_is_loop_scoped(): + """The per-executor serialization lock must be created under the running loop. + + Executors are often constructed outside an event loop and may be reused across loops + (e.g. successive ``asyncio.run`` calls). The lock is created lazily and re-created when + the running loop changes, so reuse never raises ``asyncio.Lock`` "bound to a different + event loop". Creating it eagerly in ``__init__`` and binding it to the first loop would. + """ + + class _Noop(Executor): + @handler + async def run(self, message: str, ctx: WorkflowContext[str]) -> None: ... + + executor = _Noop(id="noop") + + async def _grab_lock() -> asyncio.Lock: + return executor._get_execution_lock() # pyright: ignore[reportPrivateUsage] + + # Each ``asyncio.run`` uses a fresh event loop; the lock must be re-created per loop. + lock_loop_1 = asyncio.run(_grab_lock()) + lock_loop_2 = asyncio.run(_grab_lock()) + + assert lock_loop_1 is not lock_loop_2 + + +def test_workflow_instance_can_be_reused_across_event_loops(): + """A workflow built once can be re-run across separate event loops. + + Both the per-executor ``asyncio.Lock`` and the runner context's ``asyncio.Queue`` bind to + the first event loop they are awaited under. They are re-created lazily under the running + loop, so successive ``asyncio.run`` calls on the same workflow instance do not raise + "bound to a different event loop". + """ + + class _Echo(Executor): + @handler + async def run(self, message: str, ctx: WorkflowContext[Any, str]) -> None: + await ctx.yield_output(message) + + workflow = WorkflowBuilder(start_executor=_Echo(id="echo")).build() + + async def _run(message: str) -> WorkflowRunResult: + return await workflow.run(message) + + # A fresh event loop per run; reuse must not raise "bound to a different event loop". + result_1 = asyncio.run(_run("a")) + result_2 = asyncio.run(_run("b")) + + assert result_1.get_final_state() == WorkflowRunState.IDLE + assert result_2.get_final_state() == WorkflowRunState.IDLE + assert result_1.get_outputs() == ["a"] + assert result_2.get_outputs() == ["b"] + + class _StreamingTestAgent(BaseAgent): """Test agent that supports both streaming and non-streaming modes."""