-
Notifications
You must be signed in to change notification settings - Fork 26
Implement SSE streaming for message/stream and tasks/resubscribe #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3cc1da0
3e3f9d0
3cf60c4
04be72f
02fb21b
604e3f3
22020c5
14c9b33
cb1fcd6
59530da
83d8715
5f1aaf1
f4b1e8d
f338c14
e82dd17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,7 @@ | |
| from starlette.applications import Starlette | ||
| from starlette.middleware import Middleware | ||
| from starlette.requests import Request | ||
| from starlette.responses import FileResponse, Response | ||
| from starlette.responses import FileResponse, Response, StreamingResponse | ||
| from starlette.routing import Route | ||
| from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send | ||
|
|
||
|
|
@@ -19,9 +19,7 @@ | |
| AgentCard, | ||
| AgentInterface, | ||
| AgentProvider, | ||
| SendMessageResponse, | ||
| Skill, | ||
| UnsupportedOperationError, | ||
| a2a_request_ta, | ||
| a2a_response_ta, | ||
| agent_card_ta, | ||
|
|
@@ -104,7 +102,7 @@ async def _agent_card_endpoint(self, request: Request) -> Response: | |
| skills=self.skills, | ||
| default_input_modes=self.default_input_modes, | ||
| default_output_modes=self.default_output_modes, | ||
| capabilities=AgentCapabilities(streaming=False, push_notifications=False), | ||
| capabilities=AgentCapabilities(streaming=True, push_notifications=False), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 AgentCard streaming capability is now unconditionally True At Was this helpful? React with 👍 or 👎 to provide feedback. |
||
| ) | ||
| if self.provider is not None: | ||
| agent_card['provider'] = self.provider | ||
|
|
@@ -150,16 +148,14 @@ async def _agent_run_endpoint(self, request: Request) -> Response: | |
| elif a2a_request['method'] == 'tasks/list': | ||
| jsonrpc_response = await self.task_manager.list_tasks(a2a_request) | ||
| elif a2a_request['method'] == 'message/stream': | ||
| jsonrpc_response = SendMessageResponse( | ||
| jsonrpc='2.0', | ||
| id=a2a_request['id'], | ||
| error=UnsupportedOperationError(code=-32004, message='This operation is not supported'), | ||
| return StreamingResponse( | ||
| self.task_manager.stream_message(a2a_request), | ||
| media_type='text/event-stream', | ||
| ) | ||
| elif a2a_request['method'] == 'tasks/resubscribe': | ||
| jsonrpc_response = SendMessageResponse( | ||
| jsonrpc='2.0', | ||
| id=a2a_request['id'], | ||
| error=UnsupportedOperationError(code=-32004, message='This operation is not supported'), | ||
| return StreamingResponse( | ||
| self.task_manager.resubscribe_task(a2a_request), | ||
| media_type='text/event-stream', | ||
| ) | ||
| else: | ||
| raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||||||||||||||||||||||||||||
| from __future__ import annotations as _annotations | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import uuid | ||||||||||||||||||||||||||||||||
| from collections.abc import AsyncIterator | ||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import pydantic | ||||||||||||||||||||||||||||||||
|
|
@@ -13,9 +14,13 @@ | |||||||||||||||||||||||||||||||
| MessageSendParams, | ||||||||||||||||||||||||||||||||
| SendMessageRequest, | ||||||||||||||||||||||||||||||||
| SendMessageResponse, | ||||||||||||||||||||||||||||||||
| StreamMessageRequest, | ||||||||||||||||||||||||||||||||
| StreamMessageResponse, | ||||||||||||||||||||||||||||||||
| a2a_request_ta, | ||||||||||||||||||||||||||||||||
| send_message_request_ta, | ||||||||||||||||||||||||||||||||
| send_message_response_ta, | ||||||||||||||||||||||||||||||||
| stream_message_request_ta, | ||||||||||||||||||||||||||||||||
| stream_message_response_ta, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) | ||||||||||||||||||||||||||||||||
|
|
@@ -63,6 +68,34 @@ async def send_message( | |||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return send_message_response_ta.validate_json(response.content) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| async def stream_message( | ||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||
| message: Message, | ||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||
| metadata: dict[str, Any] | None = None, | ||||||||||||||||||||||||||||||||
| configuration: MessageSendConfiguration | None = None, | ||||||||||||||||||||||||||||||||
| ) -> AsyncIterator[StreamMessageResponse]: | ||||||||||||||||||||||||||||||||
| """Stream a message using SSE. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Yields StreamMessageResponse objects as they arrive. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| params = MessageSendParams(message=message) | ||||||||||||||||||||||||||||||||
| if metadata is not None: | ||||||||||||||||||||||||||||||||
| params['metadata'] = metadata | ||||||||||||||||||||||||||||||||
| if configuration is not None: | ||||||||||||||||||||||||||||||||
| params['configuration'] = configuration | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| request_id = str(uuid.uuid4()) | ||||||||||||||||||||||||||||||||
| payload = StreamMessageRequest(jsonrpc='2.0', id=request_id, method='message/stream', params=params) | ||||||||||||||||||||||||||||||||
| content = stream_message_request_ta.dump_json(payload, by_alias=True) | ||||||||||||||||||||||||||||||||
| async with self.http_client.stream( | ||||||||||||||||||||||||||||||||
| 'POST', '/', content=content, headers={'Content-Type': 'application/json'} | ||||||||||||||||||||||||||||||||
| ) as response: | ||||||||||||||||||||||||||||||||
| async for line in response.aiter_lines(): | ||||||||||||||||||||||||||||||||
| if line.startswith('data: '): | ||||||||||||||||||||||||||||||||
| data = line[6:] | ||||||||||||||||||||||||||||||||
| yield stream_message_response_ta.validate_json(data) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+91
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Client The
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| async def get_task(self, task_id: str) -> GetTaskResponse: | ||||||||||||||||||||||||||||||||
| payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) | ||||||||||||||||||||||||||||||||
| content = a2a_request_ta.dump_json(payload, by_alias=True) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,70 @@ | ||||||||||||||||||||
| """Event bus for streaming task updates to SSE connections.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from __future__ import annotations as _annotations | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||
| from collections.abc import AsyncIterator | ||||||||||||||||||||
| from contextlib import asynccontextmanager | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import anyio | ||||||||||||||||||||
| import anyio.abc | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from .schema import StreamResponse | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class EventBus(ABC): | ||||||||||||||||||||
| """A pub/sub event bus for streaming task events. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Allows workers to emit events that are delivered to SSE connections. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||
| @asynccontextmanager | ||||||||||||||||||||
| async def subscribe(self, task_id: str) -> AsyncIterator[anyio.abc.ObjectReceiveStream[StreamResponse]]: | ||||||||||||||||||||
| """Subscribe to events for a task. Yields a receive stream.""" | ||||||||||||||||||||
| yield # type: ignore[misc] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||
| async def emit(self, task_id: str, event: StreamResponse) -> None: | ||||||||||||||||||||
| """Emit an event to all subscribers for a task.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||
| async def close(self, task_id: str) -> None: | ||||||||||||||||||||
| """Close all subscriber streams for a task, signaling end of SSE.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class InMemoryEventBus(EventBus): | ||||||||||||||||||||
| """An in-memory event bus using anyio memory streams.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__(self) -> None: | ||||||||||||||||||||
| self._subscribers: dict[str, list[anyio.abc.ObjectSendStream[StreamResponse]]] = defaultdict(list) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @asynccontextmanager | ||||||||||||||||||||
| async def subscribe(self, task_id: str) -> AsyncIterator[anyio.abc.ObjectReceiveStream[StreamResponse]]: | ||||||||||||||||||||
| """Subscribe to events for a task. Yields a receive stream.""" | ||||||||||||||||||||
| send_stream, receive_stream = anyio.create_memory_object_stream[StreamResponse]() | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 unbounded memory object streams could accumulate events if consumer is slow In Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||||||||||
| self._subscribers[task_id].append(send_stream) | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| yield receive_stream | ||||||||||||||||||||
| finally: | ||||||||||||||||||||
| subscribers = self._subscribers.get(task_id) | ||||||||||||||||||||
| if subscribers is not None: | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| subscribers.remove(send_stream) | ||||||||||||||||||||
| except ValueError: | ||||||||||||||||||||
| pass | ||||||||||||||||||||
| if not subscribers: | ||||||||||||||||||||
| del self._subscribers[task_id] | ||||||||||||||||||||
| await send_stream.aclose() | ||||||||||||||||||||
| await receive_stream.aclose() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def emit(self, task_id: str, event: StreamResponse) -> None: | ||||||||||||||||||||
| """Emit an event to all subscribers for a task.""" | ||||||||||||||||||||
| for send_stream in self._subscribers.get(task_id, []): | ||||||||||||||||||||
| await send_stream.send(event) | ||||||||||||||||||||
|
Comment on lines
+62
to
+65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Iterating over live subscriber list in In
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback.
Comment on lines
+62
to
+65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 InMemoryEventBus.emit iterates list while awaiting — mutation risk In Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||||||||||
|
|
||||||||||||||||||||
| async def close(self, task_id: str) -> None: | ||||||||||||||||||||
| """Close all subscriber streams for a task, signaling end of SSE.""" | ||||||||||||||||||||
| for send_stream in self._subscribers.pop(task_id, []): | ||||||||||||||||||||
| await send_stream.aclose() | ||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,7 @@ | |
| from __future__ import annotations as _annotations | ||
|
|
||
| import uuid | ||
| from collections.abc import AsyncIterator | ||
| from contextlib import AsyncExitStack | ||
| from dataclasses import dataclass, field | ||
| from typing import Any | ||
|
|
@@ -80,14 +81,19 @@ | |
| ListTasksRequest, | ||
| ListTasksResponse, | ||
| PushNotificationNotSupportedError, | ||
| ResubscribeTaskRequest, | ||
| SendMessageRequest, | ||
| SendMessageResponse, | ||
| SendMessageResult, | ||
| SetTaskPushNotificationRequest, | ||
| SetTaskPushNotificationResponse, | ||
| StreamMessageRequest, | ||
| StreamMessageResponse, | ||
| StreamResponse, | ||
| TaskNotFoundError, | ||
| TaskSendParams, | ||
| UnsupportedOperationError, | ||
| stream_message_response_ta, | ||
| ) | ||
| from .storage import Storage | ||
|
|
||
|
|
@@ -119,7 +125,7 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): | |
| self._aexit_stack = None | ||
|
|
||
| async def send_message(self, request: SendMessageRequest) -> SendMessageResponse: | ||
| """Send a message using the A2A v0.3.0 protocol.""" | ||
| """Send a message using the A2A protocol.""" | ||
| request_id = request['id'] | ||
| message = request['params']['message'] | ||
| context_id = message.get('context_id', str(uuid.uuid4())) | ||
|
|
@@ -163,6 +169,67 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: | |
| ) | ||
| return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) | ||
|
|
||
| async def stream_message(self, request: StreamMessageRequest) -> AsyncIterator[bytes]: | ||
| """Stream a message response as SSE events.""" | ||
| request_id = request['id'] | ||
| message = request['params']['message'] | ||
| context_id = message.get('context_id', str(uuid.uuid4())) | ||
|
|
||
| task = await self.storage.submit_task(context_id, message) | ||
| task_id = task['id'] | ||
|
|
||
| broker_params: TaskSendParams = {'id': task_id, 'context_id': context_id, 'message': message} | ||
| config = request['params'].get('configuration', {}) | ||
| history_length = config.get('history_length') | ||
| if history_length is not None: | ||
| broker_params['history_length'] = history_length | ||
|
|
||
| async with self.broker.event_bus.subscribe(task_id) as receive_stream: | ||
| await self.broker.run_task(broker_params) | ||
|
|
||
| # Send initial task state | ||
| initial_response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=StreamResponse(task=task)) | ||
| yield self._format_sse_event(initial_response) | ||
|
|
||
| async for event in receive_stream: | ||
| response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=event) | ||
| yield self._format_sse_event(response) | ||
|
|
||
| async def resubscribe_task(self, request: ResubscribeTaskRequest) -> AsyncIterator[bytes]: | ||
| """Resubscribe to an existing task's event stream.""" | ||
| request_id = request['id'] | ||
| task_id = request['params']['id'] | ||
|
|
||
| task = await self.storage.load_task(task_id) | ||
| if task is None: | ||
| error_response = StreamMessageResponse( | ||
| jsonrpc='2.0', | ||
| id=request_id, | ||
| error=TaskNotFoundError(code=-32001, message='Task not found'), | ||
| ) | ||
| yield self._format_sse_event(error_response) | ||
| return | ||
|
|
||
| # Send current task state | ||
| initial_response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=StreamResponse(task=task)) | ||
| yield self._format_sse_event(initial_response) | ||
|
|
||
| # If task is already in a terminal state, no need to subscribe | ||
| terminal_states = {'completed', 'canceled', 'failed', 'rejected'} | ||
| if task['status']['state'] in terminal_states: | ||
| return | ||
|
|
||
| async with self.broker.event_bus.subscribe(task_id) as receive_stream: | ||
| async for event in receive_stream: | ||
| response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=event) | ||
| yield self._format_sse_event(response) | ||
|
Comment on lines
+198
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Race condition in resubscribe_task between terminal state check and subscribe is safe for InMemory but fragile for distributed implementations In Was this helpful? React with 👍 or 👎 to provide feedback.
Comment on lines
+217
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 TOCTOU race in In Why InMemoryStorage accidentally masks this bugWith Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| @staticmethod | ||
| def _format_sse_event(response: StreamMessageResponse) -> bytes: | ||
| """Format a StreamMessageResponse as an SSE event.""" | ||
| data = stream_message_response_ta.dump_json(response, by_alias=True) | ||
| return b'data: ' + data + b'\n\n' | ||
|
|
||
| async def set_task_push_notification( | ||
| self, request: SetTaskPushNotificationRequest | ||
| ) -> SetTaskPushNotificationResponse: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,7 +54,21 @@ async def _handle_task_operation(self, task_operation: TaskOperation) -> None: | |
| else: | ||
| assert_never(task_operation) | ||
| except Exception: | ||
| await self.storage.update_task(task_operation['params']['id'], state='failed') | ||
| task_id = task_operation['params']['id'] | ||
| task = await self.storage.update_task(task_id, state='failed') | ||
| from .schema import StreamResponse, TaskStatus, TaskStatusUpdateEvent | ||
|
|
||
| await self.broker.event_bus.emit( | ||
| task_id, | ||
| StreamResponse( | ||
| status_update=TaskStatusUpdateEvent( | ||
| task_id=task_id, | ||
| context_id=task['context_id'], | ||
| status=TaskStatus(state='failed'), | ||
| ) | ||
| ), | ||
| ) | ||
| await self.broker.event_bus.close(task_id) | ||
|
|
||
| @abstractmethod | ||
| async def run_task(self, params: TaskSendParams) -> None: ... | ||
|
Comment on lines
73
to
74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Worker implementations must emit their own streaming events for intermediate updates The Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
|
devin-ai-integration[bot] marked this conversation as resolved.
|
Uh oh!
There was an error while loading. Please reload this page.