Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/canvas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func TestCanvasRegisterClientSessionAPIHandlers_RawJSONRoundTrip(t *testing.T) {
_ = serverToClientReader.Close()
})

raw, err := requester.Request("canvas.open", map[string]any{
raw, err := requester.Request(t.Context(), "canvas.open", map[string]any{
"sessionId": "s1",
"extensionId": "ext",
"canvasId": "echo",
Expand Down Expand Up @@ -284,7 +284,7 @@ func TestCanvasRegisterClientSessionAPIHandlers_RawJSONRoundTrip(t *testing.T) {
t.Fatalf("expected status=ready, got %v", decoded["status"])
}

actionRaw, err := requester.Request("canvas.action.invoke", map[string]any{
actionRaw, err := requester.Request(t.Context(), "canvas.action.invoke", map[string]any{
"sessionId": "s1",
"extensionId": "ext",
"canvasId": "echo",
Expand Down
24 changes: 12 additions & 12 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses
}
}

result, err := c.client.RequestWithInlineResponse("session.create", req, inlineCb)
result, err := c.client.RequestWithInlineResponse(ctx, "session.create", req, inlineCb)
if err != nil {
if registeredSessionID != "" {
c.sessionsMux.Lock()
Expand Down Expand Up @@ -1075,7 +1075,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
session.clientSessionAPIs.SessionFS = newSessionFSAdapter(provider)
}

result, err := c.client.Request("session.resume", req)
result, err := c.client.Request(ctx, "session.resume", req)
if err != nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
Expand Down Expand Up @@ -1136,7 +1136,7 @@ func (c *Client) ListSessions(ctx context.Context, filter *SessionListFilter) ([
if filter != nil {
params.Filter = filter
}
result, err := c.client.Request("session.list", params)
result, err := c.client.Request(ctx, "session.list", params)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1168,7 +1168,7 @@ func (c *Client) GetSessionMetadata(ctx context.Context, sessionID string) (*Ses
return nil, err
}

result, err := c.client.Request("session.getMetadata", getSessionMetadataRequest{SessionID: sessionID})
result, err := c.client.Request(ctx, "session.getMetadata", getSessionMetadataRequest{SessionID: sessionID})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1199,7 +1199,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
return err
}

result, err := c.client.Request("session.delete", deleteSessionRequest{SessionID: sessionID})
result, err := c.client.Request(ctx, "session.delete", deleteSessionRequest{SessionID: sessionID})
if err != nil {
return err
}
Expand Down Expand Up @@ -1246,7 +1246,7 @@ func (c *Client) GetLastSessionID(ctx context.Context) (*string, error) {
return nil, err
}

result, err := c.client.Request("session.getLastId", getLastSessionIDRequest{})
result, err := c.client.Request(ctx, "session.getLastId", getLastSessionIDRequest{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1278,7 +1278,7 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) {
return nil, err
}

result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{})
result, err := c.client.Request(ctx, "session.getForeground", getForegroundSessionRequest{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1306,7 +1306,7 @@ func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) e
return err
}

result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID})
result, err := c.client.Request(ctx, "session.setForeground", setForegroundSessionRequest{SessionID: sessionID})
if err != nil {
return err
}
Expand Down Expand Up @@ -1446,7 +1446,7 @@ func (c *Client) Ping(ctx context.Context, message string) (*PingResponse, error
return nil, fmt.Errorf("client not connected")
}

result, err := c.client.Request("ping", pingRequest{Message: message})
result, err := c.client.Request(ctx, "ping", pingRequest{Message: message})
if err != nil {
return nil, err
}
Expand All @@ -1464,7 +1464,7 @@ func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) {
return nil, fmt.Errorf("client not connected")
}

result, err := c.client.Request("status.get", getStatusRequest{})
result, err := c.client.Request(ctx, "status.get", getStatusRequest{})
if err != nil {
return nil, err
}
Expand All @@ -1482,7 +1482,7 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err
return nil, fmt.Errorf("client not connected")
}

result, err := c.client.Request("auth.getStatus", getAuthStatusRequest{})
result, err := c.client.Request(ctx, "auth.getStatus", getAuthStatusRequest{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1523,7 +1523,7 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) {
return nil, fmt.Errorf("client not connected")
}
// Cache miss - fetch from backend while holding lock
result, err := c.client.Request("models.list", listModelsRequest{})
result, err := c.client.Request(ctx, "models.list", listModelsRequest{})
if err != nil {
return nil, err
}
Expand Down
59 changes: 45 additions & 14 deletions go/internal/jsonrpc2/jsonrpc2.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jsonrpc2

import (
"context"
"crypto/rand"
"encoding/json"
"errors"
Expand Down Expand Up @@ -202,8 +203,8 @@ func (c *Client) SetRequestHandler(method string, handler RequestHandler) {
}

// Request sends a JSON-RPC request and waits for the response
func (c *Client) Request(method string, params any) (json.RawMessage, error) {
return c.RequestWithInlineResponse(method, params, nil)
func (c *Client) Request(ctx context.Context, method string, params any) (json.RawMessage, error) {
return c.RequestWithInlineResponse(ctx, method, params, nil)
}
Comment thread
qmuntal marked this conversation as resolved.

// RequestWithInlineResponse sends a JSON-RPC request and waits for the response,
Expand All @@ -214,7 +215,13 @@ func (c *Client) Request(method string, params any) (json.RawMessage, error) {
// server in the response) before any subsequent notification on the same
// connection is dispatched. If the callback returns an error, that error is
// returned to the awaiter in place of the response.
func (c *Client) RequestWithInlineResponse(method string, params any, onResponseInline func(json.RawMessage) error) (json.RawMessage, error) {
func (c *Client) RequestWithInlineResponse(ctx context.Context, method string, params any, onResponseInline func(json.RawMessage) error) (json.RawMessage, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
Comment thread
qmuntal marked this conversation as resolved.

requestID := generateUUID()

// Create response channel
Expand All @@ -237,6 +244,8 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
// Check if process already exited before sending
if c.processDone != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-c.processDone:
if err := c.getProcessError(); err != nil {
return nil, err
Expand Down Expand Up @@ -266,13 +275,18 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
Params: paramsData,
}

if err := c.sendMessage(request); err != nil {
if err := c.sendMessage(ctx, request); err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, fmt.Errorf("failed to send request: %w", err)
}

// Wait for response, also checking for process exit
if c.processDone != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case response := <-responseChan:
if response.Error != nil {
return nil, response.Error
Expand All @@ -288,6 +302,8 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
}
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case response := <-responseChan:
if response.Error != nil {
return nil, response.Error
Expand All @@ -301,13 +317,26 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
// sendMessage writes a message to the stream.
// Write serialization is achieved via a 1-buffered channel that holds the
// writer when not in use, avoiding the need for a mutex on the write path.
func (c *Client) sendMessage(message any) error {
func (c *Client) sendMessage(ctx context.Context, message any) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

data, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}

w := <-c.writer
var w *headerWriter
select {
case <-ctx.Done():
return ctx.Err()
case <-c.stopChan:
return fmt.Errorf("client stopped")
case w = <-c.writer:
}
defer func() { c.writer <- w }()
return w.Write(data)
}
Expand Down Expand Up @@ -402,13 +431,15 @@ func (c *Client) handleResponse(response *Response) {
}

func (c *Client) handleRequest(request *Request) {
ctx := context.Background()

c.mu.Lock()
handler := c.requestHandlers[request.Method]
c.mu.Unlock()

if handler == nil {
if request.IsCall() {
c.sendErrorResponse(request.ID, &Error{
c.sendErrorResponse(ctx, request.ID, &Error{
Code: ErrMethodNotFound.Code,
Message: fmt.Sprintf("Method not found: %s", request.Method),
})
Expand All @@ -425,7 +456,7 @@ func (c *Client) handleRequest(request *Request) {
go func() {
defer func() {
if r := recover(); r != nil {
c.sendErrorResponse(request.ID, &Error{
c.sendErrorResponse(ctx, request.ID, &Error{
Code: ErrInternal.Code,
Message: fmt.Sprintf("request handler panic: %v", r),
})
Expand All @@ -434,31 +465,31 @@ func (c *Client) handleRequest(request *Request) {

result, err := handler(request.Params)
if err != nil {
c.sendErrorResponse(request.ID, err)
c.sendErrorResponse(ctx, request.ID, err)
return
}
c.sendResponse(request.ID, result)
c.sendResponse(ctx, request.ID, result)
}()
}

func (c *Client) sendResponse(id json.RawMessage, result json.RawMessage) {
func (c *Client) sendResponse(ctx context.Context, id json.RawMessage, result json.RawMessage) {
response := Response{
JSONRPC: version,
ID: id,
Result: result,
}
if err := c.sendMessage(response); err != nil {
if err := c.sendMessage(ctx, response); err != nil {
fmt.Printf("Failed to send JSON-RPC response: %v\n", err)
}
}

func (c *Client) sendErrorResponse(id json.RawMessage, rpcErr *Error) {
func (c *Client) sendErrorResponse(ctx context.Context, id json.RawMessage, rpcErr *Error) {
response := Response{
JSONRPC: version,
ID: id,
Error: rpcErr,
}
if err := c.sendMessage(response); err != nil {
if err := c.sendMessage(ctx, response); err != nil {
fmt.Printf("Failed to send JSON-RPC error response: %v\n", err)
}
}
Expand Down
Loading
Loading