From 00fef77fd9b307a76d1c06fafe9588b12fea0334 Mon Sep 17 00:00:00 2001 From: Simon Zhu Date: Thu, 5 Mar 2026 06:57:51 -0500 Subject: [PATCH] feat: add MCP tool server discovery and invocation endpoint Adds three new MCP tools to the kagent-controller's existing MCP endpoint at :8083/mcp, enabling dynamic discovery and invocation of tools across all tool source types: - list_tool_servers: Lists all tool servers (RemoteMCPServer, Service with kagent.dev/mcp-service=true label, MCPServer CRs) - list_tools: Connects to a tool server and returns its tool catalog - call_tool: Invokes a specific tool on a specific tool server This moves the functionality originally proposed in kagent-dev/kmcp#123 into kagent per reviewer feedback, since kagent already watches all three resource types and has the existing MCP handler infrastructure. Key design decisions: - Unified ref format: Kind/namespace/name (e.g. RemoteMCPServer/default/my-server) - Session caching with evict-and-retry for stale connections - Reuses existing ConvertServiceToRemoteMCPServer and ConvertMCPServerToRemoteMCPServer from the translator package - MCPServer CRD is optional (graceful degradation if not installed) Co-Authored-By: Claude Opus 4.6 Signed-off-by: Simon Zhu --- .../controller/reconciler/reconciler.go | 57 +-- go/core/internal/mcp/mcp_handler.go | 43 +- .../internal/mcp/mcp_tool_server_handler.go | 480 ++++++++++++++++++ .../mcp/mcp_tool_server_handler_test.go | 471 +++++++++++++++++ go/core/internal/mcp/transport.go | 79 +++ 5 files changed, 1072 insertions(+), 58 deletions(-) create mode 100644 go/core/internal/mcp/mcp_tool_server_handler.go create mode 100644 go/core/internal/mcp/mcp_tool_server_handler_test.go create mode 100644 go/core/internal/mcp/transport.go diff --git a/go/core/internal/controller/reconciler/reconciler.go b/go/core/internal/controller/reconciler/reconciler.go index 857f23e77..581f5f206 100644 --- a/go/core/internal/controller/reconciler/reconciler.go +++ b/go/core/internal/controller/reconciler/reconciler.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "errors" "fmt" - "net/http" "reflect" "slices" "strings" @@ -26,6 +25,7 @@ import ( "github.com/kagent-dev/kagent/go/api/v1alpha2" "github.com/kagent-dev/kagent/go/core/internal/controller/provider" agent_translator "github.com/kagent-dev/kagent/go/core/internal/controller/translator/agent" + mcputil "github.com/kagent-dev/kagent/go/core/internal/mcp" "github.com/kagent-dev/kagent/go/core/internal/utils" "github.com/kagent-dev/kagent/go/core/internal/version" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -784,7 +784,7 @@ func (a *kagentReconciler) upsertToolServerForRemoteMCPServer(ctx context.Contex return nil, fmt.Errorf("failed to store toolServer %s: %w", toolServer.Name, err) } - tsp, err := a.createMcpTransport(ctx, remoteMcpServer) + tsp, err := mcputil.CreateMCPTransport(ctx, a.kube, remoteMcpServer) if err != nil { return nil, fmt.Errorf("failed to create client for toolServer %s: %w", toolServer.Name, err) } @@ -809,59 +809,6 @@ func (a *kagentReconciler) isNamespaceWatched(namespace string) bool { return slices.Contains(a.watchedNamespaces, namespace) } -func (a *kagentReconciler) createMcpTransport(ctx context.Context, s *v1alpha2.RemoteMCPServer) (mcp.Transport, error) { - headers, err := s.ResolveHeaders(ctx, a.kube) - if err != nil { - return nil, err - } - - httpClient := newHTTPClient(headers) - - switch s.Spec.Protocol { - case v1alpha2.RemoteMCPServerProtocolSse: - return &mcp.SSEClientTransport{ - Endpoint: s.Spec.URL, - HTTPClient: httpClient, - }, nil - default: - return &mcp.StreamableClientTransport{ - Endpoint: s.Spec.URL, - HTTPClient: httpClient, - }, nil - } -} - -// go-sdk does not have a WithHeaders option when initializing transport -// so we need to create a custom HTTP client that adds headers to all requests. -func newHTTPClient(headers map[string]string) *http.Client { - if len(headers) == 0 { - return http.DefaultClient - } - return &http.Client{ - Transport: &headerTransport{ - headers: headers, - base: http.DefaultTransport, - }, - } -} - -// headerTransport is an http.RoundTripper that adds custom headers to requests. -type headerTransport struct { - headers map[string]string - base http.RoundTripper -} - -func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - for k, v := range t.headers { - req.Header.Set(k, v) - } - if t.base == nil { - t.base = http.DefaultTransport - } - return t.base.RoundTrip(req) -} - func (a *kagentReconciler) listTools(ctx context.Context, tsp mcp.Transport, toolServer *database.ToolServer) ([]*v1alpha2.MCPTool, error) { impl := &mcp.Implementation{ Name: "kagent-controller", diff --git a/go/core/internal/mcp/mcp_handler.go b/go/core/internal/mcp/mcp_handler.go index c15798e57..77fd09f36 100644 --- a/go/core/internal/mcp/mcp_handler.go +++ b/go/core/internal/mcp/mcp_handler.go @@ -22,6 +22,7 @@ import ( ) // MCPHandler handles MCP requests and bridges them to A2A endpoints +// and tool server discovery/invocation. type MCPHandler struct { kubeClient client.Client a2aBaseURL string @@ -29,6 +30,7 @@ type MCPHandler struct { httpHandler *mcpsdk.StreamableHTTPHandler server *mcpsdk.Server a2aClients sync.Map + sessions sync.Map // cached MCP client sessions keyed by "Kind/namespace/name" } // Input types for MCP tools @@ -92,6 +94,36 @@ func NewMCPHandler(kubeClient client.Client, a2aBaseURL string, authenticator au handler.handleInvokeAgent, ) + // Add list_tool_servers tool + mcpsdk.AddTool[ListToolServersInput, ListToolServersOutput]( + server, + &mcpsdk.Tool{ + Name: "list_tool_servers", + Description: "List all MCP tool servers in the cluster (RemoteMCPServer, Service, MCPServer)", + }, + handler.handleListToolServers, + ) + + // Add list_tools tool + mcpsdk.AddTool[ListToolsInput, ListToolsOutput]( + server, + &mcpsdk.Tool{ + Name: "list_tools", + Description: "Connect to a tool server and list its available tools", + }, + handler.handleListTools, + ) + + // Add call_tool tool + mcpsdk.AddTool[CallToolInput, CallToolOutput]( + server, + &mcpsdk.Tool{ + Name: "call_tool", + Description: "Invoke a specific tool on a specific tool server", + }, + handler.handleCallTool, + ) + // Create HTTP handler handler.httpHandler = mcpsdk.NewStreamableHTTPHandler( func(*http.Request) *mcpsdk.Server { @@ -309,9 +341,14 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.httpHandler.ServeHTTP(w, r) } -// Shutdown gracefully shuts down the MCP handler +// Shutdown gracefully shuts down the MCP handler and closes cached sessions. func (h *MCPHandler) Shutdown(ctx context.Context) error { - // The new SDK doesn't have an explicit Shutdown method on StreamableHTTPHandler - // The server will be shut down when the context is cancelled + h.sessions.Range(func(key, value any) bool { + if session, ok := value.(*mcpsdk.ClientSession); ok { + session.Close() + } + h.sessions.Delete(key) + return true + }) return nil } diff --git a/go/core/internal/mcp/mcp_tool_server_handler.go b/go/core/internal/mcp/mcp_tool_server_handler.go new file mode 100644 index 000000000..ddf58afde --- /dev/null +++ b/go/core/internal/mcp/mcp_tool_server_handler.go @@ -0,0 +1,480 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mcp + +import ( + "context" + "fmt" + "strings" + + "github.com/kagent-dev/kagent/go/api/v1alpha2" + agent_translator "github.com/kagent-dev/kagent/go/core/internal/controller/translator/agent" + "github.com/kagent-dev/kagent/go/core/internal/version" + "github.com/kagent-dev/kmcp/api/v1alpha1" + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + ctrllog "sigs.k8s.io/controller-runtime/pkg/log" +) + +// --- Input/Output types for tool server tools --- + +type ListToolServersInput struct { + Namespace string `json:"namespace,omitempty" jsonschema:"Optional namespace filter"` +} + +type ListToolServersOutput struct { + Servers []ToolServerSummary `json:"servers"` +} + +type ToolServerSummary struct { + Ref string `json:"ref"` // "Kind/namespace/name" + Kind string `json:"kind"` // "RemoteMCPServer", "Service", "MCPServer" + URL string `json:"url"` // resolved endpoint URL + Protocol string `json:"protocol"` // "STREAMABLE_HTTP" or "SSE" + Status string `json:"status,omitempty"` // "Ready" / "NotReady" (MCPServer only) +} + +type ListToolsInput struct { + Server string `json:"server" jsonschema:"Tool server reference in Kind/namespace/name format (e.g. RemoteMCPServer/default/my-server)"` +} + +type ListToolsOutput struct { + Server string `json:"server"` + Tools []ToolInfo `json:"tools"` +} + +type ToolInfo struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema any `json:"inputSchema,omitempty"` +} + +type CallToolInput struct { + Server string `json:"server" jsonschema:"Tool server reference in Kind/namespace/name format"` + Tool string `json:"tool" jsonschema:"Tool name to invoke"` + Arguments map[string]any `json:"arguments,omitempty" jsonschema:"Tool arguments as JSON object"` +} + +type CallToolOutput struct { + Server string `json:"server"` + Tool string `json:"tool"` + Content any `json:"content"` + IsError bool `json:"isError"` +} + +// --- Ref parsing --- + +// parseServerRef parses a "Kind/namespace/name" reference into components. +func parseServerRef(ref string) (kind, namespace, name string, err error) { + parts := strings.SplitN(ref, "/", 3) + if len(parts) != 3 || parts[0] == "" || parts[1] == "" || parts[2] == "" { + return "", "", "", fmt.Errorf("invalid server reference %q: must be Kind/namespace/name (e.g. RemoteMCPServer/default/my-server)", ref) + } + kind, namespace, name = parts[0], parts[1], parts[2] + switch kind { + case "RemoteMCPServer", "Service", "MCPServer": + return kind, namespace, name, nil + default: + return "", "", "", fmt.Errorf("unknown server kind %q: must be RemoteMCPServer, Service, or MCPServer", kind) + } +} + +// resolveToRemoteMCPServer fetches the referenced resource and converts it to a RemoteMCPServer +// so all three source types converge to a single type for transport creation. +func (h *MCPHandler) resolveToRemoteMCPServer(ctx context.Context, kind, namespace, name string) (*v1alpha2.RemoteMCPServer, error) { + key := client.ObjectKey{Namespace: namespace, Name: name} + + switch kind { + case "RemoteMCPServer": + server := &v1alpha2.RemoteMCPServer{} + if err := h.kubeClient.Get(ctx, key, server); err != nil { + return nil, fmt.Errorf("remoteMCPServer %s/%s not found: %w", namespace, name, err) + } + return server, nil + + case "Service": + svc := &corev1.Service{} + if err := h.kubeClient.Get(ctx, key, svc); err != nil { + return nil, fmt.Errorf("service %s/%s not found: %w", namespace, name, err) + } + return agent_translator.ConvertServiceToRemoteMCPServer(svc) + + case "MCPServer": + mcpServer := &v1alpha1.MCPServer{} + if err := h.kubeClient.Get(ctx, key, mcpServer); err != nil { + return nil, fmt.Errorf("mcpServer %s/%s not found: %w", namespace, name, err) + } + return agent_translator.ConvertMCPServerToRemoteMCPServer(mcpServer) + + default: + return nil, fmt.Errorf("unsupported kind: %s", kind) + } +} + +// --- Session caching --- + +// getOrCreateSession returns a cached MCP client session or creates a new one. +func (h *MCPHandler) getOrCreateSession(ctx context.Context, ref string, server *v1alpha2.RemoteMCPServer) (*mcpsdk.ClientSession, error) { + if cached, ok := h.sessions.Load(ref); ok { + if session, ok := cached.(*mcpsdk.ClientSession); ok { + return session, nil + } + } + + transport, err := CreateMCPTransport(ctx, h.kubeClient, server) + if err != nil { + return nil, fmt.Errorf("failed to create transport for %s: %w", ref, err) + } + + impl := &mcpsdk.Implementation{ + Name: "kagent-controller", + Version: version.Version, + } + mcpClient := mcpsdk.NewClient(impl, nil) + + session, err := mcpClient.Connect(ctx, transport, nil) + if err != nil { + return nil, fmt.Errorf("failed to connect to %s at %s: %w", ref, server.Spec.URL, err) + } + + h.sessions.Store(ref, session) + return session, nil +} + +// evictSession closes and removes a cached session so the next call creates a fresh one. +func (h *MCPHandler) evictSession(ref string) { + if cached, ok := h.sessions.LoadAndDelete(ref); ok { + if session, ok := cached.(*mcpsdk.ClientSession); ok { + session.Close() + } + } +} + +// --- Tool handlers --- + +// handleListToolServers lists all MCP tool servers across RemoteMCPServer, Service, and MCPServer. +func (h *MCPHandler) handleListToolServers(ctx context.Context, req *mcpsdk.CallToolRequest, input ListToolServersInput) (*mcpsdk.CallToolResult, ListToolServersOutput, error) { + log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "list_tool_servers") + + listOpts := []client.ListOption{} + if input.Namespace != "" { + listOpts = append(listOpts, client.InNamespace(input.Namespace)) + } + + var servers []ToolServerSummary + + // 1. RemoteMCPServers + remoteMCPList := &v1alpha2.RemoteMCPServerList{} + if err := h.kubeClient.List(ctx, remoteMCPList, listOpts...); err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to list RemoteMCPServers: %v", err)}, + }, + IsError: true, + }, ListToolServersOutput{}, nil + } + for _, r := range remoteMCPList.Items { + servers = append(servers, ToolServerSummary{ + Ref: fmt.Sprintf("RemoteMCPServer/%s/%s", r.Namespace, r.Name), + Kind: "RemoteMCPServer", + URL: r.Spec.URL, + Protocol: string(r.Spec.Protocol), + }) + } + + // 2. Services with kagent.dev/mcp-service=true label + svcListOpts := append([]client.ListOption{ + client.MatchingLabels{agent_translator.MCPServiceLabel: "true"}, + }, listOpts...) + svcList := &corev1.ServiceList{} + if err := h.kubeClient.List(ctx, svcList, svcListOpts...); err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to list Services: %v", err)}, + }, + IsError: true, + }, ListToolServersOutput{}, nil + } + for _, svc := range svcList.Items { + remoteMCP, err := agent_translator.ConvertServiceToRemoteMCPServer(&svc) + if err != nil { + log.V(1).Info("Skipping Service with invalid MCP config", "service", svc.Name, "namespace", svc.Namespace, "error", err) + continue + } + servers = append(servers, ToolServerSummary{ + Ref: fmt.Sprintf("Service/%s/%s", svc.Namespace, svc.Name), + Kind: "Service", + URL: remoteMCP.Spec.URL, + Protocol: string(remoteMCP.Spec.Protocol), + }) + } + + // 3. MCPServers (optional — CRD may not be installed) + mcpServerList := &v1alpha1.MCPServerList{} + if err := h.kubeClient.List(ctx, mcpServerList, listOpts...); err != nil { + // If the CRD isn't installed, treat as empty list + if !meta.IsNoMatchError(err) { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to list MCPServers: %v", err)}, + }, + IsError: true, + }, ListToolServersOutput{}, nil + } + } else { + for _, m := range mcpServerList.Items { + status := "NotReady" + for _, condition := range m.Status.Conditions { + if condition.Type == string(v1alpha1.MCPServerConditionReady) && condition.Status == metav1.ConditionTrue { + status = "Ready" + break + } + } + + remoteMCP, err := agent_translator.ConvertMCPServerToRemoteMCPServer(&m) + if err != nil { + log.V(1).Info("Skipping MCPServer with invalid config", "mcpserver", m.Name, "namespace", m.Namespace, "error", err) + continue + } + + servers = append(servers, ToolServerSummary{ + Ref: fmt.Sprintf("MCPServer/%s/%s", m.Namespace, m.Name), + Kind: "MCPServer", + URL: remoteMCP.Spec.URL, + Protocol: string(remoteMCP.Spec.Protocol), + Status: status, + }) + } + } + + log.Info("Listed tool servers", "count", len(servers)) + + output := ListToolServersOutput{Servers: servers} + + var fallbackText strings.Builder + if len(servers) == 0 { + fallbackText.WriteString("No tool servers found.") + } else { + for i, s := range servers { + if i > 0 { + fallbackText.WriteByte('\n') + } + fmt.Fprintf(&fallbackText, "%s url=%s protocol=%s", s.Ref, s.URL, s.Protocol) + if s.Status != "" { + fmt.Fprintf(&fallbackText, " status=%s", s.Status) + } + } + } + + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fallbackText.String()}, + }, + }, output, nil +} + +// handleListTools connects to a tool server and returns its available tools. +func (h *MCPHandler) handleListTools(ctx context.Context, req *mcpsdk.CallToolRequest, input ListToolsInput) (*mcpsdk.CallToolResult, ListToolsOutput, error) { + log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "list_tools", "server", input.Server) + + kind, namespace, name, err := parseServerRef(input.Server) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: err.Error()}, + }, + IsError: true, + }, ListToolsOutput{}, nil + } + + remoteMCP, err := h.resolveToRemoteMCPServer(ctx, kind, namespace, name) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: err.Error()}, + }, + IsError: true, + }, ListToolsOutput{}, nil + } + + session, err := h.getOrCreateSession(ctx, input.Server, remoteMCP) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to connect to %s: %v", input.Server, err)}, + }, + IsError: true, + }, ListToolsOutput{}, nil + } + + result, err := session.ListTools(ctx, &mcpsdk.ListToolsParams{}) + if err != nil { + // Connection may be stale; evict and retry once + h.evictSession(input.Server) + session, err = h.getOrCreateSession(ctx, input.Server, remoteMCP) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to reconnect to %s: %v", input.Server, err)}, + }, + IsError: true, + }, ListToolsOutput{}, nil + } + result, err = session.ListTools(ctx, &mcpsdk.ListToolsParams{}) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to list tools on %s: %v", input.Server, err)}, + }, + IsError: true, + }, ListToolsOutput{}, nil + } + } + + tools := make([]ToolInfo, 0, len(result.Tools)) + for _, t := range result.Tools { + tools = append(tools, ToolInfo{ + Name: t.Name, + Description: t.Description, + InputSchema: t.InputSchema, + }) + } + + log.Info("Listed tools", "server", input.Server, "count", len(tools)) + + output := ListToolsOutput{ + Server: input.Server, + Tools: tools, + } + + var fallbackText strings.Builder + if len(tools) == 0 { + fmt.Fprintf(&fallbackText, "No tools found on %s.", input.Server) + } else { + fmt.Fprintf(&fallbackText, "Tools on %s:\n", input.Server) + for _, t := range tools { + fmt.Fprintf(&fallbackText, "- %s: %s\n", t.Name, t.Description) + } + } + + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fallbackText.String()}, + }, + }, output, nil +} + +// handleCallTool invokes a specific tool on a specific tool server. +func (h *MCPHandler) handleCallTool(ctx context.Context, req *mcpsdk.CallToolRequest, input CallToolInput) (*mcpsdk.CallToolResult, CallToolOutput, error) { + log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "call_tool", "server", input.Server, "targetTool", input.Tool) + + if input.Tool == "" { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: "tool name is required"}, + }, + IsError: true, + }, CallToolOutput{}, nil + } + + kind, namespace, name, err := parseServerRef(input.Server) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: err.Error()}, + }, + IsError: true, + }, CallToolOutput{}, nil + } + + remoteMCP, err := h.resolveToRemoteMCPServer(ctx, kind, namespace, name) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: err.Error()}, + }, + IsError: true, + }, CallToolOutput{}, nil + } + + session, err := h.getOrCreateSession(ctx, input.Server, remoteMCP) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to connect to %s: %v", input.Server, err)}, + }, + IsError: true, + }, CallToolOutput{}, nil + } + + result, err := session.CallTool(ctx, &mcpsdk.CallToolParams{ + Name: input.Tool, + Arguments: input.Arguments, + }) + if err != nil { + // Connection may be stale; evict and retry once + h.evictSession(input.Server) + session, err = h.getOrCreateSession(ctx, input.Server, remoteMCP) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to reconnect to %s: %v", input.Server, err)}, + }, + IsError: true, + }, CallToolOutput{}, nil + } + result, err = session.CallTool(ctx, &mcpsdk.CallToolParams{ + Name: input.Tool, + Arguments: input.Arguments, + }) + if err != nil { + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("Failed to call tool %s on %s: %v", input.Tool, input.Server, err)}, + }, + IsError: true, + }, CallToolOutput{}, nil + } + } + + log.Info("Called tool", "server", input.Server, "targetTool", input.Tool, "isError", result.IsError) + + // Extract text content for fallback + var fallbackText strings.Builder + for _, content := range result.Content { + if textContent, ok := content.(*mcpsdk.TextContent); ok { + fallbackText.WriteString(textContent.Text) + } + } + + output := CallToolOutput{ + Server: input.Server, + Tool: input.Tool, + Content: result.StructuredContent, + IsError: result.IsError, + } + + // If no structured content, use the text content + if output.Content == nil { + output.Content = fallbackText.String() + } + + return result, output, nil +} diff --git a/go/core/internal/mcp/mcp_tool_server_handler_test.go b/go/core/internal/mcp/mcp_tool_server_handler_test.go new file mode 100644 index 000000000..7afe5beb9 --- /dev/null +++ b/go/core/internal/mcp/mcp_tool_server_handler_test.go @@ -0,0 +1,471 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mcp + +import ( + "context" + "testing" + + "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/kagent-dev/kmcp/api/v1alpha1" + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func setupTestScheme() *runtime.Scheme { + s := runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(s)) + utilruntime.Must(v1alpha2.AddToScheme(s)) + utilruntime.Must(v1alpha1.AddToScheme(s)) + return s +} + +func setupToolServerTestHandler(t *testing.T, objects ...client.Object) *MCPHandler { + t.Helper() + kubeClient := fake.NewClientBuilder(). + WithScheme(setupTestScheme()). + WithObjects(objects...). + WithStatusSubresource(&v1alpha2.RemoteMCPServer{}, &v1alpha1.MCPServer{}). + Build() + + return &MCPHandler{ + kubeClient: kubeClient, + } +} + +// --- parseServerRef tests --- + +func TestParseServerRef(t *testing.T) { + tests := []struct { + name string + ref string + wantKind string + wantNS string + wantName string + wantError string + }{ + { + name: "valid RemoteMCPServer ref", + ref: "RemoteMCPServer/default/my-server", + wantKind: "RemoteMCPServer", + wantNS: "default", + wantName: "my-server", + }, + { + name: "valid Service ref", + ref: "Service/tools/prometheus-mcp", + wantKind: "Service", + wantNS: "tools", + wantName: "prometheus-mcp", + }, + { + name: "valid MCPServer ref", + ref: "MCPServer/default/weather", + wantKind: "MCPServer", + wantNS: "default", + wantName: "weather", + }, + { + name: "missing kind - two parts only", + ref: "default/my-server", + wantError: "invalid server reference", + }, + { + name: "no slashes", + ref: "just-a-name", + wantError: "invalid server reference", + }, + { + name: "empty parts", + ref: "RemoteMCPServer//", + wantError: "invalid server reference", + }, + { + name: "unknown kind", + ref: "Deployment/default/my-deploy", + wantError: "unknown server kind", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + kind, ns, name, err := parseServerRef(tt.ref) + if tt.wantError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantKind, kind) + assert.Equal(t, tt.wantNS, ns) + assert.Equal(t, tt.wantName, name) + }) + } +} + +// --- resolveToRemoteMCPServer tests --- + +func TestResolveToRemoteMCPServer(t *testing.T) { + remoteMCP := &v1alpha2.RemoteMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-remote", + Namespace: "default", + }, + Spec: v1alpha2.RemoteMCPServerSpec{ + URL: "http://my-remote.default:8080/mcp", + Protocol: v1alpha2.RemoteMCPServerProtocolStreamableHttp, + }, + } + + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-svc", + Namespace: "tools", + Labels: map[string]string{"kagent.dev/mcp-service": "true"}, + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + {Port: 9090}, + }, + }, + } + + mcpServer := &v1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "weather", + Namespace: "default", + }, + Spec: v1alpha1.MCPServerSpec{ + TransportType: v1alpha1.TransportTypeStdio, + Deployment: v1alpha1.MCPServerDeployment{ + Port: 3000, + }, + }, + } + + handler := setupToolServerTestHandler(t, remoteMCP, svc, mcpServer) + ctx := context.Background() + + t.Run("resolves RemoteMCPServer", func(t *testing.T) { + result, err := handler.resolveToRemoteMCPServer(ctx, "RemoteMCPServer", "default", "my-remote") + require.NoError(t, err) + assert.Equal(t, "http://my-remote.default:8080/mcp", result.Spec.URL) + }) + + t.Run("resolves Service", func(t *testing.T) { + result, err := handler.resolveToRemoteMCPServer(ctx, "Service", "tools", "my-svc") + require.NoError(t, err) + assert.Contains(t, result.Spec.URL, "my-svc.tools") + assert.Contains(t, result.Spec.URL, "9090") + }) + + t.Run("resolves MCPServer", func(t *testing.T) { + result, err := handler.resolveToRemoteMCPServer(ctx, "MCPServer", "default", "weather") + require.NoError(t, err) + assert.Contains(t, result.Spec.URL, "weather.default") + assert.Contains(t, result.Spec.URL, "3000") + }) + + t.Run("RemoteMCPServer not found", func(t *testing.T) { + _, err := handler.resolveToRemoteMCPServer(ctx, "RemoteMCPServer", "default", "nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("Service not found", func(t *testing.T) { + _, err := handler.resolveToRemoteMCPServer(ctx, "Service", "default", "nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("MCPServer not found", func(t *testing.T) { + _, err := handler.resolveToRemoteMCPServer(ctx, "MCPServer", "default", "nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +// --- handleListToolServers tests --- + +func TestHandleListToolServers(t *testing.T) { + tests := []struct { + name string + objects []client.Object + input ListToolServersInput + expectedCount int + checkFunc func(t *testing.T, output ListToolServersOutput) + }{ + { + name: "empty cluster returns empty list", + objects: nil, + input: ListToolServersInput{}, + expectedCount: 0, + }, + { + name: "returns RemoteMCPServers", + objects: []client.Object{ + &v1alpha2.RemoteMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: "remote-1", Namespace: "default"}, + Spec: v1alpha2.RemoteMCPServerSpec{ + URL: "http://remote-1.default:8080/mcp", + Protocol: v1alpha2.RemoteMCPServerProtocolStreamableHttp, + }, + }, + }, + input: ListToolServersInput{}, + expectedCount: 1, + checkFunc: func(t *testing.T, output ListToolServersOutput) { + assert.Equal(t, "RemoteMCPServer/default/remote-1", output.Servers[0].Ref) + assert.Equal(t, "RemoteMCPServer", output.Servers[0].Kind) + assert.Equal(t, "http://remote-1.default:8080/mcp", output.Servers[0].URL) + }, + }, + { + name: "returns Services with MCP label", + objects: []client.Object{ + &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mcp-svc", + Namespace: "tools", + Labels: map[string]string{"kagent.dev/mcp-service": "true"}, + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{{Port: 9090}}, + }, + }, + // Service without MCP label should be excluded + &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "regular-svc", + Namespace: "tools", + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{{Port: 80}}, + }, + }, + }, + input: ListToolServersInput{}, + expectedCount: 1, + checkFunc: func(t *testing.T, output ListToolServersOutput) { + assert.Equal(t, "Service/tools/mcp-svc", output.Servers[0].Ref) + assert.Equal(t, "Service", output.Servers[0].Kind) + assert.Contains(t, output.Servers[0].URL, "9090") + }, + }, + { + name: "returns MCPServers with status", + objects: []client.Object{ + &v1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: "weather", Namespace: "default"}, + Spec: v1alpha1.MCPServerSpec{ + TransportType: v1alpha1.TransportTypeStdio, + Deployment: v1alpha1.MCPServerDeployment{Port: 3000}, + }, + Status: v1alpha1.MCPServerStatus{ + Conditions: []metav1.Condition{ + { + Type: string(v1alpha1.MCPServerConditionReady), + Status: metav1.ConditionTrue, + }, + }, + }, + }, + }, + input: ListToolServersInput{}, + expectedCount: 1, + checkFunc: func(t *testing.T, output ListToolServersOutput) { + assert.Equal(t, "MCPServer/default/weather", output.Servers[0].Ref) + assert.Equal(t, "MCPServer", output.Servers[0].Kind) + assert.Equal(t, "Ready", output.Servers[0].Status) + }, + }, + { + name: "returns all types combined", + objects: []client.Object{ + &v1alpha2.RemoteMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: "remote-1", Namespace: "default"}, + Spec: v1alpha2.RemoteMCPServerSpec{ + URL: "http://remote-1.default:8080/mcp", + Protocol: v1alpha2.RemoteMCPServerProtocolStreamableHttp, + }, + }, + &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mcp-svc", + Namespace: "tools", + Labels: map[string]string{"kagent.dev/mcp-service": "true"}, + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{{Port: 9090}}, + }, + }, + &v1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: "weather", Namespace: "default"}, + Spec: v1alpha1.MCPServerSpec{ + TransportType: v1alpha1.TransportTypeStdio, + Deployment: v1alpha1.MCPServerDeployment{Port: 3000}, + }, + }, + }, + input: ListToolServersInput{}, + expectedCount: 3, + checkFunc: func(t *testing.T, output ListToolServersOutput) { + kinds := make(map[string]bool) + for _, s := range output.Servers { + kinds[s.Kind] = true + } + assert.True(t, kinds["RemoteMCPServer"]) + assert.True(t, kinds["Service"]) + assert.True(t, kinds["MCPServer"]) + }, + }, + { + name: "filters by namespace", + objects: []client.Object{ + &v1alpha2.RemoteMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: "remote-1", Namespace: "default"}, + Spec: v1alpha2.RemoteMCPServerSpec{ + URL: "http://remote-1.default:8080/mcp", + Protocol: v1alpha2.RemoteMCPServerProtocolStreamableHttp, + }, + }, + &v1alpha2.RemoteMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: "remote-2", Namespace: "tools"}, + Spec: v1alpha2.RemoteMCPServerSpec{ + URL: "http://remote-2.tools:8080/mcp", + Protocol: v1alpha2.RemoteMCPServerProtocolStreamableHttp, + }, + }, + }, + input: ListToolServersInput{Namespace: "tools"}, + expectedCount: 1, + checkFunc: func(t *testing.T, output ListToolServersOutput) { + assert.Equal(t, "RemoteMCPServer/tools/remote-2", output.Servers[0].Ref) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := setupToolServerTestHandler(t, tt.objects...) + ctx := context.Background() + + result, output, err := handler.handleListToolServers(ctx, &mcpsdk.CallToolRequest{}, tt.input) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Len(t, output.Servers, tt.expectedCount) + + if tt.checkFunc != nil { + tt.checkFunc(t, output) + } + }) + } +} + +// --- handleListTools validation tests --- + +func TestHandleListToolsValidation(t *testing.T) { + tests := []struct { + name string + input ListToolsInput + wantError string + }{ + { + name: "invalid ref format - two parts only", + input: ListToolsInput{Server: "default/my-server"}, + wantError: "invalid server reference", + }, + { + name: "invalid ref format - no slashes", + input: ListToolsInput{Server: "just-a-name"}, + wantError: "invalid server reference", + }, + { + name: "server not found", + input: ListToolsInput{Server: "RemoteMCPServer/default/nonexistent"}, + wantError: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := setupToolServerTestHandler(t) + ctx := context.Background() + + result, _, err := handler.handleListTools(ctx, &mcpsdk.CallToolRequest{}, tt.input) + require.NoError(t, err) // protocol-level error should not occur + assert.True(t, result.IsError) + + for _, content := range result.Content { + if textContent, ok := content.(*mcpsdk.TextContent); ok { + assert.Contains(t, textContent.Text, tt.wantError) + } + } + }) + } +} + +// --- handleCallTool validation tests --- + +func TestHandleCallToolValidation(t *testing.T) { + tests := []struct { + name string + input CallToolInput + wantError string + }{ + { + name: "missing tool name", + input: CallToolInput{Server: "RemoteMCPServer/default/test", Tool: ""}, + wantError: "tool name is required", + }, + { + name: "invalid ref format", + input: CallToolInput{Server: "bad-ref", Tool: "some_tool"}, + wantError: "invalid server reference", + }, + { + name: "server not found", + input: CallToolInput{Server: "RemoteMCPServer/default/nonexistent", Tool: "some_tool"}, + wantError: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := setupToolServerTestHandler(t) + ctx := context.Background() + + result, _, err := handler.handleCallTool(ctx, &mcpsdk.CallToolRequest{}, tt.input) + require.NoError(t, err) // protocol-level error should not occur + assert.True(t, result.IsError) + + for _, content := range result.Content { + if textContent, ok := content.(*mcpsdk.TextContent); ok { + assert.Contains(t, textContent.Text, tt.wantError) + } + } + }) + } +} diff --git a/go/core/internal/mcp/transport.go b/go/core/internal/mcp/transport.go new file mode 100644 index 000000000..5a21a0ae8 --- /dev/null +++ b/go/core/internal/mcp/transport.go @@ -0,0 +1,79 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mcp + +import ( + "context" + "net/http" + + "github.com/kagent-dev/kagent/go/api/v1alpha2" + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// CreateMCPTransport creates an MCP transport from a RemoteMCPServer spec, +// resolving any headersFrom references using the kube client. +func CreateMCPTransport(ctx context.Context, kubeClient client.Client, server *v1alpha2.RemoteMCPServer) (mcpsdk.Transport, error) { + headers, err := server.ResolveHeaders(ctx, kubeClient) + if err != nil { + return nil, err + } + + httpClient := newHTTPClient(headers) + + switch server.Spec.Protocol { + case v1alpha2.RemoteMCPServerProtocolSse: + return &mcpsdk.SSEClientTransport{ + Endpoint: server.Spec.URL, + HTTPClient: httpClient, + }, nil + default: + return &mcpsdk.StreamableClientTransport{ + Endpoint: server.Spec.URL, + HTTPClient: httpClient, + }, nil + } +} + +// headerTransport is an http.RoundTripper that adds custom headers to requests. +type headerTransport struct { + headers map[string]string + base http.RoundTripper +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + for k, v := range t.headers { + req.Header.Set(k, v) + } + if t.base == nil { + t.base = http.DefaultTransport + } + return t.base.RoundTrip(req) +} + +func newHTTPClient(headers map[string]string) *http.Client { + if len(headers) == 0 { + return http.DefaultClient + } + return &http.Client{ + Transport: &headerTransport{ + headers: headers, + base: http.DefaultTransport, + }, + } +}