Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 76 additions & 48 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import contextlib
import copy
import functools
Expand Down Expand Up @@ -198,6 +199,18 @@ def __init__(
self.type = resolved_type
self.type_ = resolved_type

# Serialize per-executor message processing. Within a superstep the runner may
# dispatch deliveries to the same target executor from multiple sources
# concurrently; this lock guarantees the executor processes them one at a time
# (and, per source, in the order they were sent).
#
# The lock must be created lazily under the running loop (see ``_get_execution_lock``)
# in order to support executors that are constructed outside an event loop and reused
# across multiple async loops (e.g., successive ``asyncio.run`` calls on the same workflow).
# Binding a single lock to one loop would raise "bound to a different event loop" on reuse.
self._execution_lock: asyncio.Lock | None = None
self._execution_lock_loop: asyncio.AbstractEventLoop | None = None

from builtins import type as builtin_type

self._handlers: dict[
Expand All @@ -216,6 +229,20 @@ def __init__(
# Initialize RequestInfoMixin to discover response handlers
self._discover_response_handlers()

def _get_execution_lock(self) -> asyncio.Lock:
"""Return this executor's serialization lock, bound to the running event loop.

The lock is created lazily and re-created if the running loop has changed since it
was last used (for example, the executor is reused across successive
``asyncio.run`` calls), avoiding ``asyncio.Lock`` "bound to a different event loop"
errors. Must be called from within a running loop.
"""
loop = asyncio.get_running_loop()
if self._execution_lock is None or self._execution_lock_loop is not loop:
self._execution_lock = asyncio.Lock()
self._execution_lock_loop = loop
return self._execution_lock

async def execute(
self,
message: Any,
Expand All @@ -241,57 +268,58 @@ async def execute(
Returns:
An awaitable that resolves to the result of the execution.
"""
# Create processing span for tracing (gracefully handles disabled tracing)
with create_processing_span(
self.id,
self.__class__.__name__,
str(MessageType.STANDARD if not isinstance(message, WorkflowMessage) else message.type),
type(message).__name__,
source_trace_contexts=trace_contexts,
source_span_ids=source_span_ids,
):
# Find the handler and handler spec that matches the message type.
handler = self._find_handler(message)

original_message = message
if isinstance(message, WorkflowMessage):
# Unwrap raw data for handler call
message = message.data

# Create the appropriate WorkflowContext based on handler specs
context = self._create_context_for_handler(
source_executor_ids=source_executor_ids,
state=state,
runner_context=runner_context,
trace_contexts=trace_contexts,
async with self._get_execution_lock():
# Create processing span for tracing (gracefully handles disabled tracing)
with create_processing_span(
self.id,
self.__class__.__name__,
str(MessageType.STANDARD if not isinstance(message, WorkflowMessage) else message.type),
type(message).__name__,
source_trace_contexts=trace_contexts,
source_span_ids=source_span_ids,
request_id=original_message.original_request_info_event.request_id
if isinstance(original_message, WorkflowMessage) and original_message.original_request_info_event
else None,
)
):
# Find the handler and handler spec that matches the message type.
handler = self._find_handler(message)

original_message = message
if isinstance(message, WorkflowMessage):
# Unwrap raw data for handler call
message = message.data

# Create the appropriate WorkflowContext based on handler specs
context = self._create_context_for_handler(
source_executor_ids=source_executor_ids,
state=state,
runner_context=runner_context,
trace_contexts=trace_contexts,
source_span_ids=source_span_ids,
request_id=original_message.original_request_info_event.request_id
if isinstance(original_message, WorkflowMessage) and original_message.original_request_info_event
else None,
)

# Invoke the handler with the message and context
# Use deepcopy to capture original input state before handler can mutate it
with _framework_event_origin():
invoke_event = WorkflowEvent.executor_invoked(self.id, copy.deepcopy(message))
await context.add_event(invoke_event)
try:
await handler(message, context)
except Exception as exc:
# Surface structured executor failure before propagating
# Invoke the handler with the message and context
# Use deepcopy to capture original input state before handler can mutate it
with _framework_event_origin():
failure_event = WorkflowEvent.executor_failed(self.id, WorkflowErrorDetails.from_exception(exc))
await context.add_event(failure_event)
raise
with _framework_event_origin():
# Include sent messages and yielded outputs as the completion data
sent_messages = context.get_sent_messages()
yielded_outputs = context.get_yielded_outputs()
completion_data = sent_messages + yielded_outputs
completed_event = WorkflowEvent.executor_completed(
self.id, completion_data if completion_data else None
)
await context.add_event(completed_event)
invoke_event = WorkflowEvent.executor_invoked(self.id, copy.deepcopy(message))
await context.add_event(invoke_event)
try:
await handler(message, context)
except Exception as exc:
# Surface structured executor failure before propagating
with _framework_event_origin():
failure_event = WorkflowEvent.executor_failed(self.id, WorkflowErrorDetails.from_exception(exc))
await context.add_event(failure_event)
raise
with _framework_event_origin():
# Include sent messages and yielded outputs as the completion data
sent_messages = context.get_sent_messages()
yielded_outputs = context.get_yielded_outputs()
completion_data = sent_messages + yielded_outputs
completed_event = WorkflowEvent.executor_completed(
self.id, completion_data if completion_data else None
)
await context.add_event(completed_event)

def _create_context_for_handler(
self,
Expand Down
32 changes: 24 additions & 8 deletions python/packages/core/agent_framework/_workflows/_runner_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,13 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None):
checkpoint_storage: Optional storage to enable checkpointing.
"""
self._messages: dict[str, list[WorkflowMessage]] = {}
# Event queue for immediate streaming of events
self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue()

# The queue must be created lazily under the running loop (see ``_get_event_queue``)
# in order to support contexts that are constructed outside an event loop and reused
# across multiple async loops (e.g., successive ``asyncio.run`` calls on the same workflow).
# Binding a single queue to one loop would raise "bound to a different event loop" on reuse.
self._event_queue: asyncio.Queue[WorkflowEvent] | None = None
self._event_queue_loop: asyncio.AbstractEventLoop | None = None

# An additional storage for pending request info events
self._pending_request_info_events: dict[str, WorkflowEvent[Any]] = {}
Expand All @@ -312,33 +317,42 @@ async def drain_messages(self) -> dict[str, list[WorkflowMessage]]:
async def has_messages(self) -> bool:
return bool(self._messages)

def _get_event_queue(self) -> asyncio.Queue[WorkflowEvent]:
"""Return the event queue bound to the running loop, re-creating it on loop change."""
loop = asyncio.get_running_loop()
if self._event_queue is None or self._event_queue_loop is not loop:
self._event_queue = asyncio.Queue()
self._event_queue_loop = loop
return self._event_queue

async def add_event(self, event: WorkflowEvent) -> None:
"""Add an event to the context immediately.

Events are enqueued so runners can stream them in real time instead of
waiting for superstep boundaries.
"""
await self._event_queue.put(event)
await self._get_event_queue().put(event)

async def drain_events(self) -> list[WorkflowEvent]:
"""Drain all currently queued events without blocking for new ones."""
events: list[WorkflowEvent] = []
queue = self._get_event_queue()
while True:
try:
events.append(self._event_queue.get_nowait())
events.append(queue.get_nowait())
except asyncio.QueueEmpty:
break
return events

async def has_events(self) -> bool:
return not self._event_queue.empty()
return not self._get_event_queue().empty()

async def next_event(self) -> WorkflowEvent:
"""Wait for and return the next event.

Used by the runner to interleave event emission with ongoing iteration work.
"""
return await self._event_queue.get()
return await self._get_event_queue().get()

# endregion Messaging and Events

Expand Down Expand Up @@ -407,8 +421,10 @@ def reset_for_new_run(self) -> None:
Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level.
"""
self._messages.clear()
# Clear any pending events (best-effort) by recreating the queue
self._event_queue = asyncio.Queue()
# Drop any pending events. The queue and its loop marker are cleared so the queue
# rebinds lazily under the running loop on next use.
self._event_queue = None
self._event_queue_loop = None
self._streaming = False # Reset streaming flag

async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None:
Expand Down
112 changes: 112 additions & 0 deletions python/packages/core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
WorkflowEvent,
WorkflowException,
WorkflowMessage,
WorkflowRunResult,
WorkflowRunState,
handler,
response_handler,
Expand Down Expand Up @@ -1079,6 +1080,117 @@ async def test_workflow_partial_stream_does_not_clobber_successor_runtime_storag
await asyncio.sleep(0)


async def test_workflow_serializes_concurrent_delivery_to_same_executor():
"""Messages delivered to one executor within a superstep must be processed serially.

A start executor fans out to two intermediate executors that both send to a single
target in the same superstep. The runner dispatches those two deliveries concurrently
(they have different source executors), but the target must process them one at a time
rather than interleaving at ``await`` points.
"""

class _FanSource(Executor):
def __init__(self, id: str, label: str) -> None:
super().__init__(id=id)
self._label = label

@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(self._label)

class _SerialTarget(Executor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.active = 0
self.max_active = 0
self.received: list[str] = []

@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
self.active += 1
self.max_active = max(self.max_active, self.active)
# Yield control. If execution were not serialized per executor, a concurrent
# delivery would enter here and push ``active`` (and ``max_active``) to 2.
await asyncio.sleep(0)
self.received.append(message)
self.active -= 1

start = _FanSource(id="start", label="go")
source_a = _FanSource(id="a", label="from_a")
source_b = _FanSource(id="b", label="from_b")
target = _SerialTarget(id="target")

# superstep 1: start -> {a, b}; superstep 2: a -> target, b -> target (both in the
# same superstep); superstep 3: target receives both deliveries and processes them serially.
workflow = (
WorkflowBuilder(start_executor=start)
.add_edge(start, source_a)
.add_edge(start, source_b)
.add_edge(source_a, target)
.add_edge(source_b, target)
.build()
)

await workflow.run("go")

assert target.max_active == 1, "Target processed concurrent deliveries (executions overlapped)"
assert sorted(target.received) == ["from_a", "from_b"]


def test_executor_serialization_lock_is_loop_scoped():
"""The per-executor serialization lock must be created under the running loop.

Executors are often constructed outside an event loop and may be reused across loops
(e.g. successive ``asyncio.run`` calls). The lock is created lazily and re-created when
the running loop changes, so reuse never raises ``asyncio.Lock`` "bound to a different
event loop". Creating it eagerly in ``__init__`` and binding it to the first loop would.
"""

class _Noop(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None: ...

executor = _Noop(id="noop")

async def _grab_lock() -> asyncio.Lock:
return executor._get_execution_lock() # pyright: ignore[reportPrivateUsage]

# Each ``asyncio.run`` uses a fresh event loop; the lock must be re-created per loop.
lock_loop_1 = asyncio.run(_grab_lock())
lock_loop_2 = asyncio.run(_grab_lock())

assert lock_loop_1 is not lock_loop_2


def test_workflow_instance_can_be_reused_across_event_loops():
"""A workflow built once can be re-run across separate event loops.

Both the per-executor ``asyncio.Lock`` and the runner context's ``asyncio.Queue`` bind to
the first event loop they are awaited under. They are re-created lazily under the running
loop, so successive ``asyncio.run`` calls on the same workflow instance do not raise
"bound to a different event loop".
"""

class _Echo(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[Any, str]) -> None:
await ctx.yield_output(message)

workflow = WorkflowBuilder(start_executor=_Echo(id="echo")).build()

async def _run(message: str) -> WorkflowRunResult:
return await workflow.run(message)

# A fresh event loop per run; reuse must not raise "bound to a different event loop".
result_1 = asyncio.run(_run("a"))
result_2 = asyncio.run(_run("b"))

assert result_1.get_final_state() == WorkflowRunState.IDLE
assert result_2.get_final_state() == WorkflowRunState.IDLE
assert result_1.get_outputs() == ["a"]
assert result_2.get_outputs() == ["b"]


class _StreamingTestAgent(BaseAgent):
"""Test agent that supports both streaming and non-streaming modes."""

Expand Down
Loading