55import json
66import logging
77from collections .abc import Awaitable , Callable
8- from dataclasses import dataclass
98from typing import Any
109
1110from pydantic import BaseModel , ValidationError
1211
1312from .exceptions import RequestError
13+ from .task import (
14+ DefaultMessageDispatcher ,
15+ InMemoryMessageQueue ,
16+ InMemoryMessageStateStore ,
17+ MessageDispatcher ,
18+ MessageQueue ,
19+ MessageSender ,
20+ MessageStateStore ,
21+ NotificationRunner ,
22+ RequestRunner ,
23+ RpcTask ,
24+ RpcTaskKind ,
25+ SenderFactory ,
26+ TaskSupervisor ,
27+ )
1428
1529JsonValue = Any
1630MethodHandler = Callable [[str , JsonValue | None , bool ], Awaitable [JsonValue | None ]]
1933__all__ = ["Connection" , "JsonValue" , "MethodHandler" ]
2034
2135
22- @dataclass (slots = True )
23- class _Pending :
24- future : asyncio .Future [Any ]
36+ DispatcherFactory = Callable [
37+ [MessageQueue , TaskSupervisor , MessageStateStore , RequestRunner , NotificationRunner ],
38+ MessageDispatcher ,
39+ ]
2540
2641
2742class Connection :
@@ -32,42 +47,64 @@ def __init__(
3247 handler : MethodHandler ,
3348 writer : asyncio .StreamWriter ,
3449 reader : asyncio .StreamReader ,
50+ * ,
51+ queue : MessageQueue | None = None ,
52+ state_store : MessageStateStore | None = None ,
53+ dispatcher_factory : DispatcherFactory | None = None ,
54+ sender_factory : SenderFactory | None = None ,
3555 ) -> None :
3656 self ._handler = handler
3757 self ._writer = writer
3858 self ._reader = reader
3959 self ._next_request_id = 0
40- self ._pending : dict [int , _Pending ] = {}
41- self ._inflight : set [asyncio .Task [Any ]] = set ()
42- self ._write_lock = asyncio .Lock ()
43- self ._recv_task = asyncio .create_task (self ._receive_loop ())
60+ self ._state = state_store or InMemoryMessageStateStore ()
61+ self ._tasks = TaskSupervisor (source = "acp.Connection" )
62+ self ._tasks .add_error_handler (self ._on_task_error )
63+ self ._queue = queue or InMemoryMessageQueue ()
64+ self ._closed = False
65+ self ._sender = (sender_factory or self ._default_sender_factory )(self ._writer , self ._tasks )
66+ self ._recv_task = self ._tasks .create (
67+ self ._receive_loop (),
68+ name = "acp.Connection.receive" ,
69+ on_error = self ._on_receive_error ,
70+ )
71+ dispatcher_factory = dispatcher_factory or self ._default_dispatcher_factory
72+ self ._dispatcher = dispatcher_factory (
73+ self ._queue ,
74+ self ._tasks ,
75+ self ._state ,
76+ self ._run_request ,
77+ self ._run_notification ,
78+ )
79+ self ._dispatcher .start ()
4480
4581 async def close (self ) -> None :
4682 """Stop the receive loop and cancel any in-flight handler tasks."""
47- if not self ._recv_task .done ():
48- self ._recv_task .cancel ()
49- with contextlib .suppress (asyncio .CancelledError ):
50- await self ._recv_task
51- if self ._inflight :
52- tasks = list (self ._inflight )
53- for task in tasks :
54- task .cancel ()
55- for task in tasks :
56- with contextlib .suppress (asyncio .CancelledError ):
57- await task
83+ if self ._closed :
84+ return
85+ self ._closed = True
86+ await self ._dispatcher .stop ()
87+ await self ._sender .close ()
88+ await self ._tasks .shutdown ()
89+ self ._state .reject_all_outgoing (ConnectionError ("Connection closed" ))
90+
91+ async def __aenter__ (self ) -> Connection :
92+ return self
93+
94+ async def __aexit__ (self , exc_type , exc , tb ) -> None :
95+ await self .close ()
5896
5997 async def send_request (self , method : str , params : JsonValue | None = None ) -> Any :
6098 request_id = self ._next_request_id
6199 self ._next_request_id += 1
62- future : asyncio .Future [Any ] = asyncio .get_running_loop ().create_future ()
63- self ._pending [request_id ] = _Pending (future )
100+ future = self ._state .register_outgoing (request_id , method )
64101 payload = {"jsonrpc" : "2.0" , "id" : request_id , "method" : method , "params" : params }
65- await self ._send_obj (payload )
102+ await self ._sender . send (payload )
66103 return await future
67104
68105 async def send_notification (self , method : str , params : JsonValue | None = None ) -> None :
69106 payload = {"jsonrpc" : "2.0" , "method" : method , "params" : params }
70- await self ._send_obj (payload )
107+ await self ._sender . send (payload )
71108
72109 async def _receive_loop (self ) -> None :
73110 try :
@@ -88,71 +125,87 @@ async def _process_message(self, message: dict[str, Any]) -> None:
88125 method = message .get ("method" )
89126 has_id = "id" in message
90127 if method is not None and has_id :
91- self ._schedule ( self . _handle_request ( message ))
128+ await self ._queue . publish ( RpcTask ( RpcTaskKind . REQUEST , message ))
92129 return
93130 if method is not None and not has_id :
94- await self ._handle_notification ( message )
131+ await self ._queue . publish ( RpcTask ( RpcTaskKind . NOTIFICATION , message ) )
95132 return
96133 if has_id :
97134 await self ._handle_response (message )
98135
99- def _schedule (self , coroutine : Awaitable [Any ]) -> None :
100- task = asyncio .create_task (coroutine )
101- self ._inflight .add (task )
102- task .add_done_callback (self ._task_done )
103-
104- def _task_done (self , task : asyncio .Task [Any ]) -> None :
105- self ._inflight .discard (task )
106- if task .cancelled ():
107- return
108- with contextlib .suppress (Exception ):
109- task .result ()
110-
111- async def _handle_request (self , message : dict [str , Any ]) -> None :
136+ async def _run_request (self , message : dict [str , Any ]) -> Any :
112137 payload : dict [str , Any ] = {"jsonrpc" : "2.0" , "id" : message ["id" ]}
113138 try :
114139 result = await self ._handler (message ["method" ], message .get ("params" ), False )
115140 if isinstance (result , BaseModel ):
116141 result = result .model_dump ()
117142 payload ["result" ] = result if result is not None else None
143+ await self ._sender .send (payload )
144+ return payload .get ("result" )
118145 except RequestError as exc :
119146 payload ["error" ] = exc .to_error_obj ()
147+ await self ._sender .send (payload )
148+ raise
120149 except ValidationError as exc :
121- payload ["error" ] = RequestError .invalid_params ({"errors" : exc .errors ()}).to_error_obj ()
150+ err = RequestError .invalid_params ({"errors" : exc .errors ()})
151+ payload ["error" ] = err .to_error_obj ()
152+ await self ._sender .send (payload )
153+ raise err from None
122154 except Exception as exc :
123155 try :
124156 data = json .loads (str (exc ))
125157 except Exception :
126158 data = {"details" : str (exc )}
127- payload ["error" ] = RequestError .internal_error (data ).to_error_obj ()
128- await self ._send_obj (payload )
159+ err = RequestError .internal_error (data )
160+ payload ["error" ] = err .to_error_obj ()
161+ await self ._sender .send (payload )
162+ raise err from None
129163
130- async def _handle_notification (self , message : dict [str , Any ]) -> None :
164+ async def _run_notification (self , message : dict [str , Any ]) -> None :
131165 with contextlib .suppress (Exception ):
132166 await self ._handler (message ["method" ], message .get ("params" ), True )
133167
134168 async def _handle_response (self , message : dict [str , Any ]) -> None :
135- pending = self ._pending .pop (message ["id" ], None )
136- if pending is None :
137- return
169+ request_id = message ["id" ]
170+ result = message .get ("result" )
138171 if "result" in message :
139- pending . future . set_result ( message . get ( " result" ) )
172+ self . _state . resolve_outgoing ( request_id , result )
140173 return
141174 if "error" in message :
142175 error_obj = message .get ("error" ) or {}
143- pending .future .set_exception (
176+ self ._state .reject_outgoing (
177+ request_id ,
144178 RequestError (
145179 error_obj .get ("code" , - 32603 ),
146180 error_obj .get ("message" , "Error" ),
147181 error_obj .get ("data" ),
148- )
182+ ),
149183 )
150184 return
151- pending .future .set_result (None )
152-
153- async def _send_obj (self , payload : dict [str , Any ]) -> None :
154- data = (json .dumps (payload , separators = ("," , ":" )) + "\n " ).encode ("utf-8" )
155- async with self ._write_lock :
156- self ._writer .write (data )
157- with contextlib .suppress (ConnectionError , RuntimeError ):
158- await self ._writer .drain ()
185+ self ._state .resolve_outgoing (request_id , None )
186+
187+ def _on_receive_error (self , task : asyncio .Task [Any ], exc : BaseException ) -> None :
188+ logging .exception ("Receive loop failed" , exc_info = exc )
189+ self ._state .reject_all_outgoing (exc )
190+
191+ def _on_task_error (self , task : asyncio .Task [Any ], exc : BaseException ) -> None :
192+ logging .exception ("Background task failed" , exc_info = exc )
193+
194+ def _default_dispatcher_factory (
195+ self ,
196+ queue : MessageQueue ,
197+ supervisor : TaskSupervisor ,
198+ state : MessageStateStore ,
199+ request_runner : RequestRunner ,
200+ notification_runner : NotificationRunner ,
201+ ) -> MessageDispatcher :
202+ return DefaultMessageDispatcher (
203+ queue = queue ,
204+ supervisor = supervisor ,
205+ store = state ,
206+ request_runner = request_runner ,
207+ notification_runner = notification_runner ,
208+ )
209+
210+ def _default_sender_factory (self , writer : asyncio .StreamWriter , supervisor : TaskSupervisor ) -> MessageSender :
211+ return MessageSender (writer , supervisor )
0 commit comments