From 54886f2b38fde148d98b8d9cb0548f3fc16e5f6b Mon Sep 17 00:00:00 2001 From: Junyi Hou Date: Sun, 12 Oct 2025 05:17:03 +0800 Subject: [PATCH 1/7] fix: CORS --- internal/api/gin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/gin.go b/internal/api/gin.go index c9e0d577..68abbee4 100644 --- a/internal/api/gin.go +++ b/internal/api/gin.go @@ -21,7 +21,7 @@ func NewGinServer(cfg *cfg.Cfg, oauthHandler *auth.OAuthHandler) *GinServer { ginServer := &GinServer{Engine: gin.New(), cfg: cfg} ginServer.Use(ginServer.ginLogMiddleware(), gin.Recovery()) ginServer.Use(cors.New(cors.Config{ - AllowOrigins: []string{"https://overleaf.com", "https://*.overleaf.com", "https://*.paperdebugger.com", "http://localhost:3000", "http://127.0.0.1:3000"}, + AllowOrigins: []string{"*"}, AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, AllowHeaders: []string{"*"}, ExposeHeaders: []string{"*"}, From bb7223d472f657dd1502f533d648b621356624e0 Mon Sep 17 00:00:00 2001 From: Junyi Hou Date: Tue, 14 Oct 2025 17:37:04 +0800 Subject: [PATCH 2/7] chore: tested any schema --- hack/values-dev.yaml | 4 + internal/services/toolkit/client/client.go | 25 ++- .../toolkit/tools/xtragpt/search_papers.go | 194 ++++++++++++++++++ .../services/toolkit/tools/xtragpt/util.go | 135 ++++++++++++ .../toolkit/tools/xtragpt/util_test.go | 173 ++++++++++++++++ 5 files changed, 523 insertions(+), 8 deletions(-) create mode 100644 hack/values-dev.yaml create mode 100644 internal/services/toolkit/tools/xtragpt/search_papers.go create mode 100644 internal/services/toolkit/tools/xtragpt/util.go create mode 100644 internal/services/toolkit/tools/xtragpt/util_test.go diff --git a/hack/values-dev.yaml b/hack/values-dev.yaml new file mode 100644 index 00000000..5f95b780 --- /dev/null +++ b/hack/values-dev.yaml @@ -0,0 +1,4 @@ +namespace: paperdebugger-dev + +mongo: + in_cluster: false diff --git a/internal/services/toolkit/client/client.go b/internal/services/toolkit/client/client.go index b2ea7acc..487351b0 100644 --- a/internal/services/toolkit/client/client.go +++ b/internal/services/toolkit/client/client.go @@ -9,7 +9,7 @@ import ( "paperdebugger/internal/services" "paperdebugger/internal/services/toolkit/handler" "paperdebugger/internal/services/toolkit/registry" - "paperdebugger/internal/services/toolkit/tools" + "paperdebugger/internal/services/toolkit/tools/xtragpt" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -42,15 +42,24 @@ func NewAIClient( option.WithAPIKey(cfg.OpenAIAPIKey), ) CheckOpenAIWorks(oaiClient, logger) - - toolPaperScore := tools.NewPaperScoreTool(db, projectService) - toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService) + toolSearchPapers := xtragpt.NewSearchPapersTool(db, projectService) + // toolPaperScore := tools.NewPaperScoreTool(db, projectService) + // toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService) toolRegistry := registry.NewToolRegistry() - toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool) - toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool) - toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call) - toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call) + + // toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool) + // toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool) + // toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call) + // toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call) + + // toolRegistry.Register("export_papers") + // toolRegistry.Register("get_conference_papers") + // toolRegistry.Register("get_user_papers") + toolRegistry.Register("search_papers", toolSearchPapers.Description, toolSearchPapers.Call) + // toolRegistry.Register("search_user") + // toolRegistry.Register("identify_improvements") + // toolRegistry.Register("suggest_improvement") toolCallHandler := handler.NewToolCallHandler(toolRegistry) diff --git a/internal/services/toolkit/tools/xtragpt/search_papers.go b/internal/services/toolkit/tools/xtragpt/search_papers.go new file mode 100644 index 00000000..6c9405f0 --- /dev/null +++ b/internal/services/toolkit/tools/xtragpt/search_papers.go @@ -0,0 +1,194 @@ +package xtragpt + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/services" + toolCallRecordDB "paperdebugger/internal/services/toolkit/db" + "strings" + "time" + + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/packages/param" + "github.com/openai/openai-go/v2/responses" + "github.com/samber/lo" +) + +// MCPRequest represents the JSON-RPC request structure +type MCPRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID int `json:"id"` + Params MCPParams `json:"params"` +} + +// MCPParams represents the parameters for the MCP request +type MCPParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// Venue represents a conference venue with year +type Venue struct { + Venue string `json:"venue"` + Year string `json:"year"` +} +type SearchPapersTool struct { + Description responses.ToolUnionParam + toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB + projectService *services.ProjectService + coolDownTime time.Duration + baseURL string + client *http.Client +} + +var schema map[string]any + +var SearchPapersToolDescription = responses.ToolUnionParam{ + OfFunction: &responses.FunctionToolParam{ + Name: "search_papers", + Description: param.NewOpt("Search for papers by keywords within specific conference venues, with various matching modes."), + Parameters: openai.FunctionParameters(schema), + }, +} + +func NewSearchPapersTool(db *db.DB, projectService *services.ProjectService) *SearchPapersTool { + json.Unmarshal([]byte(`{"properties":{"query":{"description":"Keywords / topics or content to search for (e.g., 'time series token merging', 'neural networks').","title":"Query","type":"string"},"venues":{"description":"List of conference venues and years to search in. Each entry must be a dict with 'venue' (e.g., 'ICLR.cc', 'NeurIPS.cc', 'ICML.cc'; users may omit '.cc') and 'year' (e.g., '2024', '2025').","items":{"additionalProperties":{"type":"string"},"type":"object"},"minItems":1,"title":"Venues","type":"array"},"search_fields":{"default":["title","abstract"],"description":"Fields to search within each paper. Options: 'title', 'abstract', 'authors'.","items":{"enum":["title","abstract","authors"],"type":"string"},"title":"Search Fields","type":"array"},"match_mode":{"default":"majority","description":"Match mode:\n- any: At least one keyword must match\n- all: All keywords must match\n- exact: Match the entire phrase exactly\n- majority: Match majority of keywords (>50%)\n- threshold: Match percentage of terms based on 'match_threshold'.","enum":["any","all","exact","majority","threshold"],"title":"Match Mode","type":"string"},"match_threshold":{"default":0.5,"description":"Minimum fraction (0.0-1.0) of search terms that must match when using 'threshold' mode. Example: 0.5 = 50% of terms must match.","maximum":1,"minimum":0,"title":"Match Threshold","type":"number"},"limit":{"default":10,"description":"Maximum number of results to return (1-16).","maximum":16,"minimum":1,"title":"Limit","type":"integer"},"min_score":{"default":0.6,"description":"Minimum match score (0.0-1.0). Lower values allow looser matches; higher values enforce stricter matches.","maximum":1,"minimum":0,"title":"Min Score","type":"number"}},"required":["query","venues"],"title":"search_papers_toolArguments","type":"object"}`), &schema) + toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db) + return &SearchPapersTool{ + Description: SearchPapersToolDescription, + toolCallRecordDB: toolCallRecordDB, + projectService: projectService, + coolDownTime: 5 * time.Minute, + baseURL: "http://xtragpt-mcp-server:8080/paper-score", + client: &http.Client{}, + } +} + +type SearchPapersToolArgs struct { + Limit int `json:"limit"` + MatchMode string `json:"matchMode"` + MatchThreshold float64 `json:"matchThreshold"` + MinScore float64 `json:"minScore"` + Query string `json:"query"` + Venues []Venue `json:"venues"` + SearchFields []string `json:"searchFields"` +} + +func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) { + var argsMap SearchPapersToolArgs + err := json.Unmarshal(args, &argsMap) + if err != nil { + return "", "", err + } + + // Create function call record + record, err := t.toolCallRecordDB.Create(ctx, toolCallId, *t.Description.GetName(), map[string]any{ + "limit": argsMap.Limit, + "matchMode": argsMap.MatchMode, + "matchThreshold": argsMap.MatchThreshold, + "minScore": argsMap.MinScore, + "query": argsMap.Query, + "venues": argsMap.Venues, + "searchFields": argsMap.SearchFields, + }) + if err != nil { + return "", "", err + } + + respStr, err := t.SearchPaper(argsMap.Limit, argsMap.MatchMode, argsMap.MatchThreshold, argsMap.MinScore, argsMap.Query, argsMap.Venues, argsMap.SearchFields) + if err != nil { + err = fmt.Errorf("failed to search paper: %v", err) + t.toolCallRecordDB.OnError(ctx, record, err) + return "", "", err + } + + rawJson, err := json.Marshal(respStr) + if err != nil { + err = fmt.Errorf("failed to marshal paper search result: %v, rawJson: %v", err, string(rawJson)) + t.toolCallRecordDB.OnError(ctx, record, err) + return "", "", err + } + t.toolCallRecordDB.OnSuccess(ctx, record, string(rawJson)) + + return respStr, "", nil +} + +func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThreshold float64, minScore float64, query string, venues []Venue, searchFields []string) (string, error) { + sessionId, err := MCPInitialize(t.baseURL) + if err != nil { + fmt.Printf("Error initializing MCP: %v\n", err) + return "", fmt.Errorf("failed to initialize MCP: %w", err) + } + if sessionId == "" { + return "", fmt.Errorf("failed to initialize MCP") + } + + fmt.Println("sessionId", sessionId) + request := MCPRequest{ + JSONRPC: "2.0", + Method: "tools/call", + ID: 2, + Params: MCPParams{ + Name: "search_papers", + Arguments: map[string]interface{}{ + "limit": limit, + "match_mode": matchMode, + "match_threshold": matchThreshold, + "min_score": minScore, + "query": query, + "search_fields": searchFields, + "venues": venues, + }, + }, + } + + // Marshal request to JSON + jsonData, err := json.Marshal(request) + if err != nil { + return "", fmt.Errorf("failed to marshal MCP request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequest("POST", "http://localhost:8080/mcp", bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", sessionId) + + // Make the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + fmt.Println("body", string(body)) + // split lines + lines := strings.Split(string(body), "\n") + // keep only the line starts with "data:" + lines = lo.Filter(lines, func(line string, _ int) bool { + return strings.HasPrefix(line, "data:") + }) + if len(lines) == 0 { + return "", fmt.Errorf("no data line found") + } + line := lines[0] + line = strings.TrimPrefix(line, "data: ") + return line, nil +} diff --git a/internal/services/toolkit/tools/xtragpt/util.go b/internal/services/toolkit/tools/xtragpt/util.go new file mode 100644 index 00000000..258aaa18 --- /dev/null +++ b/internal/services/toolkit/tools/xtragpt/util.go @@ -0,0 +1,135 @@ +package xtragpt + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" +) + +type InitializeRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]interface{} `json:"capabilities"` + ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` + } `json:"clientInfo"` + } `json:"params"` + ID int `json:"id"` +} + +type NotificationRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` +} + +type ToolsRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID int `json:"id"` +} + +func MCPNotificationsInitialized(url string, sessionId string) { + notifyReq := NotificationRequest{ + JSONRPC: "2.0", + Method: "notifications/initialized", + Params: make(map[string]interface{}), + } + + // Marshal notification to JSON + notifyJSON, err := json.Marshal(notifyReq) + if err != nil { + fmt.Printf("Error marshaling notification JSON: %v\n", err) + return + } + + // Create HTTP client and request for notification + client := &http.Client{} + req, err := http.NewRequest("POST", url, bytes.NewBuffer(notifyJSON)) + if err != nil { + fmt.Printf("Error creating request: %v\n", err) + return + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", sessionId) + + // Make the notification request + notifyResp, err := client.Do(req) + if err != nil { + fmt.Printf("Error making notification request: %v\n", err) + return + } + defer notifyResp.Body.Close() + +} + +func MCPInitialize(url string) (string, error) { + initReq := InitializeRequest{ + JSONRPC: "2.0", + Method: "initialize", + ID: 1, + } + initReq.Params.ProtocolVersion = "2024-11-05" + initReq.Params.Capabilities = make(map[string]interface{}) + initReq.Params.ClientInfo.Name = "test-client" + initReq.Params.ClientInfo.Version = "1.0.0" + + // Marshal to JSON + jsonData, err := json.Marshal(initReq) + if err != nil { + fmt.Printf("Error marshaling JSON: %v\n", err) + return "", err + } + + // Make the first request + resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Printf("Error making request: %v\n", err) + return "", err + } + defer resp.Body.Close() + fmt.Println("Initialize response", resp.Body, resp.Header) + + // Get session ID from headers + sessionID := resp.Header.Get("mcp-session-id") + fmt.Printf("Session ID: %s\n", sessionID) + + MCPNotificationsInitialized(url, sessionID) + fmt.Println("Initialized") + return sessionID, nil +} + +func MCPListTools(url string) ([]string, error) { + toolsReq := ToolsRequest{ + JSONRPC: "2.0", + Method: "tools/list", + ID: 1, + } + jsonData, err := json.Marshal(toolsReq) + if err != nil { + fmt.Printf("Error marshaling JSON: %v\n", err) + return nil, err + } + resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Printf("Error making request: %v\n", err) + return nil, err + } + defer resp.Body.Close() + fmt.Println("List tools response", resp.Body, resp.Header) + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("Error reading response: %v\n", err) + return nil, err + } + fmt.Println("List tools response", string(body)) + return nil, nil +} diff --git a/internal/services/toolkit/tools/xtragpt/util_test.go b/internal/services/toolkit/tools/xtragpt/util_test.go new file mode 100644 index 00000000..00dd75ca --- /dev/null +++ b/internal/services/toolkit/tools/xtragpt/util_test.go @@ -0,0 +1,173 @@ +package xtragpt_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "paperdebugger/internal/services/toolkit/tools/xtragpt" + "testing" +) + +func TestMCPInitialize_Success(t *testing.T) { + expectedSessionID := "test-session-123" + + // Mock server that handles both initialize and notifications/initialized requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + } + + // Parse request body to determine which request this is + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("Failed to decode request body: %v", err) + } + + method, ok := reqBody["method"].(string) + if !ok { + t.Fatalf("Missing or invalid method field") + } + + switch method { + case "initialize": + // Validate initialize request structure + if reqBody["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", reqBody["jsonrpc"]) + } + + if reqBody["id"] != float64(1) { + t.Errorf("Expected id 1, got %v", reqBody["id"]) + } + + params, ok := reqBody["params"].(map[string]interface{}) + if !ok { + t.Fatalf("Missing or invalid params field") + } + + if params["protocolVersion"] != "2024-11-05" { + t.Errorf("Expected protocolVersion 2024-11-05, got %v", params["protocolVersion"]) + } + + clientInfo, ok := params["clientInfo"].(map[string]interface{}) + if !ok { + t.Fatalf("Missing or invalid clientInfo field") + } + + if clientInfo["name"] != "test-client" { + t.Errorf("Expected client name test-client, got %v", clientInfo["name"]) + } + + if clientInfo["version"] != "1.0.0" { + t.Errorf("Expected client version 1.0.0, got %v", clientInfo["version"]) + } + + // Set session ID header and return success response + w.Header().Set("mcp-session-id", expectedSessionID) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05"}}`) + + case "notifications/initialized": + // Validate notifications/initialized request + if reqBody["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", reqBody["jsonrpc"]) + } + + // Check session ID header + if r.Header.Get("mcp-session-id") != expectedSessionID { + t.Errorf("Expected session ID %s, got %s", expectedSessionID, r.Header.Get("mcp-session-id")) + } + + if r.Header.Get("Accept") != "application/json, text/event-stream" { + t.Errorf("Expected Accept header 'application/json, text/event-stream', got %s", r.Header.Get("Accept")) + } + + w.WriteHeader(http.StatusOK) + + default: + t.Errorf("Unexpected method: %s", method) + } + })) + defer server.Close() + + sessionID, err := xtragpt.MCPInitialize(server.URL) + + if err != nil { + t.Fatalf("MCPInitialize failed: %v", err) + } + + if sessionID != expectedSessionID { + t.Errorf("Expected session ID %s, got %s", expectedSessionID, sessionID) + } +} + +func TestMCPInitialize_InvalidURL(t *testing.T) { + sessionID, err := xtragpt.MCPInitialize("invalid-url") + + if err == nil { + t.Fatalf("Expected error for invalid URL, but got none") + } + + if sessionID != "" { + t.Errorf("Expected empty session ID on error, got %s", sessionID) + } +} + +func TestMCPNotificationsInitialized_Success(t *testing.T) { + sessionID := "test-session-456" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + } + + if r.Header.Get("Accept") != "application/json, text/event-stream" { + t.Errorf("Expected Accept header 'application/json, text/event-stream', got %s", r.Header.Get("Accept")) + } + + if r.Header.Get("mcp-session-id") != sessionID { + t.Errorf("Expected session ID %s, got %s", sessionID, r.Header.Get("mcp-session-id")) + } + + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("Failed to decode request body: %v", err) + } + + if reqBody["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", reqBody["jsonrpc"]) + } + + if reqBody["method"] != "notifications/initialized" { + t.Errorf("Expected method notifications/initialized, got %v", reqBody["method"]) + } + + params, ok := reqBody["params"].(map[string]interface{}) + if !ok { + t.Fatalf("Missing or invalid params field") + } + + if len(params) != 0 { + t.Errorf("Expected empty params, got %v", params) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // This function doesn't return anything, so we just ensure it doesn't panic + xtragpt.MCPNotificationsInitialized(server.URL, sessionID) +} + +func TestMCPNotificationsInitialized_InvalidURL(t *testing.T) { + // This should not panic even with invalid URL + xtragpt.MCPNotificationsInitialized("invalid-url", "test-session") +} From 5260f8353cfa38fe32370b05a14851930de3ee6e Mon Sep 17 00:00:00 2001 From: 4ndrelim Date: Wed, 22 Oct 2025 05:16:13 +0800 Subject: [PATCH 3/7] Fix search_papers tool --- internal/services/toolkit/client/client.go | 2 +- .../toolkit/tools/xtragpt/search_papers.go | 82 +++++++++---------- 2 files changed, 39 insertions(+), 45 deletions(-) diff --git a/internal/services/toolkit/client/client.go b/internal/services/toolkit/client/client.go index 487351b0..ae397f84 100644 --- a/internal/services/toolkit/client/client.go +++ b/internal/services/toolkit/client/client.go @@ -56,7 +56,7 @@ func NewAIClient( // toolRegistry.Register("export_papers") // toolRegistry.Register("get_conference_papers") // toolRegistry.Register("get_user_papers") - toolRegistry.Register("search_papers", toolSearchPapers.Description, toolSearchPapers.Call) + toolRegistry.Register("search_relevant_papers", toolSearchPapers.Description, toolSearchPapers.Call) // toolRegistry.Register("search_user") // toolRegistry.Register("identify_improvements") // toolRegistry.Register("suggest_improvement") diff --git a/internal/services/toolkit/tools/xtragpt/search_papers.go b/internal/services/toolkit/tools/xtragpt/search_papers.go index 6c9405f0..b846044b 100644 --- a/internal/services/toolkit/tools/xtragpt/search_papers.go +++ b/internal/services/toolkit/tools/xtragpt/search_papers.go @@ -33,11 +33,6 @@ type MCPParams struct { Arguments map[string]interface{} `json:"arguments"` } -// Venue represents a conference venue with year -type Venue struct { - Venue string `json:"venue"` - Year string `json:"year"` -} type SearchPapersTool struct { Description responses.ToolUnionParam toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB @@ -47,37 +42,39 @@ type SearchPapersTool struct { client *http.Client } -var schema map[string]any - -var SearchPapersToolDescription = responses.ToolUnionParam{ - OfFunction: &responses.FunctionToolParam{ - Name: "search_papers", - Description: param.NewOpt("Search for papers by keywords within specific conference venues, with various matching modes."), - Parameters: openai.FunctionParameters(schema), - }, -} - func NewSearchPapersTool(db *db.DB, projectService *services.ProjectService) *SearchPapersTool { - json.Unmarshal([]byte(`{"properties":{"query":{"description":"Keywords / topics or content to search for (e.g., 'time series token merging', 'neural networks').","title":"Query","type":"string"},"venues":{"description":"List of conference venues and years to search in. Each entry must be a dict with 'venue' (e.g., 'ICLR.cc', 'NeurIPS.cc', 'ICML.cc'; users may omit '.cc') and 'year' (e.g., '2024', '2025').","items":{"additionalProperties":{"type":"string"},"type":"object"},"minItems":1,"title":"Venues","type":"array"},"search_fields":{"default":["title","abstract"],"description":"Fields to search within each paper. Options: 'title', 'abstract', 'authors'.","items":{"enum":["title","abstract","authors"],"type":"string"},"title":"Search Fields","type":"array"},"match_mode":{"default":"majority","description":"Match mode:\n- any: At least one keyword must match\n- all: All keywords must match\n- exact: Match the entire phrase exactly\n- majority: Match majority of keywords (>50%)\n- threshold: Match percentage of terms based on 'match_threshold'.","enum":["any","all","exact","majority","threshold"],"title":"Match Mode","type":"string"},"match_threshold":{"default":0.5,"description":"Minimum fraction (0.0-1.0) of search terms that must match when using 'threshold' mode. Example: 0.5 = 50% of terms must match.","maximum":1,"minimum":0,"title":"Match Threshold","type":"number"},"limit":{"default":10,"description":"Maximum number of results to return (1-16).","maximum":16,"minimum":1,"title":"Limit","type":"integer"},"min_score":{"default":0.6,"description":"Minimum match score (0.0-1.0). Lower values allow looser matches; higher values enforce stricter matches.","maximum":1,"minimum":0,"title":"Min Score","type":"number"}},"required":["query","venues"],"title":"search_papers_toolArguments","type":"object"}`), &schema) + // Create and populate schema + var schema map[string]any + json.Unmarshal([]byte(`{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks","...when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution"],"title":"Query","type":"string"},"top_k":{"description":"Number of top relevant or similar papers to return.","title":"Top K","type":"integer"},"date_min":{"description":"Minimum publication date (YYYY-MM-DD) to filter papers.","examples":["2023-01-01","2022-06-25"],"title":"Date Min","type":"string"},"date_max":{"description":"Maximum publication date (YYYY-MM-DD) to filter papers.","examples":["2024-12-31","2023-06-25"],"title":"Date Max","type":"string"},"countries":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"description":"List of country codes in ISO ALPHA-3 format to filter papers by author affiliations.","examples":[["USA","CHN","SGP","GBR","DEU","KOR","JPN"]],"title":"Countries"},"min_similarity":{"description":"Minimum similarity score (0.0-1.0) for returned papers. Higher values yield more relevant results but fewer papers.","examples":[0.3,0.5,0.7,0.9],"title":"Min Similarity","type":"number"}},"required":["query","top_k","countries","min_similarity"],"title":"search_papers_toolArguments","type":"object"}`), &schema) + + // Create tool description with populated schema + description := responses.ToolUnionParam{ + OfFunction: &responses.FunctionToolParam{ + Name: "search_relevant_papers", + Description: param.NewOpt("Search for similar or relevant papers by keywords against the local database of academic papers. This tool uses semantic search with vector embeddings to find the most relevant results. It is the default and recommended tool for paper searches."), + Parameters: openai.FunctionParameters(schema), + }, + } + toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db) return &SearchPapersTool{ - Description: SearchPapersToolDescription, + Description: description, toolCallRecordDB: toolCallRecordDB, projectService: projectService, coolDownTime: 5 * time.Minute, - baseURL: "http://xtragpt-mcp-server:8080/paper-score", + // baseURL: "http://xtragpt-mcp-server:8080/mcp", + baseURL: "http://localhost:8080/mcp", // For local development client: &http.Client{}, } } type SearchPapersToolArgs struct { - Limit int `json:"limit"` - MatchMode string `json:"matchMode"` - MatchThreshold float64 `json:"matchThreshold"` - MinScore float64 `json:"minScore"` - Query string `json:"query"` - Venues []Venue `json:"venues"` - SearchFields []string `json:"searchFields"` + Query string `json:"query"` + TopK int `json:"top_k"` + DateMin *string `json:"date_min,omitempty"` + DateMax *string `json:"date_max,omitempty"` + Countries []string `json:"countries"` + MinSimilarity float64 `json:"min_similarity"` } func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) { @@ -89,19 +86,18 @@ func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args jso // Create function call record record, err := t.toolCallRecordDB.Create(ctx, toolCallId, *t.Description.GetName(), map[string]any{ - "limit": argsMap.Limit, - "matchMode": argsMap.MatchMode, - "matchThreshold": argsMap.MatchThreshold, - "minScore": argsMap.MinScore, "query": argsMap.Query, - "venues": argsMap.Venues, - "searchFields": argsMap.SearchFields, + "top_k": argsMap.TopK, + "date_min": argsMap.DateMin, + "date_max": argsMap.DateMax, + "countries": argsMap.Countries, + "min_similarity": argsMap.MinSimilarity, }) if err != nil { return "", "", err } - respStr, err := t.SearchPaper(argsMap.Limit, argsMap.MatchMode, argsMap.MatchThreshold, argsMap.MinScore, argsMap.Query, argsMap.Venues, argsMap.SearchFields) + respStr, err := t.SearchPaper(argsMap.Query, argsMap.TopK, argsMap.DateMin, argsMap.DateMax, argsMap.Countries, argsMap.MinSimilarity) if err != nil { err = fmt.Errorf("failed to search paper: %v", err) t.toolCallRecordDB.OnError(ctx, record, err) @@ -119,7 +115,7 @@ func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args jso return respStr, "", nil } -func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThreshold float64, minScore float64, query string, venues []Venue, searchFields []string) (string, error) { +func (t *SearchPapersTool) SearchPaper(query string, topK int, dateMin *string, dateMax *string, countries []string, minSimilarity float64) (string, error) { sessionId, err := MCPInitialize(t.baseURL) if err != nil { fmt.Printf("Error initializing MCP: %v\n", err) @@ -135,15 +131,14 @@ func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThresho Method: "tools/call", ID: 2, Params: MCPParams{ - Name: "search_papers", + Name: "search_relevant_papers", Arguments: map[string]interface{}{ - "limit": limit, - "match_mode": matchMode, - "match_threshold": matchThreshold, - "min_score": minScore, - "query": query, - "search_fields": searchFields, - "venues": venues, + "query": query, + "top_k": topK, + "date_min": dateMin, + "date_max": dateMax, + "countries": countries, + "min_similarity": minSimilarity, }, }, } @@ -155,7 +150,7 @@ func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThresho } // Create HTTP request - req, err := http.NewRequest("POST", "http://localhost:8080/mcp", bytes.NewBuffer(jsonData)) + req, err := http.NewRequest("POST", t.baseURL, bytes.NewBuffer(jsonData)) if err != nil { return "", fmt.Errorf("failed to create HTTP request: %w", err) } @@ -166,8 +161,7 @@ func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThresho req.Header.Set("mcp-session-id", sessionId) // Make the request - client := &http.Client{} - resp, err := client.Do(req) + resp, err := t.client.Do(req) if err != nil { return "", fmt.Errorf("failed to make request: %w", err) } From 622ed605babc72997cbf32a2cbc04663a3923d04 Mon Sep 17 00:00:00 2001 From: 4ndrelim Date: Wed, 22 Oct 2025 06:39:12 +0800 Subject: [PATCH 4/7] feat: Add XtraMCP loader to handle init and ack --- .../services/toolkit/tools/xtramcp/loader.go | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 internal/services/toolkit/tools/xtramcp/loader.go diff --git a/internal/services/toolkit/tools/xtramcp/loader.go b/internal/services/toolkit/tools/xtramcp/loader.go new file mode 100644 index 00000000..70dfc185 --- /dev/null +++ b/internal/services/toolkit/tools/xtramcp/loader.go @@ -0,0 +1,206 @@ +package xtramcp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/services" + "paperdebugger/internal/services/toolkit/registry" +) + +// MCPToolsResponse represents the JSON-RPC response from your backend +type MCPToolsResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result struct { + Tools []ToolSchema `json:"tools"` + } `json:"result"` +} + +// loads tools dynamically from backend +type XtraMCPLoader struct { + db *db.DB + projectService *services.ProjectService + baseURL string + client *http.Client + sessionID string // Store the MCP session ID after initialization for re-use +} + +// NewXtraMCPLoader creates a new dynamic XtraMCP loader +func NewXtraMCPLoader(db *db.DB, projectService *services.ProjectService, baseURL string) *XtraMCPLoader { + return &XtraMCPLoader{ + db: db, + projectService: projectService, + baseURL: baseURL, + client: &http.Client{}, + } +} + +// LoadToolsFromBackend fetches tool schemas from backend and registers them +func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolRegistry) error { + // Initialize MCP session ONCE + sessionID, err := loader.initializeMCP() + if err != nil { + return fmt.Errorf("failed to initialize MCP: %w", err) + } + loader.sessionID = sessionID + + // Fetch tools from backend using the session (currently returns mock data) + toolSchemas, err := loader.fetchAvailableTools() + if err != nil { + return fmt.Errorf("failed to fetch tools from backend: %w", err) + } + + // Register each tool dynamically, passing the session ID + for _, toolSchema := range toolSchemas { + dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID) + + // Register the tool with the registry + toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call) + + fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name) + } + + return nil +} + +// initializeMCP performs the full MCP initialization handshake +func (loader *XtraMCPLoader) initializeMCP() (string, error) { + // Step 1: Initialize + sessionID, err := loader.performInitialize() + if err != nil { + return "", fmt.Errorf("step 1 - initialize failed: %w", err) + } + + // Step 2: Send notifications/initialized + err = loader.sendInitializedNotification(sessionID) + if err != nil { + return "", fmt.Errorf("step 2 - notifications/initialized failed: %w", err) + } + + return sessionID, nil +} + +// performInitialize performs MCP initialization (1. establish connection) +func (loader *XtraMCPLoader) performInitialize() (string, error) { + initReq := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1, + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{ + "name": "paperdebugger-client", + "version": "1.0.0", + }, + }, + } + + jsonData, err := json.Marshal(initReq) + if err != nil { + return "", fmt.Errorf("failed to marshal initialize request: %w", err) + } + + req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create initialize request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := loader.client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to make initialize request: %w", err) + } + defer resp.Body.Close() + + // Extract session ID from response headers + sessionID := resp.Header.Get("mcp-session-id") + if sessionID == "" { + return "", fmt.Errorf("no session ID returned from initialize") + } + + return sessionID, nil +} + +// sendInitializedNotification completes MCP initialization (acknowledges initialization) +func (loader *XtraMCPLoader) sendInitializedNotification(sessionID string) error { + notifyReq := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": map[string]interface{}{}, + } + + jsonData, err := json.Marshal(notifyReq) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create notification request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", sessionID) + + resp, err := loader.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send notification: %w", err) + } + defer resp.Body.Close() + + return nil +} + +// fetchAvailableTools makes a request to get available tools from backend +func (loader *XtraMCPLoader) fetchAvailableTools() ([]ToolSchema, error) { + // List all tools using the established session + requestBody := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "tools/list", + "params": map[string]interface{}{}, + "id": 2, + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", loader.sessionID) + + resp, err := loader.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + // Parse response + var mcpResponse MCPToolsResponse + err = json.NewDecoder(resp.Body).Decode(&mcpResponse) + if err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return mcpResponse.Result.Tools, nil + + // mock data; return hardcoded tool schemas for testing + // mockToolsJSON := `[{"name":"get_user_papers","description":"Fetch all papers published by a specific user identified by email. Supports 'summary' (abstract truncated to 150 words) and 'detailed' (full abstract).","inputSchema":{"properties":{"email":{"description":"Email address of the user whose papers to fetch. Must be a valid email string.","examples":["alice@example.com","bob@university.edu"],"title":"Email","type":"string"},"format":{"default":"detailed","description":"Format of the response. 'summary' shows title, venue, authors, URL, and the first 150 words of the abstract (default). 'detailed' shows the full abstract.","enum":["summary","detailed"],"examples":["summary","detailed"],"title":"Format","type":"string"}},"required":["email"],"title":"get_user_papers_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"get_user_papers_toolOutput","type":"object"}},{"name":"search_papers_on_openreview","description":"Search for academic papers on OpenReview by keywords within specific conference venues. This tool supports various matching modes and is ideal for discovering recent or broader papers beyond those available in the local database. Use this tool when results from search_relevant_papers are insufficient.","inputSchema":{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks"],"title":"Query","type":"string"},"venues":{"description":"List of conference venues and years to search in. Each entry must be a dict with 'venue' and 'year'.","examples":[{"venue":"ICLR.cc","year":"2024"},{"venue":"ICML","year":"2024"},{"venue":"NeurIPS.cc","year":"2023"},{"venue":"NeurIPS.cc","year":"2022"}],"items":{"additionalProperties":{"type":"string"},"type":"object"},"minItems":1,"title":"Venues","type":"array"},"search_fields":{"default":["title","abstract"],"description":"Fields to search within each paper. Options: 'title', 'abstract', 'authors'.","items":{"enum":["title","abstract","authors"],"type":"string"},"title":"Search Fields","type":"array"},"match_mode":{"default":"majority","description":"Match mode:\n- any: At least one keyword must match\n- all: All keywords must match\n- exact: Match the entire phrase exactly\n- majority: Match majority of keywords (>50%)\n- threshold: Match percentage of terms based on 'match_threshold'.","enum":["any","all","exact","majority","threshold"],"title":"Match Mode","type":"string"},"match_threshold":{"default":0.5,"description":"Minimum fraction (0.0-1.0) of search terms that must match when using 'threshold' mode. Example: 0.5 = 50% of terms must match.","maximum":1,"minimum":0,"title":"Match Threshold","type":"number"},"limit":{"default":10,"description":"Maximum number of results to return (1-16).","maximum":16,"minimum":1,"title":"Limit","type":"integer"},"min_score":{"default":0.6,"description":"Minimum match score (0.0-1.0). Lower values allow looser matches; higher values enforce stricter matches.","maximum":1,"minimum":0,"title":"Min Score","type":"number"}},"required":["query","venues"],"title":"search_papers_openreview_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"search_papers_openreview_toolOutput","type":"object"}},{"name":"search_relevant_papers","description":"Search for similar or relevant papers by keywords against the local database of academic papers. This tool uses semantic search with vector embeddings to find the most relevant results. It is the default and recommended tool for paper searches.","inputSchema":{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks","...when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution"],"title":"Query","type":"string"},"top_k":{"description":"Number of top relevant or similar papers to return.","title":"Top K","type":"integer"},"date_min":{"description":"Minimum publication date (YYYY-MM-DD) to filter papers.","examples":["2023-01-01","2022-06-25"],"title":"Date Min","type":"string"},"date_max":{"description":"Maximum publication date (YYYY-MM-DD) to filter papers.","examples":["2024-12-31","2023-06-25"],"title":"Date Max","type":"string"},"countries":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"description":"List of country codes in ISO ALPHA-3 format to filter papers by author affiliations.","examples":[["USA","CHN","SGP","GBR","DEU","KOR","JPN"]],"title":"Countries"},"min_similarity":{"description":"Minimum similarity score (0.0-1.0) for returned papers. Higher values yield more relevant results but fewer papers.","examples":[0.3,0.5,0.7,0.9],"title":"Min Similarity","type":"number"}},"required":["query","top_k","countries","min_similarity"],"title":"search_papers_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"search_papers_toolOutput","type":"object"}},{"name":"identify_improvements","description":"Analyzes a draft academic paper against the standards of top-tier ML conferences (ICLR, ICML, NeurIPS). Identifies issues in structure, completeness, clarity, and argumentation, then provides prioritized, actionable suggestions.","inputSchema":{"properties":{"paper_content":{"description":"The full text content of the academic paper draft. Paper content should not be truncated.","title":"Paper Content","type":"string"},"target_venue":{"default":"NeurIPS","description":"The target top-tier conference to tailor the feedback for.","enum":["ICLR","ICML","NeurIPS"],"title":"Target Venue","type":"string"},"focus_areas":{"anyOf":[{"items":{"enum":["Structure","Clarity","Evidence","Positioning","Style","Completeness","Soundness","Limitations"],"type":"string"},"type":"array"},{"type":"null"}],"default":null,"description":"List of specific areas to focus the analysis on. If empty, default areas are: {DEFAULT_FOCUS_AREAS}.","title":"Focus Areas"},"severity_threshold":{"default":"major","description":"The minimum severity level to report. 'major' will show blockers and major issues.","enum":["blocker","major","minor","nit"],"title":"Severity Threshold","type":"string"}},"required":["paper_content"],"title":"identify_improvementsArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"identify_improvementsOutput","type":"object"}},{"name":"enhance_academic_writing","description":"Suggest context-aware academic paper writing enhancements for selected text.","inputSchema":{"properties":{"full_paper_content":{"description":"Surrounding context from the manuscript (e.g., abstract, background, or several sections). This need not be the entire paper; providing a substantial excerpt helps tailor the tone, terminology, and level of detail to academic venues (journals and conferences).","examples":["In reinforcement learning, one could structure these metrics (previously for evaluation) as rewards that could be boosted during training (Sharma et al., 2021; Yadav et al., 2021; Deng et al., 2022; Liu et al., 2023a; Xu et al., 2024; Wang et al., 2024b), to optimize complex objective functions even at testing time (OpenAI, 2024). However, when reward weights remain static, the weakest metric (the 'short-board') becomes a bottleneck that restricts overall LLM effectiveness, which introduces the short-board effect in multi-reward optimization. For example, in Figure 2, when the scaled reward itself (or its growth trend) has not yet reached saturation, its update magnitude should accordingly be increased."],"title":"Full Paper Content","type":"string"},"selected_content":{"description":"The specific text excerpt selected for improvement from the paper.","examples":["...when the scaled reward itself (or its growth trend) has not yet reached saturation, its update magnitude should accordingly be increased..."],"title":"Selected Content","type":"string"}},"required":["full_paper_content","selected_content"],"title":"improve_academic_passage_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"improve_academic_passage_toolOutput","type":"object"}}]` + + // var mockTools []ToolSchema + // err := json.Unmarshal([]byte(mockToolsJSON), &mockTools) +} From 774dba6e22da97769e08cf5ca00532d92e3272fd Mon Sep 17 00:00:00 2001 From: 4ndrelim Date: Wed, 22 Oct 2025 06:39:44 +0800 Subject: [PATCH 5/7] feat: Add DynamicTool to represent generic XtraMCP tool --- internal/services/toolkit/client/client.go | 20 +- .../services/toolkit/tools/xtramcp/tool.go | 173 ++++++++++++++++++ 2 files changed, 183 insertions(+), 10 deletions(-) create mode 100644 internal/services/toolkit/tools/xtramcp/tool.go diff --git a/internal/services/toolkit/client/client.go b/internal/services/toolkit/client/client.go index ae397f84..6fa38197 100644 --- a/internal/services/toolkit/client/client.go +++ b/internal/services/toolkit/client/client.go @@ -9,7 +9,7 @@ import ( "paperdebugger/internal/services" "paperdebugger/internal/services/toolkit/handler" "paperdebugger/internal/services/toolkit/registry" - "paperdebugger/internal/services/toolkit/tools/xtragpt" + "paperdebugger/internal/services/toolkit/tools/xtramcp" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -42,7 +42,6 @@ func NewAIClient( option.WithAPIKey(cfg.OpenAIAPIKey), ) CheckOpenAIWorks(oaiClient, logger) - toolSearchPapers := xtragpt.NewSearchPapersTool(db, projectService) // toolPaperScore := tools.NewPaperScoreTool(db, projectService) // toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService) @@ -53,16 +52,17 @@ func NewAIClient( // toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call) // toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call) - // toolRegistry.Register("export_papers") - // toolRegistry.Register("get_conference_papers") - // toolRegistry.Register("get_user_papers") - toolRegistry.Register("search_relevant_papers", toolSearchPapers.Description, toolSearchPapers.Call) - // toolRegistry.Register("search_user") - // toolRegistry.Register("identify_improvements") - // toolRegistry.Register("suggest_improvement") + // Load tools dynamically from backend (TODO: Make URL configurable / Xtramcp url) + xtraMCPLoader := xtramcp.NewXtraMCPLoader(db, projectService, "http://localhost:8080/mcp") + err := xtraMCPLoader.LoadToolsFromBackend(toolRegistry) + if err != nil { + logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err) + // Fallback to static tools or return error based on your preference + } else { + logger.Info("[AI Client] Successfully loaded XtraMCP tools") + } toolCallHandler := handler.NewToolCallHandler(toolRegistry) - client := &AIClient{ openaiClient: &oaiClient, toolCallHandler: toolCallHandler, diff --git a/internal/services/toolkit/tools/xtramcp/tool.go b/internal/services/toolkit/tools/xtramcp/tool.go new file mode 100644 index 00000000..ba8cde7d --- /dev/null +++ b/internal/services/toolkit/tools/xtramcp/tool.go @@ -0,0 +1,173 @@ +package xtramcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/services" + toolCallRecordDB "paperdebugger/internal/services/toolkit/db" + "strings" + "time" + + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/packages/param" + "github.com/openai/openai-go/v2/responses" + "github.com/samber/lo" +) + +// ToolSchema represents the schema from your backend +type ToolSchema struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` + OutputSchema map[string]interface{} `json:"outputSchema"` +} + +// MCPRequest represents the JSON-RPC request structure +type MCPRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID int `json:"id"` + Params MCPParams `json:"params"` +} + +// MCPParams represents the parameters for the MCP request +type MCPParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// DynamicTool represents a generic tool that can handle any schema +type DynamicTool struct { + Name string + Description responses.ToolUnionParam + toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB + projectService *services.ProjectService + coolDownTime time.Duration + baseURL string + client *http.Client + schema map[string]interface{} + sessionID string // Reuse the session ID from initialization +} + +// NewDynamicTool creates a new dynamic tool from a schema +func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string) *DynamicTool { + // Create tool description with the schema + description := responses.ToolUnionParam{ + OfFunction: &responses.FunctionToolParam{ + Name: toolSchema.Name, + Description: param.NewOpt(toolSchema.Description), + Parameters: openai.FunctionParameters(toolSchema.InputSchema), + }, + } + + toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db) + return &DynamicTool{ + Name: toolSchema.Name, + Description: description, + toolCallRecordDB: toolCallRecordDB, + projectService: projectService, + coolDownTime: 5 * time.Minute, + baseURL: baseURL, + client: &http.Client{}, + schema: toolSchema.InputSchema, + sessionID: sessionID, // Store the session ID for reuse + } +} + +// Call handles the tool execution (generic for any tool) +func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) { + // Parse arguments as generic map since we don't know the structure + var argsMap map[string]interface{} + err := json.Unmarshal(args, &argsMap) + if err != nil { + return "", "", err + } + + // Create function call record + record, err := t.toolCallRecordDB.Create(ctx, toolCallId, t.Name, argsMap) + if err != nil { + return "", "", err + } + + // Execute the tool via MCP + respStr, err := t.executeTool(argsMap) + if err != nil { + err = fmt.Errorf("failed to execute tool %s: %v", t.Name, err) + t.toolCallRecordDB.OnError(ctx, record, err) + return "", "", err + } + + rawJson, err := json.Marshal(respStr) + if err != nil { + err = fmt.Errorf("failed to marshal tool result: %v", err) + t.toolCallRecordDB.OnError(ctx, record, err) + return "", "", err + } + t.toolCallRecordDB.OnSuccess(ctx, record, string(rawJson)) + + return respStr, "", nil +} + +// executeTool makes the MCP request (generic for any tool) +func (t *DynamicTool) executeTool(args map[string]interface{}) (string, error) { + // Use the stored session ID - no need to re-initialize! + fmt.Printf("Using existing sessionId for %s: %s\n", t.Name, t.sessionID) + + request := MCPRequest{ + JSONRPC: "2.0", + Method: "tools/call", + ID: int(time.Now().Unix()), // to ensure unique ID; TODO: consider better ID generation + Params: MCPParams{ + Name: t.Name, + Arguments: args, + }, + } + + // Marshal request to JSON + jsonData, err := json.Marshal(request) + if err != nil { + return "", fmt.Errorf("failed to marshal MCP request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequest("POST", t.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("mcp-session-id", t.sessionID) // Use the stored session ID + + // Make the request + resp, err := t.client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + fmt.Printf("Response body for %s: %s\n", t.Name, string(body)) + + // Parse response (assuming stream format) + lines := strings.Split(string(body), "\n") + lines = lo.Filter(lines, func(line string, _ int) bool { + return strings.HasPrefix(line, "data:") + }) + if len(lines) == 0 { + return "", fmt.Errorf("no data line found") + } + line := lines[0] + line = strings.TrimPrefix(line, "data: ") + return line, nil +} From a05ad8d12761d69fc458f209e6cf4e297a4a787c Mon Sep 17 00:00:00 2001 From: 4ndrelim Date: Wed, 22 Oct 2025 07:18:22 +0800 Subject: [PATCH 6/7] nit: Refactor for better structure --- internal/services/toolkit/client/client.go | 18 +++++-- .../services/toolkit/tools/xtramcp/loader.go | 51 +++++++++++++------ 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/internal/services/toolkit/client/client.go b/internal/services/toolkit/client/client.go index 6fa38197..2e8b091d 100644 --- a/internal/services/toolkit/client/client.go +++ b/internal/services/toolkit/client/client.go @@ -54,12 +54,22 @@ func NewAIClient( // Load tools dynamically from backend (TODO: Make URL configurable / Xtramcp url) xtraMCPLoader := xtramcp.NewXtraMCPLoader(db, projectService, "http://localhost:8080/mcp") - err := xtraMCPLoader.LoadToolsFromBackend(toolRegistry) + + // initialize MCP session first and log session ID + sessionID, err := xtraMCPLoader.InitializeMCP() if err != nil { - logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err) - // Fallback to static tools or return error based on your preference + logger.Errorf("[AI Client] Failed to initialize XtraMCP session: %v", err) + // TODO: Fallback to static tools or exit? } else { - logger.Info("[AI Client] Successfully loaded XtraMCP tools") + logger.Info("[AI Client] XtraMCP session initialized", "sessionID", sessionID) + + // dynamically load all tools from XtraMCP backend + err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry) + if err != nil { + logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err) + } else { + logger.Info("[AI Client] Successfully loaded XtraMCP tools") + } } toolCallHandler := handler.NewToolCallHandler(toolRegistry) diff --git a/internal/services/toolkit/tools/xtramcp/loader.go b/internal/services/toolkit/tools/xtramcp/loader.go index 70dfc185..3703e305 100644 --- a/internal/services/toolkit/tools/xtramcp/loader.go +++ b/internal/services/toolkit/tools/xtramcp/loader.go @@ -6,13 +6,14 @@ import ( "fmt" "io" "net/http" + "strings" "paperdebugger/internal/libs/db" "paperdebugger/internal/services" "paperdebugger/internal/services/toolkit/registry" ) -// MCPToolsResponse represents the JSON-RPC response from your backend -type MCPToolsResponse struct { +// MCPListToolsResponse represents the JSON-RPC response from tools/list method +type MCPListToolsResponse struct { JSONRPC string `json:"jsonrpc"` ID int `json:"id"` Result struct { @@ -41,14 +42,11 @@ func NewXtraMCPLoader(db *db.DB, projectService *services.ProjectService, baseUR // LoadToolsFromBackend fetches tool schemas from backend and registers them func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolRegistry) error { - // Initialize MCP session ONCE - sessionID, err := loader.initializeMCP() - if err != nil { - return fmt.Errorf("failed to initialize MCP: %w", err) + if loader.sessionID == "" { + return fmt.Errorf("MCP session not initialized - call InitializeMCP first") } - loader.sessionID = sessionID - // Fetch tools from backend using the session (currently returns mock data) + // Fetch tools from backend using the established session toolSchemas, err := loader.fetchAvailableTools() if err != nil { return fmt.Errorf("failed to fetch tools from backend: %w", err) @@ -67,8 +65,8 @@ func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolReg return nil } -// initializeMCP performs the full MCP initialization handshake -func (loader *XtraMCPLoader) initializeMCP() (string, error) { +// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it +func (loader *XtraMCPLoader) InitializeMCP() (string, error) { // Step 1: Initialize sessionID, err := loader.performInitialize() if err != nil { @@ -81,6 +79,9 @@ func (loader *XtraMCPLoader) initializeMCP() (string, error) { return "", fmt.Errorf("step 2 - notifications/initialized failed: %w", err) } + // Store session ID for future use and return it + loader.sessionID = sessionID + return sessionID, nil } @@ -189,15 +190,35 @@ func (loader *XtraMCPLoader) fetchAvailableTools() ([]ToolSchema, error) { } defer resp.Body.Close() - // Parse response - var mcpResponse MCPToolsResponse - err = json.NewDecoder(resp.Body).Decode(&mcpResponse) + // Read the raw response body (SSE format) for debugging + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Parse SSE format - extract JSON from "data: " lines + lines := strings.Split(string(bodyBytes), "\n") + var extractedJSON string + for _, line := range lines { + if strings.HasPrefix(line, "data: ") { + extractedJSON = strings.TrimPrefix(line, "data: ") + break + } + } + + if extractedJSON == "" { + return nil, fmt.Errorf("no data line found in SSE response") + } + + // Parse the extracted JSON + var mcpResponse MCPListToolsResponse + err = json.Unmarshal([]byte(extractedJSON), &mcpResponse) if err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) + return nil, fmt.Errorf("failed to parse JSON from SSE data: %w. JSON data: %s", err, extractedJSON) } return mcpResponse.Result.Tools, nil - + // mock data; return hardcoded tool schemas for testing // mockToolsJSON := `[{"name":"get_user_papers","description":"Fetch all papers published by a specific user identified by email. Supports 'summary' (abstract truncated to 150 words) and 'detailed' (full abstract).","inputSchema":{"properties":{"email":{"description":"Email address of the user whose papers to fetch. Must be a valid email string.","examples":["alice@example.com","bob@university.edu"],"title":"Email","type":"string"},"format":{"default":"detailed","description":"Format of the response. 'summary' shows title, venue, authors, URL, and the first 150 words of the abstract (default). 'detailed' shows the full abstract.","enum":["summary","detailed"],"examples":["summary","detailed"],"title":"Format","type":"string"}},"required":["email"],"title":"get_user_papers_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"get_user_papers_toolOutput","type":"object"}},{"name":"search_papers_on_openreview","description":"Search for academic papers on OpenReview by keywords within specific conference venues. This tool supports various matching modes and is ideal for discovering recent or broader papers beyond those available in the local database. Use this tool when results from search_relevant_papers are insufficient.","inputSchema":{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks"],"title":"Query","type":"string"},"venues":{"description":"List of conference venues and years to search in. Each entry must be a dict with 'venue' and 'year'.","examples":[{"venue":"ICLR.cc","year":"2024"},{"venue":"ICML","year":"2024"},{"venue":"NeurIPS.cc","year":"2023"},{"venue":"NeurIPS.cc","year":"2022"}],"items":{"additionalProperties":{"type":"string"},"type":"object"},"minItems":1,"title":"Venues","type":"array"},"search_fields":{"default":["title","abstract"],"description":"Fields to search within each paper. Options: 'title', 'abstract', 'authors'.","items":{"enum":["title","abstract","authors"],"type":"string"},"title":"Search Fields","type":"array"},"match_mode":{"default":"majority","description":"Match mode:\n- any: At least one keyword must match\n- all: All keywords must match\n- exact: Match the entire phrase exactly\n- majority: Match majority of keywords (>50%)\n- threshold: Match percentage of terms based on 'match_threshold'.","enum":["any","all","exact","majority","threshold"],"title":"Match Mode","type":"string"},"match_threshold":{"default":0.5,"description":"Minimum fraction (0.0-1.0) of search terms that must match when using 'threshold' mode. Example: 0.5 = 50% of terms must match.","maximum":1,"minimum":0,"title":"Match Threshold","type":"number"},"limit":{"default":10,"description":"Maximum number of results to return (1-16).","maximum":16,"minimum":1,"title":"Limit","type":"integer"},"min_score":{"default":0.6,"description":"Minimum match score (0.0-1.0). Lower values allow looser matches; higher values enforce stricter matches.","maximum":1,"minimum":0,"title":"Min Score","type":"number"}},"required":["query","venues"],"title":"search_papers_openreview_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"search_papers_openreview_toolOutput","type":"object"}},{"name":"search_relevant_papers","description":"Search for similar or relevant papers by keywords against the local database of academic papers. This tool uses semantic search with vector embeddings to find the most relevant results. It is the default and recommended tool for paper searches.","inputSchema":{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks","...when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution"],"title":"Query","type":"string"},"top_k":{"description":"Number of top relevant or similar papers to return.","title":"Top K","type":"integer"},"date_min":{"description":"Minimum publication date (YYYY-MM-DD) to filter papers.","examples":["2023-01-01","2022-06-25"],"title":"Date Min","type":"string"},"date_max":{"description":"Maximum publication date (YYYY-MM-DD) to filter papers.","examples":["2024-12-31","2023-06-25"],"title":"Date Max","type":"string"},"countries":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"description":"List of country codes in ISO ALPHA-3 format to filter papers by author affiliations.","examples":[["USA","CHN","SGP","GBR","DEU","KOR","JPN"]],"title":"Countries"},"min_similarity":{"description":"Minimum similarity score (0.0-1.0) for returned papers. Higher values yield more relevant results but fewer papers.","examples":[0.3,0.5,0.7,0.9],"title":"Min Similarity","type":"number"}},"required":["query","top_k","countries","min_similarity"],"title":"search_papers_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"search_papers_toolOutput","type":"object"}},{"name":"identify_improvements","description":"Analyzes a draft academic paper against the standards of top-tier ML conferences (ICLR, ICML, NeurIPS). Identifies issues in structure, completeness, clarity, and argumentation, then provides prioritized, actionable suggestions.","inputSchema":{"properties":{"paper_content":{"description":"The full text content of the academic paper draft. Paper content should not be truncated.","title":"Paper Content","type":"string"},"target_venue":{"default":"NeurIPS","description":"The target top-tier conference to tailor the feedback for.","enum":["ICLR","ICML","NeurIPS"],"title":"Target Venue","type":"string"},"focus_areas":{"anyOf":[{"items":{"enum":["Structure","Clarity","Evidence","Positioning","Style","Completeness","Soundness","Limitations"],"type":"string"},"type":"array"},{"type":"null"}],"default":null,"description":"List of specific areas to focus the analysis on. If empty, default areas are: {DEFAULT_FOCUS_AREAS}.","title":"Focus Areas"},"severity_threshold":{"default":"major","description":"The minimum severity level to report. 'major' will show blockers and major issues.","enum":["blocker","major","minor","nit"],"title":"Severity Threshold","type":"string"}},"required":["paper_content"],"title":"identify_improvementsArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"identify_improvementsOutput","type":"object"}},{"name":"enhance_academic_writing","description":"Suggest context-aware academic paper writing enhancements for selected text.","inputSchema":{"properties":{"full_paper_content":{"description":"Surrounding context from the manuscript (e.g., abstract, background, or several sections). This need not be the entire paper; providing a substantial excerpt helps tailor the tone, terminology, and level of detail to academic venues (journals and conferences).","examples":["In reinforcement learning, one could structure these metrics (previously for evaluation) as rewards that could be boosted during training (Sharma et al., 2021; Yadav et al., 2021; Deng et al., 2022; Liu et al., 2023a; Xu et al., 2024; Wang et al., 2024b), to optimize complex objective functions even at testing time (OpenAI, 2024). However, when reward weights remain static, the weakest metric (the 'short-board') becomes a bottleneck that restricts overall LLM effectiveness, which introduces the short-board effect in multi-reward optimization. For example, in Figure 2, when the scaled reward itself (or its growth trend) has not yet reached saturation, its update magnitude should accordingly be increased."],"title":"Full Paper Content","type":"string"},"selected_content":{"description":"The specific text excerpt selected for improvement from the paper.","examples":["...when the scaled reward itself (or its growth trend) has not yet reached saturation, its update magnitude should accordingly be increased..."],"title":"Selected Content","type":"string"}},"required":["full_paper_content","selected_content"],"title":"improve_academic_passage_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"improve_academic_passage_toolOutput","type":"object"}}]` From ca37a0a95d8ec3f2dcbe7591cc47a95a46bb0dbd Mon Sep 17 00:00:00 2001 From: 4ndrelim Date: Wed, 22 Oct 2025 07:23:59 +0800 Subject: [PATCH 7/7] nit: remove unnecessary files and prints --- .../toolkit/tools/xtragpt/search_papers.go | 188 ------------------ .../services/toolkit/tools/xtragpt/util.go | 135 ------------- .../toolkit/tools/xtragpt/util_test.go | 173 ---------------- .../services/toolkit/tools/xtramcp/loader.go | 6 - .../services/toolkit/tools/xtramcp/tool.go | 5 +- 5 files changed, 1 insertion(+), 506 deletions(-) delete mode 100644 internal/services/toolkit/tools/xtragpt/search_papers.go delete mode 100644 internal/services/toolkit/tools/xtragpt/util.go delete mode 100644 internal/services/toolkit/tools/xtragpt/util_test.go diff --git a/internal/services/toolkit/tools/xtragpt/search_papers.go b/internal/services/toolkit/tools/xtragpt/search_papers.go deleted file mode 100644 index b846044b..00000000 --- a/internal/services/toolkit/tools/xtragpt/search_papers.go +++ /dev/null @@ -1,188 +0,0 @@ -package xtragpt - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "paperdebugger/internal/libs/db" - "paperdebugger/internal/services" - toolCallRecordDB "paperdebugger/internal/services/toolkit/db" - "strings" - "time" - - "github.com/openai/openai-go/v2" - "github.com/openai/openai-go/v2/packages/param" - "github.com/openai/openai-go/v2/responses" - "github.com/samber/lo" -) - -// MCPRequest represents the JSON-RPC request structure -type MCPRequest struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - ID int `json:"id"` - Params MCPParams `json:"params"` -} - -// MCPParams represents the parameters for the MCP request -type MCPParams struct { - Name string `json:"name"` - Arguments map[string]interface{} `json:"arguments"` -} - -type SearchPapersTool struct { - Description responses.ToolUnionParam - toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB - projectService *services.ProjectService - coolDownTime time.Duration - baseURL string - client *http.Client -} - -func NewSearchPapersTool(db *db.DB, projectService *services.ProjectService) *SearchPapersTool { - // Create and populate schema - var schema map[string]any - json.Unmarshal([]byte(`{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks","...when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution"],"title":"Query","type":"string"},"top_k":{"description":"Number of top relevant or similar papers to return.","title":"Top K","type":"integer"},"date_min":{"description":"Minimum publication date (YYYY-MM-DD) to filter papers.","examples":["2023-01-01","2022-06-25"],"title":"Date Min","type":"string"},"date_max":{"description":"Maximum publication date (YYYY-MM-DD) to filter papers.","examples":["2024-12-31","2023-06-25"],"title":"Date Max","type":"string"},"countries":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"description":"List of country codes in ISO ALPHA-3 format to filter papers by author affiliations.","examples":[["USA","CHN","SGP","GBR","DEU","KOR","JPN"]],"title":"Countries"},"min_similarity":{"description":"Minimum similarity score (0.0-1.0) for returned papers. Higher values yield more relevant results but fewer papers.","examples":[0.3,0.5,0.7,0.9],"title":"Min Similarity","type":"number"}},"required":["query","top_k","countries","min_similarity"],"title":"search_papers_toolArguments","type":"object"}`), &schema) - - // Create tool description with populated schema - description := responses.ToolUnionParam{ - OfFunction: &responses.FunctionToolParam{ - Name: "search_relevant_papers", - Description: param.NewOpt("Search for similar or relevant papers by keywords against the local database of academic papers. This tool uses semantic search with vector embeddings to find the most relevant results. It is the default and recommended tool for paper searches."), - Parameters: openai.FunctionParameters(schema), - }, - } - - toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db) - return &SearchPapersTool{ - Description: description, - toolCallRecordDB: toolCallRecordDB, - projectService: projectService, - coolDownTime: 5 * time.Minute, - // baseURL: "http://xtragpt-mcp-server:8080/mcp", - baseURL: "http://localhost:8080/mcp", // For local development - client: &http.Client{}, - } -} - -type SearchPapersToolArgs struct { - Query string `json:"query"` - TopK int `json:"top_k"` - DateMin *string `json:"date_min,omitempty"` - DateMax *string `json:"date_max,omitempty"` - Countries []string `json:"countries"` - MinSimilarity float64 `json:"min_similarity"` -} - -func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) { - var argsMap SearchPapersToolArgs - err := json.Unmarshal(args, &argsMap) - if err != nil { - return "", "", err - } - - // Create function call record - record, err := t.toolCallRecordDB.Create(ctx, toolCallId, *t.Description.GetName(), map[string]any{ - "query": argsMap.Query, - "top_k": argsMap.TopK, - "date_min": argsMap.DateMin, - "date_max": argsMap.DateMax, - "countries": argsMap.Countries, - "min_similarity": argsMap.MinSimilarity, - }) - if err != nil { - return "", "", err - } - - respStr, err := t.SearchPaper(argsMap.Query, argsMap.TopK, argsMap.DateMin, argsMap.DateMax, argsMap.Countries, argsMap.MinSimilarity) - if err != nil { - err = fmt.Errorf("failed to search paper: %v", err) - t.toolCallRecordDB.OnError(ctx, record, err) - return "", "", err - } - - rawJson, err := json.Marshal(respStr) - if err != nil { - err = fmt.Errorf("failed to marshal paper search result: %v, rawJson: %v", err, string(rawJson)) - t.toolCallRecordDB.OnError(ctx, record, err) - return "", "", err - } - t.toolCallRecordDB.OnSuccess(ctx, record, string(rawJson)) - - return respStr, "", nil -} - -func (t *SearchPapersTool) SearchPaper(query string, topK int, dateMin *string, dateMax *string, countries []string, minSimilarity float64) (string, error) { - sessionId, err := MCPInitialize(t.baseURL) - if err != nil { - fmt.Printf("Error initializing MCP: %v\n", err) - return "", fmt.Errorf("failed to initialize MCP: %w", err) - } - if sessionId == "" { - return "", fmt.Errorf("failed to initialize MCP") - } - - fmt.Println("sessionId", sessionId) - request := MCPRequest{ - JSONRPC: "2.0", - Method: "tools/call", - ID: 2, - Params: MCPParams{ - Name: "search_relevant_papers", - Arguments: map[string]interface{}{ - "query": query, - "top_k": topK, - "date_min": dateMin, - "date_max": dateMax, - "countries": countries, - "min_similarity": minSimilarity, - }, - }, - } - - // Marshal request to JSON - jsonData, err := json.Marshal(request) - if err != nil { - return "", fmt.Errorf("failed to marshal MCP request: %w", err) - } - - // Create HTTP request - req, err := http.NewRequest("POST", t.baseURL, bytes.NewBuffer(jsonData)) - if err != nil { - return "", fmt.Errorf("failed to create HTTP request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - req.Header.Set("mcp-session-id", sessionId) - - // Make the request - resp, err := t.client.Do(req) - if err != nil { - return "", fmt.Errorf("failed to make request: %w", err) - } - defer resp.Body.Close() - - // Read response - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - fmt.Println("body", string(body)) - // split lines - lines := strings.Split(string(body), "\n") - // keep only the line starts with "data:" - lines = lo.Filter(lines, func(line string, _ int) bool { - return strings.HasPrefix(line, "data:") - }) - if len(lines) == 0 { - return "", fmt.Errorf("no data line found") - } - line := lines[0] - line = strings.TrimPrefix(line, "data: ") - return line, nil -} diff --git a/internal/services/toolkit/tools/xtragpt/util.go b/internal/services/toolkit/tools/xtragpt/util.go deleted file mode 100644 index 258aaa18..00000000 --- a/internal/services/toolkit/tools/xtragpt/util.go +++ /dev/null @@ -1,135 +0,0 @@ -package xtragpt - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" -) - -type InitializeRequest struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params struct { - ProtocolVersion string `json:"protocolVersion"` - Capabilities map[string]interface{} `json:"capabilities"` - ClientInfo struct { - Name string `json:"name"` - Version string `json:"version"` - } `json:"clientInfo"` - } `json:"params"` - ID int `json:"id"` -} - -type NotificationRequest struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]interface{} `json:"params"` -} - -type ToolsRequest struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - ID int `json:"id"` -} - -func MCPNotificationsInitialized(url string, sessionId string) { - notifyReq := NotificationRequest{ - JSONRPC: "2.0", - Method: "notifications/initialized", - Params: make(map[string]interface{}), - } - - // Marshal notification to JSON - notifyJSON, err := json.Marshal(notifyReq) - if err != nil { - fmt.Printf("Error marshaling notification JSON: %v\n", err) - return - } - - // Create HTTP client and request for notification - client := &http.Client{} - req, err := http.NewRequest("POST", url, bytes.NewBuffer(notifyJSON)) - if err != nil { - fmt.Printf("Error creating request: %v\n", err) - return - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - req.Header.Set("mcp-session-id", sessionId) - - // Make the notification request - notifyResp, err := client.Do(req) - if err != nil { - fmt.Printf("Error making notification request: %v\n", err) - return - } - defer notifyResp.Body.Close() - -} - -func MCPInitialize(url string) (string, error) { - initReq := InitializeRequest{ - JSONRPC: "2.0", - Method: "initialize", - ID: 1, - } - initReq.Params.ProtocolVersion = "2024-11-05" - initReq.Params.Capabilities = make(map[string]interface{}) - initReq.Params.ClientInfo.Name = "test-client" - initReq.Params.ClientInfo.Version = "1.0.0" - - // Marshal to JSON - jsonData, err := json.Marshal(initReq) - if err != nil { - fmt.Printf("Error marshaling JSON: %v\n", err) - return "", err - } - - // Make the first request - resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - fmt.Printf("Error making request: %v\n", err) - return "", err - } - defer resp.Body.Close() - fmt.Println("Initialize response", resp.Body, resp.Header) - - // Get session ID from headers - sessionID := resp.Header.Get("mcp-session-id") - fmt.Printf("Session ID: %s\n", sessionID) - - MCPNotificationsInitialized(url, sessionID) - fmt.Println("Initialized") - return sessionID, nil -} - -func MCPListTools(url string) ([]string, error) { - toolsReq := ToolsRequest{ - JSONRPC: "2.0", - Method: "tools/list", - ID: 1, - } - jsonData, err := json.Marshal(toolsReq) - if err != nil { - fmt.Printf("Error marshaling JSON: %v\n", err) - return nil, err - } - resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - fmt.Printf("Error making request: %v\n", err) - return nil, err - } - defer resp.Body.Close() - fmt.Println("List tools response", resp.Body, resp.Header) - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Printf("Error reading response: %v\n", err) - return nil, err - } - fmt.Println("List tools response", string(body)) - return nil, nil -} diff --git a/internal/services/toolkit/tools/xtragpt/util_test.go b/internal/services/toolkit/tools/xtragpt/util_test.go deleted file mode 100644 index 00dd75ca..00000000 --- a/internal/services/toolkit/tools/xtragpt/util_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package xtragpt_test - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "paperdebugger/internal/services/toolkit/tools/xtragpt" - "testing" -) - -func TestMCPInitialize_Success(t *testing.T) { - expectedSessionID := "test-session-123" - - // Mock server that handles both initialize and notifications/initialized requests - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - t.Errorf("Expected POST request, got %s", r.Method) - } - - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) - } - - // Parse request body to determine which request this is - var reqBody map[string]interface{} - if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { - t.Fatalf("Failed to decode request body: %v", err) - } - - method, ok := reqBody["method"].(string) - if !ok { - t.Fatalf("Missing or invalid method field") - } - - switch method { - case "initialize": - // Validate initialize request structure - if reqBody["jsonrpc"] != "2.0" { - t.Errorf("Expected jsonrpc 2.0, got %v", reqBody["jsonrpc"]) - } - - if reqBody["id"] != float64(1) { - t.Errorf("Expected id 1, got %v", reqBody["id"]) - } - - params, ok := reqBody["params"].(map[string]interface{}) - if !ok { - t.Fatalf("Missing or invalid params field") - } - - if params["protocolVersion"] != "2024-11-05" { - t.Errorf("Expected protocolVersion 2024-11-05, got %v", params["protocolVersion"]) - } - - clientInfo, ok := params["clientInfo"].(map[string]interface{}) - if !ok { - t.Fatalf("Missing or invalid clientInfo field") - } - - if clientInfo["name"] != "test-client" { - t.Errorf("Expected client name test-client, got %v", clientInfo["name"]) - } - - if clientInfo["version"] != "1.0.0" { - t.Errorf("Expected client version 1.0.0, got %v", clientInfo["version"]) - } - - // Set session ID header and return success response - w.Header().Set("mcp-session-id", expectedSessionID) - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, `{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05"}}`) - - case "notifications/initialized": - // Validate notifications/initialized request - if reqBody["jsonrpc"] != "2.0" { - t.Errorf("Expected jsonrpc 2.0, got %v", reqBody["jsonrpc"]) - } - - // Check session ID header - if r.Header.Get("mcp-session-id") != expectedSessionID { - t.Errorf("Expected session ID %s, got %s", expectedSessionID, r.Header.Get("mcp-session-id")) - } - - if r.Header.Get("Accept") != "application/json, text/event-stream" { - t.Errorf("Expected Accept header 'application/json, text/event-stream', got %s", r.Header.Get("Accept")) - } - - w.WriteHeader(http.StatusOK) - - default: - t.Errorf("Unexpected method: %s", method) - } - })) - defer server.Close() - - sessionID, err := xtragpt.MCPInitialize(server.URL) - - if err != nil { - t.Fatalf("MCPInitialize failed: %v", err) - } - - if sessionID != expectedSessionID { - t.Errorf("Expected session ID %s, got %s", expectedSessionID, sessionID) - } -} - -func TestMCPInitialize_InvalidURL(t *testing.T) { - sessionID, err := xtragpt.MCPInitialize("invalid-url") - - if err == nil { - t.Fatalf("Expected error for invalid URL, but got none") - } - - if sessionID != "" { - t.Errorf("Expected empty session ID on error, got %s", sessionID) - } -} - -func TestMCPNotificationsInitialized_Success(t *testing.T) { - sessionID := "test-session-456" - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - t.Errorf("Expected POST request, got %s", r.Method) - } - - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) - } - - if r.Header.Get("Accept") != "application/json, text/event-stream" { - t.Errorf("Expected Accept header 'application/json, text/event-stream', got %s", r.Header.Get("Accept")) - } - - if r.Header.Get("mcp-session-id") != sessionID { - t.Errorf("Expected session ID %s, got %s", sessionID, r.Header.Get("mcp-session-id")) - } - - var reqBody map[string]interface{} - if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { - t.Fatalf("Failed to decode request body: %v", err) - } - - if reqBody["jsonrpc"] != "2.0" { - t.Errorf("Expected jsonrpc 2.0, got %v", reqBody["jsonrpc"]) - } - - if reqBody["method"] != "notifications/initialized" { - t.Errorf("Expected method notifications/initialized, got %v", reqBody["method"]) - } - - params, ok := reqBody["params"].(map[string]interface{}) - if !ok { - t.Fatalf("Missing or invalid params field") - } - - if len(params) != 0 { - t.Errorf("Expected empty params, got %v", params) - } - - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - // This function doesn't return anything, so we just ensure it doesn't panic - xtragpt.MCPNotificationsInitialized(server.URL, sessionID) -} - -func TestMCPNotificationsInitialized_InvalidURL(t *testing.T) { - // This should not panic even with invalid URL - xtragpt.MCPNotificationsInitialized("invalid-url", "test-session") -} diff --git a/internal/services/toolkit/tools/xtramcp/loader.go b/internal/services/toolkit/tools/xtramcp/loader.go index 3703e305..974ff71b 100644 --- a/internal/services/toolkit/tools/xtramcp/loader.go +++ b/internal/services/toolkit/tools/xtramcp/loader.go @@ -218,10 +218,4 @@ func (loader *XtraMCPLoader) fetchAvailableTools() ([]ToolSchema, error) { } return mcpResponse.Result.Tools, nil - - // mock data; return hardcoded tool schemas for testing - // mockToolsJSON := `[{"name":"get_user_papers","description":"Fetch all papers published by a specific user identified by email. Supports 'summary' (abstract truncated to 150 words) and 'detailed' (full abstract).","inputSchema":{"properties":{"email":{"description":"Email address of the user whose papers to fetch. Must be a valid email string.","examples":["alice@example.com","bob@university.edu"],"title":"Email","type":"string"},"format":{"default":"detailed","description":"Format of the response. 'summary' shows title, venue, authors, URL, and the first 150 words of the abstract (default). 'detailed' shows the full abstract.","enum":["summary","detailed"],"examples":["summary","detailed"],"title":"Format","type":"string"}},"required":["email"],"title":"get_user_papers_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"get_user_papers_toolOutput","type":"object"}},{"name":"search_papers_on_openreview","description":"Search for academic papers on OpenReview by keywords within specific conference venues. This tool supports various matching modes and is ideal for discovering recent or broader papers beyond those available in the local database. Use this tool when results from search_relevant_papers are insufficient.","inputSchema":{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks"],"title":"Query","type":"string"},"venues":{"description":"List of conference venues and years to search in. Each entry must be a dict with 'venue' and 'year'.","examples":[{"venue":"ICLR.cc","year":"2024"},{"venue":"ICML","year":"2024"},{"venue":"NeurIPS.cc","year":"2023"},{"venue":"NeurIPS.cc","year":"2022"}],"items":{"additionalProperties":{"type":"string"},"type":"object"},"minItems":1,"title":"Venues","type":"array"},"search_fields":{"default":["title","abstract"],"description":"Fields to search within each paper. Options: 'title', 'abstract', 'authors'.","items":{"enum":["title","abstract","authors"],"type":"string"},"title":"Search Fields","type":"array"},"match_mode":{"default":"majority","description":"Match mode:\n- any: At least one keyword must match\n- all: All keywords must match\n- exact: Match the entire phrase exactly\n- majority: Match majority of keywords (>50%)\n- threshold: Match percentage of terms based on 'match_threshold'.","enum":["any","all","exact","majority","threshold"],"title":"Match Mode","type":"string"},"match_threshold":{"default":0.5,"description":"Minimum fraction (0.0-1.0) of search terms that must match when using 'threshold' mode. Example: 0.5 = 50% of terms must match.","maximum":1,"minimum":0,"title":"Match Threshold","type":"number"},"limit":{"default":10,"description":"Maximum number of results to return (1-16).","maximum":16,"minimum":1,"title":"Limit","type":"integer"},"min_score":{"default":0.6,"description":"Minimum match score (0.0-1.0). Lower values allow looser matches; higher values enforce stricter matches.","maximum":1,"minimum":0,"title":"Min Score","type":"number"}},"required":["query","venues"],"title":"search_papers_openreview_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"search_papers_openreview_toolOutput","type":"object"}},{"name":"search_relevant_papers","description":"Search for similar or relevant papers by keywords against the local database of academic papers. This tool uses semantic search with vector embeddings to find the most relevant results. It is the default and recommended tool for paper searches.","inputSchema":{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks","...when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution"],"title":"Query","type":"string"},"top_k":{"description":"Number of top relevant or similar papers to return.","title":"Top K","type":"integer"},"date_min":{"description":"Minimum publication date (YYYY-MM-DD) to filter papers.","examples":["2023-01-01","2022-06-25"],"title":"Date Min","type":"string"},"date_max":{"description":"Maximum publication date (YYYY-MM-DD) to filter papers.","examples":["2024-12-31","2023-06-25"],"title":"Date Max","type":"string"},"countries":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"description":"List of country codes in ISO ALPHA-3 format to filter papers by author affiliations.","examples":[["USA","CHN","SGP","GBR","DEU","KOR","JPN"]],"title":"Countries"},"min_similarity":{"description":"Minimum similarity score (0.0-1.0) for returned papers. Higher values yield more relevant results but fewer papers.","examples":[0.3,0.5,0.7,0.9],"title":"Min Similarity","type":"number"}},"required":["query","top_k","countries","min_similarity"],"title":"search_papers_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"search_papers_toolOutput","type":"object"}},{"name":"identify_improvements","description":"Analyzes a draft academic paper against the standards of top-tier ML conferences (ICLR, ICML, NeurIPS). Identifies issues in structure, completeness, clarity, and argumentation, then provides prioritized, actionable suggestions.","inputSchema":{"properties":{"paper_content":{"description":"The full text content of the academic paper draft. Paper content should not be truncated.","title":"Paper Content","type":"string"},"target_venue":{"default":"NeurIPS","description":"The target top-tier conference to tailor the feedback for.","enum":["ICLR","ICML","NeurIPS"],"title":"Target Venue","type":"string"},"focus_areas":{"anyOf":[{"items":{"enum":["Structure","Clarity","Evidence","Positioning","Style","Completeness","Soundness","Limitations"],"type":"string"},"type":"array"},{"type":"null"}],"default":null,"description":"List of specific areas to focus the analysis on. If empty, default areas are: {DEFAULT_FOCUS_AREAS}.","title":"Focus Areas"},"severity_threshold":{"default":"major","description":"The minimum severity level to report. 'major' will show blockers and major issues.","enum":["blocker","major","minor","nit"],"title":"Severity Threshold","type":"string"}},"required":["paper_content"],"title":"identify_improvementsArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"identify_improvementsOutput","type":"object"}},{"name":"enhance_academic_writing","description":"Suggest context-aware academic paper writing enhancements for selected text.","inputSchema":{"properties":{"full_paper_content":{"description":"Surrounding context from the manuscript (e.g., abstract, background, or several sections). This need not be the entire paper; providing a substantial excerpt helps tailor the tone, terminology, and level of detail to academic venues (journals and conferences).","examples":["In reinforcement learning, one could structure these metrics (previously for evaluation) as rewards that could be boosted during training (Sharma et al., 2021; Yadav et al., 2021; Deng et al., 2022; Liu et al., 2023a; Xu et al., 2024; Wang et al., 2024b), to optimize complex objective functions even at testing time (OpenAI, 2024). However, when reward weights remain static, the weakest metric (the 'short-board') becomes a bottleneck that restricts overall LLM effectiveness, which introduces the short-board effect in multi-reward optimization. For example, in Figure 2, when the scaled reward itself (or its growth trend) has not yet reached saturation, its update magnitude should accordingly be increased."],"title":"Full Paper Content","type":"string"},"selected_content":{"description":"The specific text excerpt selected for improvement from the paper.","examples":["...when the scaled reward itself (or its growth trend) has not yet reached saturation, its update magnitude should accordingly be increased..."],"title":"Selected Content","type":"string"}},"required":["full_paper_content","selected_content"],"title":"improve_academic_passage_toolArguments","type":"object"},"outputSchema":{"properties":{"result":{"title":"Result","type":"string"}},"required":["result"],"title":"improve_academic_passage_toolOutput","type":"object"}}]` - - // var mockTools []ToolSchema - // err := json.Unmarshal([]byte(mockToolsJSON), &mockTools) } diff --git a/internal/services/toolkit/tools/xtramcp/tool.go b/internal/services/toolkit/tools/xtramcp/tool.go index ba8cde7d..12156837 100644 --- a/internal/services/toolkit/tools/xtramcp/tool.go +++ b/internal/services/toolkit/tools/xtramcp/tool.go @@ -115,9 +115,7 @@ func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.Raw // executeTool makes the MCP request (generic for any tool) func (t *DynamicTool) executeTool(args map[string]interface{}) (string, error) { - // Use the stored session ID - no need to re-initialize! - fmt.Printf("Using existing sessionId for %s: %s\n", t.Name, t.sessionID) - + request := MCPRequest{ JSONRPC: "2.0", Method: "tools/call", @@ -157,7 +155,6 @@ func (t *DynamicTool) executeTool(args map[string]interface{}) (string, error) { if err != nil { return "", fmt.Errorf("failed to read response: %w", err) } - fmt.Printf("Response body for %s: %s\n", t.Name, string(body)) // Parse response (assuming stream format) lines := strings.Split(string(body), "\n")