Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/services/toolkit/client/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,17 @@ func initializeToolkit(
// initialize MCP session first and log session ID
sessionID, err := xtraMCPLoader.InitializeMCP()
if err != nil {
logger.Errorf("[AI Client] Failed to initialize XtraMCP session: %v", err)
logger.Errorf("[XtraMCP Client] Failed to initialize XtraMCP session: %v", err)
// TODO: Fallback to static tools or exit?
} else {
logger.Info("[AI Client] XtraMCP session initialized", "sessionID", sessionID)
logger.Info("[XtraMCP Client] XtraMCP session initialized", "sessionID", sessionID)

// dynamically load all tools from XtraMCP backend
err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry)
if err != nil {
logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err)
logger.Errorf("[XtraMCP Client] Failed to load XtraMCP tools: %v", err)
} else {
logger.Info("[AI Client] Successfully loaded XtraMCP tools")
logger.Info("[XtraMCP Client] Successfully loaded XtraMCP tools")
}
}

Expand Down
13 changes: 5 additions & 8 deletions internal/services/toolkit/client/utils_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,20 @@ func initializeToolkitV2(

logger.Info("[AI Client V2] Registered static LaTeX tools", "count", 0)

// Load tools dynamically from backend
// // Load tools dynamically from backend
// xtraMCPLoader := xtramcp.NewXtraMCPLoaderV2(db, projectService, cfg.XtraMCPURI)

// initialize MCP session first and log session ID
// // initialize MCP session first and log session ID
// sessionID, err := xtraMCPLoader.InitializeMCP()
// if err != nil {
// logger.Errorf("[AI Client V2] Failed to initialize XtraMCP session: %v", err)
// // TODO: Fallback to static tools or exit?
// logger.Errorf("[XtraMCP Client] Failed to initialize XtraMCP session: %v", err)
// } else {
// logger.Info("[AI Client V2] XtraMCP session initialized", "sessionID", sessionID)
// logger.Info("[XtraMCP Client] XtraMCP session initialized", "sessionID", sessionID)

// // dynamically load all tools from XtraMCP backend
// err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry)
// if err != nil {
// logger.Errorf("[AI Client V2] Failed to load XtraMCP tools: %v", err)
// } else {
// logger.Info("[AI Client V2] Successfully loaded XtraMCP tools")
// logger.Errorf("[XtraMCP Client] Failed to load XtraMCP tools: %v", err)
// }
// }

Expand Down
32 changes: 30 additions & 2 deletions internal/services/toolkit/tools/xtramcp/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,45 @@ func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolReg

// Register each tool dynamically, passing the session ID
for _, toolSchema := range toolSchemas {
dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)
// some tools require secrutiy context injection e.g. user_id to authenticate
requiresInjection := loader.requiresSecurityInjection(toolSchema)

dynamicTool := NewDynamicTool(
loader.db,
loader.projectService,
toolSchema,
loader.baseURL,
loader.sessionID,
requiresInjection,
)

// Register the tool with the registry
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)

fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
if requiresInjection {
fmt.Printf("Registered dynamic tool with security injection: %s\n", toolSchema.Name)
} else {
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
}
}

return nil
}

// checks if a tool schema contains parameters that should be inejected instead of LLM-generated
func (loader *XtraMCPLoader) requiresSecurityInjection(schema ToolSchema) bool {
properties, ok := schema.InputSchema["properties"].(map[string]interface{})
if !ok {
return false
}

// injected parameters
_, hasUserId := properties["user_id"]
_, hasProjectId := properties["project_id"]

return hasUserId || hasProjectId
}

// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
func (loader *XtraMCPLoader) InitializeMCP() (string, error) {
// Step 1: Initialize
Expand Down
32 changes: 30 additions & 2 deletions internal/services/toolkit/tools/xtramcp/loader_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,45 @@ func (loader *XtraMCPLoaderV2) LoadToolsFromBackend(toolRegistry *registry.ToolR

// Register each tool dynamically, passing the session ID
for _, toolSchema := range toolSchemas {
dynamicTool := NewDynamicToolV2(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)
// some tools require security context injection e.g. user_id to authenticate
requiresInjection := loader.requiresSecurityInjection(toolSchema)

dynamicTool := NewDynamicToolV2(
loader.db,
loader.projectService,
toolSchema,
loader.baseURL,
loader.sessionID,
requiresInjection,
)

// Register the tool with the registry
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)

fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
if requiresInjection {
fmt.Printf("Registered dynamic tool with security injection: %s\n", toolSchema.Name)
} else {
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
}
}

return nil
}

// checks if a tool schema contains parameters that should be injected instead of LLM-generated
func (loader *XtraMCPLoaderV2) requiresSecurityInjection(schema ToolSchemaV2) bool {
properties, ok := schema.InputSchema["properties"].(map[string]interface{})
if !ok {
return false
}

// injected parameters
_, hasUserId := properties["user_id"]
_, hasProjectId := properties["project_id"]

return hasUserId || hasProjectId
}

// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
func (loader *XtraMCPLoaderV2) InitializeMCP() (string, error) {
// Step 1: Initialize
Expand Down
68 changes: 68 additions & 0 deletions internal/services/toolkit/tools/xtramcp/schema_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package xtramcp

import "encoding/json"

// parameters that should be injected server-side
var securityParameters = []string{"user_id", "project_id"}

// removes security parameters from schema shown to LLM so LLM does not need to generate / fill
func filterSecurityParameters(schema map[string]interface{}) map[string]interface{} {
filtered := deepCopySchema(schema)

// Remove from properties
if properties, ok := filtered["properties"].(map[string]interface{}); ok {
for _, param := range securityParameters {
delete(properties, param)
}
}

// Remove from required array
if required, ok := filtered["required"].([]interface{}); ok {
filtered["required"] = filterRequiredArray(required, securityParameters)
}

return filtered
}

// creates a deep copy of the schema using JSON marshal/unmarshal
// Uses JSON round-trip because map[string]interface{} contains nested structures
// This ensures modifications to the copy don't affect the original schema.
func deepCopySchema(schema map[string]interface{}) map[string]interface{} {
// Use JSON marshal/unmarshal for deep copy
jsonBytes, err := json.Marshal(schema)
if err != nil {
// Extremely unlikely with valid JSON schemas (MCP schemas are JSON-compatible)
// // If marshaling fails, return original schema
return schema
}

var copy map[string]interface{}
err = json.Unmarshal(jsonBytes, &copy)
if err != nil {
// Should never happen if marshal succeeded
return schema
}

return copy
}

// removes security parameters from the required array
func filterRequiredArray(required []interface{}, toRemove []string) []interface{} {
filtered := []interface{}{}
removeMap := make(map[string]bool)

for _, r := range toRemove {
removeMap[r] = true
}

// filter out security params
for _, item := range required {
if str, ok := item.(string); ok {
if !removeMap[str] {
filtered = append(filtered, item)
}
}
}

return filtered
}
82 changes: 31 additions & 51 deletions internal/services/toolkit/tools/xtramcp/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,74 +41,54 @@ type MCPParams struct {

// DynamicTool represents a generic tool that can handle any schema
type DynamicTool struct {
Name string
Description responses.ToolUnionParam
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
projectService *services.ProjectService
coolDownTime time.Duration
baseURL string
client *http.Client
schema map[string]interface{}
sessionID string // Reuse the session ID from initialization
Name string
Description responses.ToolUnionParam
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
projectService *services.ProjectService
coolDownTime time.Duration
baseURL string
client *http.Client
schema map[string]interface{}
sessionID string // Reuse the session ID from initialization
requiresInjection bool // Indicates if this tool needs user/project injection
}

// NewDynamicTool creates a new dynamic tool from a schema
func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string) *DynamicTool {
// Create tool description with the schema
func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string, requiresInjection bool) *DynamicTool {
// filter schema if injection is required (hide security context like user_id/project_id from LLM)
schemaForLLM := toolSchema.InputSchema
if requiresInjection {
schemaForLLM = filterSecurityParameters(toolSchema.InputSchema)
}

description := responses.ToolUnionParam{
OfFunction: &responses.FunctionToolParam{
Name: toolSchema.Name,
Description: param.NewOpt(toolSchema.Description),
Parameters: openai.FunctionParameters(toolSchema.InputSchema),
Parameters: openai.FunctionParameters(schemaForLLM), // Use filtered schema
},
}

toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db)
//TODO: consider letting llm client know of output schema too
return &DynamicTool{
Name: toolSchema.Name,
Description: description,
toolCallRecordDB: toolCallRecordDB,
projectService: projectService,
coolDownTime: 5 * time.Minute,
baseURL: baseURL,
client: &http.Client{},
schema: toolSchema.InputSchema,
sessionID: sessionID, // Store the session ID for reuse
Name: toolSchema.Name,
Description: description,
toolCallRecordDB: toolCallRecordDB,
projectService: projectService,
coolDownTime: 5 * time.Minute,
baseURL: baseURL,
client: &http.Client{},
schema: toolSchema.InputSchema, // Store original schema for validation
sessionID: sessionID, // Store the session ID for reuse
requiresInjection: requiresInjection,
}
}

// Call handles the tool execution (generic for any tool)
// DEPRECATED: v1 API is no longer supported. This method should not be called.
func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) {
// Parse arguments as generic map since we don't know the structure
var argsMap map[string]interface{}
err := json.Unmarshal(args, &argsMap)
if err != nil {
return "", "", err
}

// Create function call record
record, err := t.toolCallRecordDB.Create(ctx, toolCallId, t.Name, argsMap)
if err != nil {
return "", "", err
}

// Execute the tool via MCP
respStr, err := t.executeTool(argsMap)
if err != nil {
err = fmt.Errorf("failed to execute tool %s: %v", t.Name, err)
t.toolCallRecordDB.OnError(ctx, record, err)
return "", "", err
}

rawJson, err := json.Marshal(respStr)
if err != nil {
err = fmt.Errorf("failed to marshal tool result: %v", err)
t.toolCallRecordDB.OnError(ctx, record, err)
return "", "", err
}
t.toolCallRecordDB.OnSuccess(ctx, record, string(rawJson))

return respStr, "", nil
return "", "", fmt.Errorf("v1 API is deprecated and no longer supported. Please use v2 API instead")
}

// executeTool makes the MCP request (generic for any tool)
Expand Down
Loading