diff --git a/hack/values-dev.yaml b/hack/values-dev.yaml new file mode 100644 index 00000000..5f95b780 --- /dev/null +++ b/hack/values-dev.yaml @@ -0,0 +1,4 @@ +namespace: paperdebugger-dev + +mongo: + in_cluster: false diff --git a/internal/api/gin.go b/internal/api/gin.go index c9e0d577..68abbee4 100644 --- a/internal/api/gin.go +++ b/internal/api/gin.go @@ -21,7 +21,7 @@ func NewGinServer(cfg *cfg.Cfg, oauthHandler *auth.OAuthHandler) *GinServer { ginServer := &GinServer{Engine: gin.New(), cfg: cfg} ginServer.Use(ginServer.ginLogMiddleware(), gin.Recovery()) ginServer.Use(cors.New(cors.Config{ - AllowOrigins: []string{"https://overleaf.com", "https://*.overleaf.com", "https://*.paperdebugger.com", "http://localhost:3000", "http://127.0.0.1:3000"}, + AllowOrigins: []string{"*"}, AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, AllowHeaders: []string{"*"}, ExposeHeaders: []string{"*"}, diff --git a/internal/services/toolkit/client/client.go b/internal/services/toolkit/client/client.go index b2ea7acc..2e8b091d 100644 --- a/internal/services/toolkit/client/client.go +++ b/internal/services/toolkit/client/client.go @@ -9,7 +9,7 @@ import ( "paperdebugger/internal/services" "paperdebugger/internal/services/toolkit/handler" "paperdebugger/internal/services/toolkit/registry" - "paperdebugger/internal/services/toolkit/tools" + "paperdebugger/internal/services/toolkit/tools/xtramcp" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -42,18 +42,37 @@ func NewAIClient( option.WithAPIKey(cfg.OpenAIAPIKey), ) CheckOpenAIWorks(oaiClient, logger) - - toolPaperScore := tools.NewPaperScoreTool(db, projectService) - toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService) + // toolPaperScore := tools.NewPaperScoreTool(db, projectService) + // toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService) toolRegistry := registry.NewToolRegistry() - toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool) - toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool) - toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call) - toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call) - toolCallHandler := handler.NewToolCallHandler(toolRegistry) + // toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool) + // toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool) + // toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call) + // toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call) + + // Load tools dynamically from backend (TODO: Make URL configurable / Xtramcp url) + xtraMCPLoader := xtramcp.NewXtraMCPLoader(db, projectService, "http://localhost:8080/mcp") + + // initialize MCP session first and log session ID + sessionID, err := xtraMCPLoader.InitializeMCP() + if err != nil { + logger.Errorf("[AI Client] Failed to initialize XtraMCP session: %v", err) + // TODO: Fallback to static tools or exit? + } else { + logger.Info("[AI Client] XtraMCP session initialized", "sessionID", sessionID) + + // dynamically load all tools from XtraMCP backend + err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry) + if err != nil { + logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err) + } else { + logger.Info("[AI Client] Successfully loaded XtraMCP tools") + } + } + toolCallHandler := handler.NewToolCallHandler(toolRegistry) client := &AIClient{ openaiClient: &oaiClient, toolCallHandler: toolCallHandler, diff --git a/internal/services/toolkit/tools/xtramcp/loader.go b/internal/services/toolkit/tools/xtramcp/loader.go new file mode 100644 index 00000000..974ff71b --- /dev/null +++ b/internal/services/toolkit/tools/xtramcp/loader.go @@ -0,0 +1,221 @@ +package xtramcp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/services" + "paperdebugger/internal/services/toolkit/registry" +) + +// MCPListToolsResponse represents the JSON-RPC response from tools/list method +type MCPListToolsResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result struct { + Tools []ToolSchema `json:"tools"` + } `json:"result"` +} + +// loads tools dynamically from backend +type XtraMCPLoader struct { + db *db.DB + projectService *services.ProjectService + baseURL string + client *http.Client + sessionID string // Store the MCP session ID after initialization for re-use +} + +// NewXtraMCPLoader creates a new dynamic XtraMCP loader +func NewXtraMCPLoader(db *db.DB, projectService *services.ProjectService, baseURL string) *XtraMCPLoader { + return &XtraMCPLoader{ + db: db, + projectService: projectService, + baseURL: baseURL, + client: &http.Client{}, + } +} + +// LoadToolsFromBackend fetches tool schemas from backend and registers them +func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolRegistry) error { + if loader.sessionID == "" { + return fmt.Errorf("MCP session not initialized - call InitializeMCP first") + } + + // Fetch tools from backend using the established session + toolSchemas, err := loader.fetchAvailableTools() + if err != nil { + return fmt.Errorf("failed to fetch tools from backend: %w", err) + } + + // Register each tool dynamically, passing the session ID + for _, toolSchema := range toolSchemas { + dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID) + + // Register the tool with the registry + toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call) + + fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name) + } + + return nil +} + +// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it +func (loader *XtraMCPLoader) InitializeMCP() (string, error) { + // Step 1: Initialize + sessionID, err := loader.performInitialize() + if err != nil { + return "", fmt.Errorf("step 1 - initialize failed: %w", err) + } + + // Step 2: Send notifications/initialized + err = loader.sendInitializedNotification(sessionID) + if err != nil { + return "", fmt.Errorf("step 2 - notifications/initialized failed: %w", err) + } + + // Store session ID for future use and return it + loader.sessionID = sessionID + + return sessionID, nil +} + +// performInitialize performs MCP initialization (1. establish connection) +func (loader *XtraMCPLoader) performInitialize() (string, error) { + initReq := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1, + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{ + "name": "paperdebugger-client", + "version": "1.0.0", + }, + }, + } + + jsonData, err := json.Marshal(initReq) + if err != nil { + return "", fmt.Errorf("failed to marshal initialize request: %w", err) + } + + req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create initialize request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := loader.client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to make initialize request: %w", err) + } + defer resp.Body.Close() + + // Extract session ID from response headers + sessionID := resp.Header.Get("mcp-session-id") + if sessionID == "" { + return "", fmt.Errorf("no session ID returned from initialize") + } + + return sessionID, nil +} + +// sendInitializedNotification completes MCP initialization (acknowledges initialization) +func (loader *XtraMCPLoader) sendInitializedNotification(sessionID string) error { + notifyReq := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": map[string]interface{}{}, + } + + jsonData, err := json.Marshal(notifyReq) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create notification request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", sessionID) + + resp, err := loader.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send notification: %w", err) + } + defer resp.Body.Close() + + return nil +} + +// fetchAvailableTools makes a request to get available tools from backend +func (loader *XtraMCPLoader) fetchAvailableTools() ([]ToolSchema, error) { + // List all tools using the established session + requestBody := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "tools/list", + "params": map[string]interface{}{}, + "id": 2, + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", loader.sessionID) + + resp, err := loader.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + // Read the raw response body (SSE format) for debugging + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Parse SSE format - extract JSON from "data: " lines + lines := strings.Split(string(bodyBytes), "\n") + var extractedJSON string + for _, line := range lines { + if strings.HasPrefix(line, "data: ") { + extractedJSON = strings.TrimPrefix(line, "data: ") + break + } + } + + if extractedJSON == "" { + return nil, fmt.Errorf("no data line found in SSE response") + } + + // Parse the extracted JSON + var mcpResponse MCPListToolsResponse + err = json.Unmarshal([]byte(extractedJSON), &mcpResponse) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON from SSE data: %w. JSON data: %s", err, extractedJSON) + } + + return mcpResponse.Result.Tools, nil +} diff --git a/internal/services/toolkit/tools/xtramcp/tool.go b/internal/services/toolkit/tools/xtramcp/tool.go new file mode 100644 index 00000000..12156837 --- /dev/null +++ b/internal/services/toolkit/tools/xtramcp/tool.go @@ -0,0 +1,170 @@ +package xtramcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/services" + toolCallRecordDB "paperdebugger/internal/services/toolkit/db" + "strings" + "time" + + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/packages/param" + "github.com/openai/openai-go/v2/responses" + "github.com/samber/lo" +) + +// ToolSchema represents the schema from your backend +type ToolSchema struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` + OutputSchema map[string]interface{} `json:"outputSchema"` +} + +// MCPRequest represents the JSON-RPC request structure +type MCPRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID int `json:"id"` + Params MCPParams `json:"params"` +} + +// MCPParams represents the parameters for the MCP request +type MCPParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// DynamicTool represents a generic tool that can handle any schema +type DynamicTool struct { + Name string + Description responses.ToolUnionParam + toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB + projectService *services.ProjectService + coolDownTime time.Duration + baseURL string + client *http.Client + schema map[string]interface{} + sessionID string // Reuse the session ID from initialization +} + +// NewDynamicTool creates a new dynamic tool from a schema +func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string) *DynamicTool { + // Create tool description with the schema + description := responses.ToolUnionParam{ + OfFunction: &responses.FunctionToolParam{ + Name: toolSchema.Name, + Description: param.NewOpt(toolSchema.Description), + Parameters: openai.FunctionParameters(toolSchema.InputSchema), + }, + } + + toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db) + return &DynamicTool{ + Name: toolSchema.Name, + Description: description, + toolCallRecordDB: toolCallRecordDB, + projectService: projectService, + coolDownTime: 5 * time.Minute, + baseURL: baseURL, + client: &http.Client{}, + schema: toolSchema.InputSchema, + sessionID: sessionID, // Store the session ID for reuse + } +} + +// Call handles the tool execution (generic for any tool) +func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) { + // Parse arguments as generic map since we don't know the structure + var argsMap map[string]interface{} + err := json.Unmarshal(args, &argsMap) + if err != nil { + return "", "", err + } + + // Create function call record + record, err := t.toolCallRecordDB.Create(ctx, toolCallId, t.Name, argsMap) + if err != nil { + return "", "", err + } + + // Execute the tool via MCP + respStr, err := t.executeTool(argsMap) + if err != nil { + err = fmt.Errorf("failed to execute tool %s: %v", t.Name, err) + t.toolCallRecordDB.OnError(ctx, record, err) + return "", "", err + } + + rawJson, err := json.Marshal(respStr) + if err != nil { + err = fmt.Errorf("failed to marshal tool result: %v", err) + t.toolCallRecordDB.OnError(ctx, record, err) + return "", "", err + } + t.toolCallRecordDB.OnSuccess(ctx, record, string(rawJson)) + + return respStr, "", nil +} + +// executeTool makes the MCP request (generic for any tool) +func (t *DynamicTool) executeTool(args map[string]interface{}) (string, error) { + + request := MCPRequest{ + JSONRPC: "2.0", + Method: "tools/call", + ID: int(time.Now().Unix()), // to ensure unique ID; TODO: consider better ID generation + Params: MCPParams{ + Name: t.Name, + Arguments: args, + }, + } + + // Marshal request to JSON + jsonData, err := json.Marshal(request) + if err != nil { + return "", fmt.Errorf("failed to marshal MCP request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequest("POST", t.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", t.sessionID) // Use the stored session ID + + // Make the request + resp, err := t.client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + // Parse response (assuming stream format) + lines := strings.Split(string(body), "\n") + lines = lo.Filter(lines, func(line string, _ int) bool { + return strings.HasPrefix(line, "data:") + }) + if len(lines) == 0 { + return "", fmt.Errorf("no data line found") + } + line := lines[0] + line = strings.TrimPrefix(line, "data: ") + return line, nil +}