Skip to content

Commit 76fcfc6

Browse files
author
Chojan Shang
committed
refactor: split into files
Signed-off-by: Chojan Shang <chojan.shang@vesoft.com>
1 parent 3383ef3 commit 76fcfc6

12 files changed

Lines changed: 927 additions & 669 deletions

File tree

src/acp/agent/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .connection import AgentSideConnection
2+
3+
__all__ = ["AgentSideConnection"]

src/acp/agent/connection.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import Callable
5+
from typing import Any
6+
7+
from ..connection import Connection, MethodHandler
8+
from ..interfaces import Agent
9+
from ..meta import CLIENT_METHODS
10+
from ..schema import (
11+
CreateTerminalRequest,
12+
CreateTerminalResponse,
13+
KillTerminalCommandRequest,
14+
KillTerminalCommandResponse,
15+
ReadTextFileRequest,
16+
ReadTextFileResponse,
17+
ReleaseTerminalRequest,
18+
ReleaseTerminalResponse,
19+
RequestPermissionRequest,
20+
RequestPermissionResponse,
21+
SessionNotification,
22+
TerminalOutputRequest,
23+
TerminalOutputResponse,
24+
WaitForTerminalExitRequest,
25+
WaitForTerminalExitResponse,
26+
WriteTextFileRequest,
27+
WriteTextFileResponse,
28+
)
29+
from ..terminal import TerminalHandle
30+
from ..utils import notify_model, request_model, request_optional_model
31+
from .handlers import dispatch_agent_method
32+
33+
__all__ = ["AgentSideConnection"]
34+
35+
_AGENT_CONNECTION_ERROR = "AgentSideConnection requires asyncio StreamWriter/StreamReader"
36+
37+
38+
class AgentSideConnection:
39+
"""Agent-side connection wrapper that dispatches JSON-RPC messages to a Client implementation."""
40+
41+
def __init__(
42+
self,
43+
to_agent: Callable[[AgentSideConnection], Agent],
44+
input_stream: Any,
45+
output_stream: Any,
46+
) -> None:
47+
agent = to_agent(self)
48+
handler = self._create_handler(agent)
49+
50+
if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader):
51+
raise TypeError(_AGENT_CONNECTION_ERROR)
52+
self._conn = Connection(handler, input_stream, output_stream)
53+
54+
def _create_handler(self, agent: Agent) -> MethodHandler:
55+
async def handler(method: str, params: Any | None, is_notification: bool) -> Any:
56+
return await dispatch_agent_method(agent, method, params, is_notification)
57+
58+
return handler
59+
60+
async def sessionUpdate(self, params: SessionNotification) -> None:
61+
await notify_model(self._conn, CLIENT_METHODS["session_update"], params)
62+
63+
async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse:
64+
return await request_model(
65+
self._conn,
66+
CLIENT_METHODS["session_request_permission"],
67+
params,
68+
RequestPermissionResponse,
69+
)
70+
71+
async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse:
72+
return await request_model(
73+
self._conn,
74+
CLIENT_METHODS["fs_read_text_file"],
75+
params,
76+
ReadTextFileResponse,
77+
)
78+
79+
async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse | None:
80+
return await request_optional_model(
81+
self._conn,
82+
CLIENT_METHODS["fs_write_text_file"],
83+
params,
84+
WriteTextFileResponse,
85+
)
86+
87+
async def createTerminal(self, params: CreateTerminalRequest) -> TerminalHandle:
88+
create_response = await request_model(
89+
self._conn,
90+
CLIENT_METHODS["terminal_create"],
91+
params,
92+
CreateTerminalResponse,
93+
)
94+
return TerminalHandle(create_response.terminalId, params.sessionId, self._conn)
95+
96+
async def terminalOutput(self, params: TerminalOutputRequest) -> TerminalOutputResponse:
97+
return await request_model(
98+
self._conn,
99+
CLIENT_METHODS["terminal_output"],
100+
params,
101+
TerminalOutputResponse,
102+
)
103+
104+
async def releaseTerminal(self, params: ReleaseTerminalRequest) -> ReleaseTerminalResponse | None:
105+
return await request_optional_model(
106+
self._conn,
107+
CLIENT_METHODS["terminal_release"],
108+
params,
109+
ReleaseTerminalResponse,
110+
)
111+
112+
async def waitForTerminalExit(self, params: WaitForTerminalExitRequest) -> WaitForTerminalExitResponse:
113+
return await request_model(
114+
self._conn,
115+
CLIENT_METHODS["terminal_wait_for_exit"],
116+
params,
117+
WaitForTerminalExitResponse,
118+
)
119+
120+
async def killTerminal(self, params: KillTerminalCommandRequest) -> KillTerminalCommandResponse | None:
121+
return await request_optional_model(
122+
self._conn,
123+
CLIENT_METHODS["terminal_kill"],
124+
params,
125+
KillTerminalCommandResponse,
126+
)
127+
128+
async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
129+
return await self._conn.send_request(f"_{method}", params)
130+
131+
async def extNotification(self, method: str, params: dict[str, Any]) -> None:
132+
await self._conn.send_notification(f"_{method}", params)

src/acp/agent/handlers.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from ..exceptions import RequestError
6+
from ..interfaces import Agent
7+
from ..meta import AGENT_METHODS
8+
from ..schema import (
9+
AuthenticateRequest,
10+
CancelNotification,
11+
InitializeRequest,
12+
LoadSessionRequest,
13+
NewSessionRequest,
14+
PromptRequest,
15+
SetSessionModelRequest,
16+
SetSessionModeRequest,
17+
)
18+
from ..utils import normalize_result
19+
20+
__all__ = [
21+
"NO_MATCH",
22+
"dispatch_agent_method",
23+
]
24+
25+
26+
class _NoMatch:
27+
"""Sentinel returned by routing helpers when no handler matches."""
28+
29+
30+
NO_MATCH = _NoMatch()
31+
32+
33+
async def _handle_agent_init(agent: Agent, method: str, params: Any | None) -> Any:
34+
if method == AGENT_METHODS["initialize"]:
35+
request = InitializeRequest.model_validate(params)
36+
return await agent.initialize(request)
37+
if method == AGENT_METHODS["session_new"]:
38+
request = NewSessionRequest.model_validate(params)
39+
return await agent.newSession(request)
40+
return NO_MATCH
41+
42+
43+
async def _handle_agent_session(agent: Agent, method: str, params: Any | None) -> Any:
44+
if method == AGENT_METHODS["session_load"]:
45+
if not hasattr(agent, "loadSession"):
46+
raise RequestError.method_not_found(method)
47+
request = LoadSessionRequest.model_validate(params)
48+
result = await agent.loadSession(request)
49+
return normalize_result(result)
50+
if method == AGENT_METHODS["session_set_mode"]:
51+
if not hasattr(agent, "setSessionMode"):
52+
raise RequestError.method_not_found(method)
53+
request = SetSessionModeRequest.model_validate(params)
54+
result = await agent.setSessionMode(request)
55+
return normalize_result(result)
56+
if method == AGENT_METHODS["session_prompt"]:
57+
request = PromptRequest.model_validate(params)
58+
return await agent.prompt(request)
59+
if method == AGENT_METHODS["session_set_model"]:
60+
if not hasattr(agent, "setSessionModel"):
61+
raise RequestError.method_not_found(method)
62+
request = SetSessionModelRequest.model_validate(params)
63+
result = await agent.setSessionModel(request)
64+
return normalize_result(result)
65+
if method == AGENT_METHODS["session_cancel"]:
66+
request = CancelNotification.model_validate(params)
67+
return await agent.cancel(request)
68+
return NO_MATCH
69+
70+
71+
async def _handle_agent_auth(agent: Agent, method: str, params: Any | None) -> Any:
72+
if method == AGENT_METHODS["authenticate"]:
73+
if not hasattr(agent, "authenticate"):
74+
raise RequestError.method_not_found(method)
75+
request = AuthenticateRequest.model_validate(params)
76+
result = await agent.authenticate(request)
77+
return normalize_result(result)
78+
return NO_MATCH
79+
80+
81+
async def _handle_agent_extensions(agent: Agent, method: str, params: Any | None, is_notification: bool) -> Any:
82+
if isinstance(method, str) and method.startswith("_"):
83+
ext_name = method[1:]
84+
if is_notification:
85+
if hasattr(agent, "extNotification"):
86+
await agent.extNotification(ext_name, params or {}) # type: ignore[arg-type]
87+
return None
88+
return None
89+
if hasattr(agent, "extMethod"):
90+
return await agent.extMethod(ext_name, params or {}) # type: ignore[arg-type]
91+
return NO_MATCH
92+
return NO_MATCH
93+
94+
95+
async def dispatch_agent_method(agent: Agent, method: str, params: Any | None, is_notification: bool) -> Any:
96+
"""Dispatch agent-bound methods, mirroring the upstream ACP routing."""
97+
for resolver in (_handle_agent_init, _handle_agent_session, _handle_agent_auth):
98+
result = await resolver(agent, method, params)
99+
if result is not NO_MATCH:
100+
return result
101+
extension_result = await _handle_agent_extensions(agent, method, params, is_notification)
102+
if extension_result is not NO_MATCH:
103+
return extension_result
104+
raise RequestError.method_not_found(method)

src/acp/client/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .connection import ClientSideConnection
2+
3+
__all__ = ["ClientSideConnection"]

src/acp/client/connection.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import Callable
5+
from typing import Any
6+
7+
from ..connection import Connection, MethodHandler
8+
from ..interfaces import Agent, Client
9+
from ..meta import AGENT_METHODS
10+
from ..schema import (
11+
AuthenticateRequest,
12+
AuthenticateResponse,
13+
CancelNotification,
14+
InitializeRequest,
15+
InitializeResponse,
16+
LoadSessionRequest,
17+
LoadSessionResponse,
18+
NewSessionRequest,
19+
NewSessionResponse,
20+
PromptRequest,
21+
PromptResponse,
22+
SetSessionModelRequest,
23+
SetSessionModelResponse,
24+
SetSessionModeRequest,
25+
SetSessionModeResponse,
26+
)
27+
from ..utils import (
28+
notify_model,
29+
request_model,
30+
request_model_from_dict,
31+
)
32+
from .handlers import dispatch_client_method
33+
34+
__all__ = ["ClientSideConnection"]
35+
36+
_CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader"
37+
38+
39+
class ClientSideConnection:
40+
"""Client-side connection wrapper that dispatches JSON-RPC messages to an Agent implementation."""
41+
42+
def __init__(
43+
self,
44+
to_client: Callable[[Agent], Client],
45+
input_stream: Any,
46+
output_stream: Any,
47+
) -> None:
48+
if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader):
49+
raise TypeError(_CLIENT_CONNECTION_ERROR)
50+
51+
client = to_client(self) # type: ignore[arg-type]
52+
handler = self._create_handler(client)
53+
self._conn = Connection(handler, input_stream, output_stream)
54+
55+
def _create_handler(self, client: Client) -> MethodHandler:
56+
async def handler(method: str, params: Any | None, is_notification: bool) -> Any:
57+
return await dispatch_client_method(client, method, params, is_notification)
58+
59+
return handler
60+
61+
async def initialize(self, params: InitializeRequest) -> InitializeResponse:
62+
return await request_model(
63+
self._conn,
64+
AGENT_METHODS["initialize"],
65+
params,
66+
InitializeResponse,
67+
)
68+
69+
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse:
70+
return await request_model(
71+
self._conn,
72+
AGENT_METHODS["session_new"],
73+
params,
74+
NewSessionResponse,
75+
)
76+
77+
async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse:
78+
return await request_model_from_dict(
79+
self._conn,
80+
AGENT_METHODS["session_load"],
81+
params,
82+
LoadSessionResponse,
83+
)
84+
85+
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse:
86+
return await request_model_from_dict(
87+
self._conn,
88+
AGENT_METHODS["session_set_mode"],
89+
params,
90+
SetSessionModeResponse,
91+
)
92+
93+
async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse:
94+
return await request_model_from_dict(
95+
self._conn,
96+
AGENT_METHODS["session_set_model"],
97+
params,
98+
SetSessionModelResponse,
99+
)
100+
101+
async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse:
102+
return await request_model_from_dict(
103+
self._conn,
104+
AGENT_METHODS["authenticate"],
105+
params,
106+
AuthenticateResponse,
107+
)
108+
109+
async def prompt(self, params: PromptRequest) -> PromptResponse:
110+
return await request_model(
111+
self._conn,
112+
AGENT_METHODS["session_prompt"],
113+
params,
114+
PromptResponse,
115+
)
116+
117+
async def cancel(self, params: CancelNotification) -> None:
118+
await notify_model(self._conn, AGENT_METHODS["session_cancel"], params)
119+
120+
async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
121+
return await self._conn.send_request(f"_{method}", params)
122+
123+
async def extNotification(self, method: str, params: dict[str, Any]) -> None:
124+
await self._conn.send_notification(f"_{method}", params)

0 commit comments

Comments
 (0)