diff --git a/fasta2a/applications.py b/fasta2a/applications.py index 7e96251..89a3cda 100644 --- a/fasta2a/applications.py +++ b/fasta2a/applications.py @@ -8,7 +8,7 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.requests import Request -from starlette.responses import FileResponse, Response +from starlette.responses import FileResponse, Response, StreamingResponse from starlette.routing import Route from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send @@ -19,9 +19,7 @@ AgentCard, AgentInterface, AgentProvider, - SendMessageResponse, Skill, - UnsupportedOperationError, a2a_request_ta, a2a_response_ta, agent_card_ta, @@ -104,7 +102,7 @@ async def _agent_card_endpoint(self, request: Request) -> Response: skills=self.skills, default_input_modes=self.default_input_modes, default_output_modes=self.default_output_modes, - capabilities=AgentCapabilities(streaming=False, push_notifications=False), + capabilities=AgentCapabilities(streaming=True, push_notifications=False), ) if self.provider is not None: agent_card['provider'] = self.provider @@ -150,16 +148,14 @@ async def _agent_run_endpoint(self, request: Request) -> Response: elif a2a_request['method'] == 'tasks/list': jsonrpc_response = await self.task_manager.list_tasks(a2a_request) elif a2a_request['method'] == 'message/stream': - jsonrpc_response = SendMessageResponse( - jsonrpc='2.0', - id=a2a_request['id'], - error=UnsupportedOperationError(code=-32004, message='This operation is not supported'), + return StreamingResponse( + self.task_manager.stream_message(a2a_request), + media_type='text/event-stream', ) elif a2a_request['method'] == 'tasks/resubscribe': - jsonrpc_response = SendMessageResponse( - jsonrpc='2.0', - id=a2a_request['id'], - error=UnsupportedOperationError(code=-32004, message='This operation is not supported'), + return StreamingResponse( + self.task_manager.resubscribe_task(a2a_request), + media_type='text/event-stream', ) else: raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') diff --git a/fasta2a/broker.py b/fasta2a/broker.py index c84b738..e91ffdd 100644 --- a/fasta2a/broker.py +++ b/fasta2a/broker.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import AsyncExitStack -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Annotated, Any, Generic, Literal, TypeVar import anyio @@ -11,6 +11,7 @@ from pydantic import Discriminator from typing_extensions import Self, TypedDict +from .event_bus import EventBus, InMemoryEventBus from .schema import TaskIdParams, TaskSendParams tracer = get_tracer(__name__) @@ -27,6 +28,8 @@ class Broker(ABC): extended to support remote workers. """ + event_bus: EventBus = field(default_factory=InMemoryEventBus) + @abstractmethod async def run_task(self, params: TaskSendParams) -> None: """Send a task to be executed by the worker.""" diff --git a/fasta2a/client.py b/fasta2a/client.py index cd84499..e059c33 100644 --- a/fasta2a/client.py +++ b/fasta2a/client.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import uuid +from collections.abc import AsyncIterator from typing import Any import pydantic @@ -13,9 +14,13 @@ MessageSendParams, SendMessageRequest, SendMessageResponse, + StreamMessageRequest, + StreamMessageResponse, a2a_request_ta, send_message_request_ta, send_message_response_ta, + stream_message_request_ta, + stream_message_response_ta, ) get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) @@ -63,6 +68,34 @@ async def send_message( return send_message_response_ta.validate_json(response.content) + async def stream_message( + self, + message: Message, + *, + metadata: dict[str, Any] | None = None, + configuration: MessageSendConfiguration | None = None, + ) -> AsyncIterator[StreamMessageResponse]: + """Stream a message using SSE. + + Yields StreamMessageResponse objects as they arrive. + """ + params = MessageSendParams(message=message) + if metadata is not None: + params['metadata'] = metadata + if configuration is not None: + params['configuration'] = configuration + + request_id = str(uuid.uuid4()) + payload = StreamMessageRequest(jsonrpc='2.0', id=request_id, method='message/stream', params=params) + content = stream_message_request_ta.dump_json(payload, by_alias=True) + async with self.http_client.stream( + 'POST', '/', content=content, headers={'Content-Type': 'application/json'} + ) as response: + async for line in response.aiter_lines(): + if line.startswith('data: '): + data = line[6:] + yield stream_message_response_ta.validate_json(data) + async def get_task(self, task_id: str) -> GetTaskResponse: payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) content = a2a_request_ta.dump_json(payload, by_alias=True) diff --git a/fasta2a/event_bus.py b/fasta2a/event_bus.py new file mode 100644 index 0000000..a47ce31 --- /dev/null +++ b/fasta2a/event_bus.py @@ -0,0 +1,70 @@ +"""Event bus for streaming task updates to SSE connections.""" + +from __future__ import annotations as _annotations + +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import anyio +import anyio.abc + +from .schema import StreamResponse + + +class EventBus(ABC): + """A pub/sub event bus for streaming task events. + + Allows workers to emit events that are delivered to SSE connections. + """ + + @abstractmethod + @asynccontextmanager + async def subscribe(self, task_id: str) -> AsyncIterator[anyio.abc.ObjectReceiveStream[StreamResponse]]: + """Subscribe to events for a task. Yields a receive stream.""" + yield # type: ignore[misc] + + @abstractmethod + async def emit(self, task_id: str, event: StreamResponse) -> None: + """Emit an event to all subscribers for a task.""" + + @abstractmethod + async def close(self, task_id: str) -> None: + """Close all subscriber streams for a task, signaling end of SSE.""" + + +class InMemoryEventBus(EventBus): + """An in-memory event bus using anyio memory streams.""" + + def __init__(self) -> None: + self._subscribers: dict[str, list[anyio.abc.ObjectSendStream[StreamResponse]]] = defaultdict(list) + + @asynccontextmanager + async def subscribe(self, task_id: str) -> AsyncIterator[anyio.abc.ObjectReceiveStream[StreamResponse]]: + """Subscribe to events for a task. Yields a receive stream.""" + send_stream, receive_stream = anyio.create_memory_object_stream[StreamResponse]() + self._subscribers[task_id].append(send_stream) + try: + yield receive_stream + finally: + subscribers = self._subscribers.get(task_id) + if subscribers is not None: + try: + subscribers.remove(send_stream) + except ValueError: + pass + if not subscribers: + del self._subscribers[task_id] + await send_stream.aclose() + await receive_stream.aclose() + + async def emit(self, task_id: str, event: StreamResponse) -> None: + """Emit an event to all subscribers for a task.""" + for send_stream in self._subscribers.get(task_id, []): + await send_stream.send(event) + + async def close(self, task_id: str) -> None: + """Close all subscriber streams for a task, signaling end of SSE.""" + for send_stream in self._subscribers.pop(task_id, []): + await send_stream.aclose() diff --git a/fasta2a/task_manager.py b/fasta2a/task_manager.py index 0a2fee0..d0fecbc 100644 --- a/fasta2a/task_manager.py +++ b/fasta2a/task_manager.py @@ -61,6 +61,7 @@ from __future__ import annotations as _annotations import uuid +from collections.abc import AsyncIterator from contextlib import AsyncExitStack from dataclasses import dataclass, field from typing import Any @@ -80,14 +81,19 @@ ListTasksRequest, ListTasksResponse, PushNotificationNotSupportedError, + ResubscribeTaskRequest, SendMessageRequest, SendMessageResponse, SendMessageResult, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, + StreamMessageRequest, + StreamMessageResponse, + StreamResponse, TaskNotFoundError, TaskSendParams, UnsupportedOperationError, + stream_message_response_ta, ) from .storage import Storage @@ -119,7 +125,7 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): self._aexit_stack = None async def send_message(self, request: SendMessageRequest) -> SendMessageResponse: - """Send a message using the A2A v0.3.0 protocol.""" + """Send a message using the A2A protocol.""" request_id = request['id'] message = request['params']['message'] context_id = message.get('context_id', str(uuid.uuid4())) @@ -163,6 +169,67 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: ) return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) + async def stream_message(self, request: StreamMessageRequest) -> AsyncIterator[bytes]: + """Stream a message response as SSE events.""" + request_id = request['id'] + message = request['params']['message'] + context_id = message.get('context_id', str(uuid.uuid4())) + + task = await self.storage.submit_task(context_id, message) + task_id = task['id'] + + broker_params: TaskSendParams = {'id': task_id, 'context_id': context_id, 'message': message} + config = request['params'].get('configuration', {}) + history_length = config.get('history_length') + if history_length is not None: + broker_params['history_length'] = history_length + + async with self.broker.event_bus.subscribe(task_id) as receive_stream: + await self.broker.run_task(broker_params) + + # Send initial task state + initial_response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=StreamResponse(task=task)) + yield self._format_sse_event(initial_response) + + async for event in receive_stream: + response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=event) + yield self._format_sse_event(response) + + async def resubscribe_task(self, request: ResubscribeTaskRequest) -> AsyncIterator[bytes]: + """Resubscribe to an existing task's event stream.""" + request_id = request['id'] + task_id = request['params']['id'] + + task = await self.storage.load_task(task_id) + if task is None: + error_response = StreamMessageResponse( + jsonrpc='2.0', + id=request_id, + error=TaskNotFoundError(code=-32001, message='Task not found'), + ) + yield self._format_sse_event(error_response) + return + + # Send current task state + initial_response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=StreamResponse(task=task)) + yield self._format_sse_event(initial_response) + + # If task is already in a terminal state, no need to subscribe + terminal_states = {'completed', 'canceled', 'failed', 'rejected'} + if task['status']['state'] in terminal_states: + return + + async with self.broker.event_bus.subscribe(task_id) as receive_stream: + async for event in receive_stream: + response = StreamMessageResponse(jsonrpc='2.0', id=request_id, result=event) + yield self._format_sse_event(response) + + @staticmethod + def _format_sse_event(response: StreamMessageResponse) -> bytes: + """Format a StreamMessageResponse as an SSE event.""" + data = stream_message_response_ta.dump_json(response, by_alias=True) + return b'data: ' + data + b'\n\n' + async def set_task_push_notification( self, request: SetTaskPushNotificationRequest ) -> SetTaskPushNotificationResponse: diff --git a/fasta2a/worker.py b/fasta2a/worker.py index bcb0172..0e22a80 100644 --- a/fasta2a/worker.py +++ b/fasta2a/worker.py @@ -54,7 +54,21 @@ async def _handle_task_operation(self, task_operation: TaskOperation) -> None: else: assert_never(task_operation) except Exception: - await self.storage.update_task(task_operation['params']['id'], state='failed') + task_id = task_operation['params']['id'] + task = await self.storage.update_task(task_id, state='failed') + from .schema import StreamResponse, TaskStatus, TaskStatusUpdateEvent + + await self.broker.event_bus.emit( + task_id, + StreamResponse( + status_update=TaskStatusUpdateEvent( + task_id=task_id, + context_id=task['context_id'], + status=TaskStatus(state='failed'), + ) + ), + ) + await self.broker.event_bus.close(task_id) @abstractmethod async def run_task(self, params: TaskSendParams) -> None: ... diff --git a/tests/test_applications.py b/tests/test_applications.py index e5265ba..2385138 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -43,7 +43,7 @@ async def test_agent_card(): 'defaultInputModes': ['application/json'], 'defaultOutputModes': ['application/json'], 'capabilities': { - 'streaming': False, + 'streaming': True, 'pushNotifications': False, }, } diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..476a4e8 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,131 @@ +from __future__ import annotations as _annotations + +import uuid +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import httpx +import pytest +from asgi_lifespan import LifespanManager + +from fasta2a.applications import FastA2A +from fasta2a.broker import InMemoryBroker +from fasta2a.client import A2AClient +from fasta2a.schema import ( + Artifact, + Message, + Part, + StreamResponse, + TaskSendParams, + TaskStatus, + TaskStatusUpdateEvent, +) +from fasta2a.storage import InMemoryStorage +from fasta2a.worker import Worker + +pytestmark = pytest.mark.anyio + + +class EchoWorker(Worker[Any]): + """A simple worker that echoes the input message as an artifact.""" + + async def run_task(self, params: TaskSendParams) -> None: + task_id = params['id'] + context_id = params['context_id'] + + # Emit a "working" status update + await self.broker.event_bus.emit( + task_id, + StreamResponse( + status_update=TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus(state='working'), + ) + ), + ) + + # Update storage to working + await self.storage.update_task(task_id, state='working') + + # Create an artifact echoing the input + input_parts = params['message']['parts'] + artifact = Artifact(artifact_id=str(uuid.uuid4()), parts=input_parts) + await self.storage.update_task(task_id, state='completed', new_artifacts=[artifact]) + + # Emit completed status + await self.broker.event_bus.emit( + task_id, + StreamResponse( + status_update=TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus(state='completed'), + ) + ), + ) + await self.broker.event_bus.close(task_id) + + async def cancel_task(self, params: Any) -> None: + pass + + def build_message_history(self, history: list[Message]) -> list[Any]: + return [] + + def build_artifacts(self, result: Any) -> list[Artifact]: + return [] + + +@asynccontextmanager +async def create_streaming_app() -> AsyncIterator[httpx.AsyncClient]: + broker = InMemoryBroker() + storage = InMemoryStorage() + worker = EchoWorker(broker=broker, storage=storage) + + app = FastA2A(storage=storage, broker=broker) + + @asynccontextmanager + async def lifespan(app: FastA2A) -> AsyncIterator[None]: + async with app.task_manager: + async with worker.run(): + yield + + app = FastA2A(storage=storage, broker=broker, lifespan=lifespan) + + async with LifespanManager(app=app) as manager: + transport = httpx.ASGITransport(app=manager.app) + async with httpx.AsyncClient(transport=transport, base_url='http://testclient') as client: + yield client + + +async def test_stream_message(): + async with create_streaming_app() as http_client: + client = A2AClient(http_client=http_client) + client.http_client.base_url = 'http://testclient' + + message = Message( + role='user', + parts=[Part(text='Hello, world!')], + message_id=str(uuid.uuid4()), + ) + + events: list[StreamResponse] = [] + async for response in client.stream_message(message): + if 'result' in response: + events.append(response['result']) + + # Should have: initial task, working status, completed status + assert len(events) == 3 + + # First event: initial task state (submitted) + assert 'task' in events[0] + assert events[0]['task']['status']['state'] == 'submitted' + + # Second event: working status update + assert 'status_update' in events[1] + assert events[1]['status_update']['status']['state'] == 'working' + + # Third event: completed status update + assert 'status_update' in events[2] + assert events[2]['status_update']['status']['state'] == 'completed'