Skip to content

Commit fa007f9

Browse files
author
Chojan Shang
committed
feat: intro task module for better handle
Signed-off-by: Chojan Shang <chojan.shang@vesoft.com>
1 parent 76fcfc6 commit fa007f9

7 files changed

Lines changed: 545 additions & 56 deletions

File tree

src/acp/connection.py

Lines changed: 109 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,26 @@
55
import json
66
import logging
77
from collections.abc import Awaitable, Callable
8-
from dataclasses import dataclass
98
from typing import Any
109

1110
from pydantic import BaseModel, ValidationError
1211

1312
from .exceptions import RequestError
13+
from .task import (
14+
DefaultMessageDispatcher,
15+
InMemoryMessageQueue,
16+
InMemoryMessageStateStore,
17+
MessageDispatcher,
18+
MessageQueue,
19+
MessageSender,
20+
MessageStateStore,
21+
NotificationRunner,
22+
RequestRunner,
23+
RpcTask,
24+
RpcTaskKind,
25+
SenderFactory,
26+
TaskSupervisor,
27+
)
1428

1529
JsonValue = Any
1630
MethodHandler = Callable[[str, JsonValue | None, bool], Awaitable[JsonValue | None]]
@@ -19,9 +33,10 @@
1933
__all__ = ["Connection", "JsonValue", "MethodHandler"]
2034

2135

22-
@dataclass(slots=True)
23-
class _Pending:
24-
future: asyncio.Future[Any]
36+
DispatcherFactory = Callable[
37+
[MessageQueue, TaskSupervisor, MessageStateStore, RequestRunner, NotificationRunner],
38+
MessageDispatcher,
39+
]
2540

2641

2742
class Connection:
@@ -32,42 +47,64 @@ def __init__(
3247
handler: MethodHandler,
3348
writer: asyncio.StreamWriter,
3449
reader: asyncio.StreamReader,
50+
*,
51+
queue: MessageQueue | None = None,
52+
state_store: MessageStateStore | None = None,
53+
dispatcher_factory: DispatcherFactory | None = None,
54+
sender_factory: SenderFactory | None = None,
3555
) -> None:
3656
self._handler = handler
3757
self._writer = writer
3858
self._reader = reader
3959
self._next_request_id = 0
40-
self._pending: dict[int, _Pending] = {}
41-
self._inflight: set[asyncio.Task[Any]] = set()
42-
self._write_lock = asyncio.Lock()
43-
self._recv_task = asyncio.create_task(self._receive_loop())
60+
self._state = state_store or InMemoryMessageStateStore()
61+
self._tasks = TaskSupervisor(source="acp.Connection")
62+
self._tasks.add_error_handler(self._on_task_error)
63+
self._queue = queue or InMemoryMessageQueue()
64+
self._closed = False
65+
self._sender = (sender_factory or self._default_sender_factory)(self._writer, self._tasks)
66+
self._recv_task = self._tasks.create(
67+
self._receive_loop(),
68+
name="acp.Connection.receive",
69+
on_error=self._on_receive_error,
70+
)
71+
dispatcher_factory = dispatcher_factory or self._default_dispatcher_factory
72+
self._dispatcher = dispatcher_factory(
73+
self._queue,
74+
self._tasks,
75+
self._state,
76+
self._run_request,
77+
self._run_notification,
78+
)
79+
self._dispatcher.start()
4480

4581
async def close(self) -> None:
4682
"""Stop the receive loop and cancel any in-flight handler tasks."""
47-
if not self._recv_task.done():
48-
self._recv_task.cancel()
49-
with contextlib.suppress(asyncio.CancelledError):
50-
await self._recv_task
51-
if self._inflight:
52-
tasks = list(self._inflight)
53-
for task in tasks:
54-
task.cancel()
55-
for task in tasks:
56-
with contextlib.suppress(asyncio.CancelledError):
57-
await task
83+
if self._closed:
84+
return
85+
self._closed = True
86+
await self._dispatcher.stop()
87+
await self._sender.close()
88+
await self._tasks.shutdown()
89+
self._state.reject_all_outgoing(ConnectionError("Connection closed"))
90+
91+
async def __aenter__(self) -> Connection:
92+
return self
93+
94+
async def __aexit__(self, exc_type, exc, tb) -> None:
95+
await self.close()
5896

5997
async def send_request(self, method: str, params: JsonValue | None = None) -> Any:
6098
request_id = self._next_request_id
6199
self._next_request_id += 1
62-
future: asyncio.Future[Any] = asyncio.get_running_loop().create_future()
63-
self._pending[request_id] = _Pending(future)
100+
future = self._state.register_outgoing(request_id, method)
64101
payload = {"jsonrpc": "2.0", "id": request_id, "method": method, "params": params}
65-
await self._send_obj(payload)
102+
await self._sender.send(payload)
66103
return await future
67104

68105
async def send_notification(self, method: str, params: JsonValue | None = None) -> None:
69106
payload = {"jsonrpc": "2.0", "method": method, "params": params}
70-
await self._send_obj(payload)
107+
await self._sender.send(payload)
71108

72109
async def _receive_loop(self) -> None:
73110
try:
@@ -88,71 +125,87 @@ async def _process_message(self, message: dict[str, Any]) -> None:
88125
method = message.get("method")
89126
has_id = "id" in message
90127
if method is not None and has_id:
91-
self._schedule(self._handle_request(message))
128+
await self._queue.publish(RpcTask(RpcTaskKind.REQUEST, message))
92129
return
93130
if method is not None and not has_id:
94-
await self._handle_notification(message)
131+
await self._queue.publish(RpcTask(RpcTaskKind.NOTIFICATION, message))
95132
return
96133
if has_id:
97134
await self._handle_response(message)
98135

99-
def _schedule(self, coroutine: Awaitable[Any]) -> None:
100-
task = asyncio.create_task(coroutine)
101-
self._inflight.add(task)
102-
task.add_done_callback(self._task_done)
103-
104-
def _task_done(self, task: asyncio.Task[Any]) -> None:
105-
self._inflight.discard(task)
106-
if task.cancelled():
107-
return
108-
with contextlib.suppress(Exception):
109-
task.result()
110-
111-
async def _handle_request(self, message: dict[str, Any]) -> None:
136+
async def _run_request(self, message: dict[str, Any]) -> Any:
112137
payload: dict[str, Any] = {"jsonrpc": "2.0", "id": message["id"]}
113138
try:
114139
result = await self._handler(message["method"], message.get("params"), False)
115140
if isinstance(result, BaseModel):
116141
result = result.model_dump()
117142
payload["result"] = result if result is not None else None
143+
await self._sender.send(payload)
144+
return payload.get("result")
118145
except RequestError as exc:
119146
payload["error"] = exc.to_error_obj()
147+
await self._sender.send(payload)
148+
raise
120149
except ValidationError as exc:
121-
payload["error"] = RequestError.invalid_params({"errors": exc.errors()}).to_error_obj()
150+
err = RequestError.invalid_params({"errors": exc.errors()})
151+
payload["error"] = err.to_error_obj()
152+
await self._sender.send(payload)
153+
raise err from None
122154
except Exception as exc:
123155
try:
124156
data = json.loads(str(exc))
125157
except Exception:
126158
data = {"details": str(exc)}
127-
payload["error"] = RequestError.internal_error(data).to_error_obj()
128-
await self._send_obj(payload)
159+
err = RequestError.internal_error(data)
160+
payload["error"] = err.to_error_obj()
161+
await self._sender.send(payload)
162+
raise err from None
129163

130-
async def _handle_notification(self, message: dict[str, Any]) -> None:
164+
async def _run_notification(self, message: dict[str, Any]) -> None:
131165
with contextlib.suppress(Exception):
132166
await self._handler(message["method"], message.get("params"), True)
133167

134168
async def _handle_response(self, message: dict[str, Any]) -> None:
135-
pending = self._pending.pop(message["id"], None)
136-
if pending is None:
137-
return
169+
request_id = message["id"]
170+
result = message.get("result")
138171
if "result" in message:
139-
pending.future.set_result(message.get("result"))
172+
self._state.resolve_outgoing(request_id, result)
140173
return
141174
if "error" in message:
142175
error_obj = message.get("error") or {}
143-
pending.future.set_exception(
176+
self._state.reject_outgoing(
177+
request_id,
144178
RequestError(
145179
error_obj.get("code", -32603),
146180
error_obj.get("message", "Error"),
147181
error_obj.get("data"),
148-
)
182+
),
149183
)
150184
return
151-
pending.future.set_result(None)
152-
153-
async def _send_obj(self, payload: dict[str, Any]) -> None:
154-
data = (json.dumps(payload, separators=(",", ":")) + "\n").encode("utf-8")
155-
async with self._write_lock:
156-
self._writer.write(data)
157-
with contextlib.suppress(ConnectionError, RuntimeError):
158-
await self._writer.drain()
185+
self._state.resolve_outgoing(request_id, None)
186+
187+
def _on_receive_error(self, task: asyncio.Task[Any], exc: BaseException) -> None:
188+
logging.exception("Receive loop failed", exc_info=exc)
189+
self._state.reject_all_outgoing(exc)
190+
191+
def _on_task_error(self, task: asyncio.Task[Any], exc: BaseException) -> None:
192+
logging.exception("Background task failed", exc_info=exc)
193+
194+
def _default_dispatcher_factory(
195+
self,
196+
queue: MessageQueue,
197+
supervisor: TaskSupervisor,
198+
state: MessageStateStore,
199+
request_runner: RequestRunner,
200+
notification_runner: NotificationRunner,
201+
) -> MessageDispatcher:
202+
return DefaultMessageDispatcher(
203+
queue=queue,
204+
supervisor=supervisor,
205+
store=state,
206+
request_runner=request_runner,
207+
notification_runner=notification_runner,
208+
)
209+
210+
def _default_sender_factory(self, writer: asyncio.StreamWriter, supervisor: TaskSupervisor) -> MessageSender:
211+
return MessageSender(writer, supervisor)

src/acp/task/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from enum import Enum
5+
from typing import Any
6+
7+
__all__ = ["RpcTask", "RpcTaskKind"]
8+
9+
10+
class RpcTaskKind(Enum):
11+
REQUEST = "request"
12+
NOTIFICATION = "notification"
13+
14+
15+
@dataclass(slots=True)
16+
class RpcTask:
17+
kind: RpcTaskKind
18+
message: dict[str, Any]
19+
20+
21+
from .dispatcher import ( # noqa: E402
22+
DefaultMessageDispatcher,
23+
MessageDispatcher,
24+
NotificationRunner,
25+
RequestRunner,
26+
)
27+
from .queue import InMemoryMessageQueue, MessageQueue # noqa: E402
28+
from .sender import MessageSender, SenderFactory # noqa: E402
29+
from .state import InMemoryMessageStateStore, MessageStateStore # noqa: E402
30+
from .supervisor import TaskSupervisor # noqa: E402
31+
32+
__all__ += [
33+
"DefaultMessageDispatcher",
34+
"InMemoryMessageQueue",
35+
"InMemoryMessageStateStore",
36+
"MessageDispatcher",
37+
"MessageQueue",
38+
"MessageSender",
39+
"MessageStateStore",
40+
"NotificationRunner",
41+
"RequestRunner",
42+
"SenderFactory",
43+
"TaskSupervisor",
44+
]

src/acp/task/dispatcher.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import Awaitable, Callable
5+
from contextlib import suppress
6+
from typing import Any, Protocol
7+
8+
from . import RpcTaskKind
9+
from .queue import MessageQueue
10+
from .state import MessageStateStore
11+
from .supervisor import TaskSupervisor
12+
13+
__all__ = [
14+
"DefaultMessageDispatcher",
15+
"MessageDispatcher",
16+
"NotificationRunner",
17+
"RequestRunner",
18+
]
19+
20+
21+
RequestRunner = Callable[[dict[str, Any]], Awaitable[Any]]
22+
NotificationRunner = Callable[[dict[str, Any]], Awaitable[None]]
23+
24+
25+
class MessageDispatcher(Protocol):
26+
def start(self) -> None: ...
27+
28+
async def stop(self) -> None: ...
29+
30+
31+
class DefaultMessageDispatcher(MessageDispatcher):
32+
"""Background worker that consumes RPC tasks from a broker, coordinating with the store."""
33+
34+
def __init__(
35+
self,
36+
*,
37+
queue: MessageQueue,
38+
supervisor: TaskSupervisor,
39+
store: MessageStateStore,
40+
request_runner: RequestRunner,
41+
notification_runner: NotificationRunner,
42+
) -> None:
43+
self._queue = queue
44+
self._supervisor = supervisor
45+
self._store = store
46+
self._request_runner = request_runner
47+
self._notification_runner = notification_runner
48+
self._task: asyncio.Task[None] | None = None
49+
50+
def start(self) -> None:
51+
if self._task is not None:
52+
msg = "dispatcher already started"
53+
raise RuntimeError(msg)
54+
self._task = self._supervisor.create(self._run(), name="acp.Dispatcher.loop")
55+
56+
async def _run(self) -> None:
57+
try:
58+
async for task in self._queue:
59+
try:
60+
if task.kind is RpcTaskKind.REQUEST:
61+
await self._dispatch_request(task.message)
62+
else:
63+
await self._dispatch_notification(task.message)
64+
finally:
65+
self._queue.task_done()
66+
except asyncio.CancelledError:
67+
return
68+
69+
async def stop(self) -> None:
70+
await self._queue.close()
71+
if self._task is not None:
72+
with suppress(asyncio.CancelledError):
73+
await self._task
74+
self._task = None
75+
76+
async def _dispatch_request(self, message: dict[str, Any]) -> None:
77+
record = self._store.begin_incoming(message.get("method", ""), message.get("params"))
78+
79+
async def runner() -> None:
80+
try:
81+
result = await self._request_runner(message)
82+
except Exception as exc:
83+
self._store.fail_incoming(record, exc)
84+
raise
85+
else:
86+
self._store.complete_incoming(record, result)
87+
88+
self._supervisor.create(runner(), name="acp.Dispatcher.request")
89+
90+
async def _dispatch_notification(self, message: dict[str, Any]) -> None:
91+
async def runner() -> None:
92+
await self._notification_runner(message)
93+
94+
self._supervisor.create(runner(), name="acp.Dispatcher.notification")

0 commit comments

Comments
 (0)