From 3f3dee4c49b84952412fa08256beea77b76386f6 Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Thu, 19 Feb 2026 19:48:08 +0100 Subject: [PATCH 1/3] feat(logging): centralize K8s error logging in MCP tool handler (#792) * feat(logging): centralize K8s error logging in MCP tool handler Move HandleK8sError calls from individual tool handlers to the central ServerToolToGoSdkTool dispatch point in gosdk.go. This eliminates duplicated logging across ~20 call sites and ensures consistent error categorization for all tools automatically. Add errors.As guard in HandleK8sError fallback to only log actual K8s API errors, silently ignoring non-K8s errors (e.g. access control denials). Extract README logging section into docs/logging.md. Add unit tests for HandleK8sError covering all K8s error types, wrapped error chain traversal, and non-K8s error filtering. Add integration tests verifying non-K8s errors produce no log notifications (pods, helm) and K8s forbidden errors from Helm produce correct notifications. Add RequireNoLogNotification test helper. Signed-off-by: Marc Nuri * test(logging): assert on log level and message in K8s error tests Extract classifyK8sError to make error-to-log mapping directly testable. Unit tests now verify the correct Level and message for each K8s error type instead of only asserting NotPanics. Signed-off-by: Marc Nuri --------- Signed-off-by: Marc Nuri --- AGENTS.md | 1 + README.md | 66 +-------- docs/README.md | 1 + docs/logging.md | 83 +++++++++++ internal/test/mcp.go | 27 ++++ pkg/mcp/gosdk.go | 4 + pkg/mcp/helm_test.go | 27 ++++ pkg/mcp/pods_test.go | 8 + pkg/mcplog/k8s.go | 46 ++++-- pkg/mcplog/k8s_test.go | 181 +++++++++++++++++++++++ pkg/toolsets/core/error_handling_test.go | 164 -------------------- pkg/toolsets/core/events.go | 2 - pkg/toolsets/core/namespaces.go | 3 - pkg/toolsets/core/nodes.go | 5 - pkg/toolsets/core/pods.go | 5 - pkg/toolsets/core/resources.go | 6 - pkg/toolsets/helm/helm.go | 4 - 17 files changed, 365 insertions(+), 268 deletions(-) create mode 100644 docs/logging.md create mode 100644 pkg/mcplog/k8s_test.go delete mode 100644 pkg/toolsets/core/error_handling_test.go diff --git a/AGENTS.md b/AGENTS.md index 5012e84d0..45eb4f81e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -304,6 +304,7 @@ The `docs/` directory contains user-facing documentation: - `docs/README.md` – Documentation index and navigation - `docs/configuration.md` – **Complete TOML configuration reference** (all `StaticConfig` options, drop-in configuration, dynamic reload) - `docs/prompts.md` – MCP Prompts configuration guide +- `docs/logging.md` – MCP Logging guide (automatic K8s error logging, secret redaction) - `docs/OTEL.md` – OpenTelemetry observability setup - `docs/KIALI.md` – Kiali toolset configuration - `docs/GETTING_STARTED_KUBERNETES.md` – Kubernetes ServiceAccount setup diff --git a/README.md b/README.md index 1eb1f284f..cb4a64630 100644 --- a/README.md +++ b/README.md @@ -232,70 +232,10 @@ See the **[Configuration Reference](docs/configuration.md)**. ## 📊 MCP Logging The server supports the MCP logging capability, allowing clients to receive debugging information via structured log messages. +Kubernetes API errors are automatically categorized and logged to clients with appropriate severity levels. +Sensitive data (tokens, keys, passwords, cloud credentials) is automatically redacted before being sent to clients. -### For Clients - -Clients can control log verbosity by sending a `logging/setLevel` request: - -```json -{ - "method": "logging/setLevel", - "params": { "level": "info" } -} -``` - -**Available log levels** (in order of increasing severity): -- `debug` - Detailed debugging information -- `info` - General informational messages (default) -- `notice` - Normal but significant events -- `warning` - Warning messages -- `error` - Error conditions -- `critical` - Critical conditions -- `alert` - Action must be taken immediately -- `emergency` - System is unusable - -### For Developers - -Toolsets can optionally send debug information to clients using helper functions from the `mcplog` package: - -**Recommended approach for Kubernetes errors** (automatically categorizes errors and sends appropriate messages): - -```go -import "github.com/containers/kubernetes-mcp-server/pkg/mcplog" - -// In your tool handler: -ret, err := client.CoreV1().Pods(namespace).Get(ctx, name, metav1.GetOptions{}) -if err != nil { - mcplog.HandleK8sError(ctx, err, "pod access") - return api.NewToolCallResult("", fmt.Errorf("failed to get pod: %v", err)), nil -} -``` - -**Manual logging** (for custom messages): - -```go -import "github.com/containers/kubernetes-mcp-server/pkg/mcplog" - -// In your tool handler: -if err != nil { - mcplog.SendMCPLog(ctx, "error", "Operation failed - check permissions") - return api.NewToolCallResult("", err) -} -``` - -**Key Points:** -- Logging is **optional** - toolsets work fine without sending MCP logs -- Uses a dedicated named logger (`logger="mcp"`) for complete separation from server logs -- Server logs (klog) remain detailed and unaffected -- Client logs are high-level, helpful hints for debugging -- Authentication failures send generic messages to clients (no security info leaked) -- Sensitive data is automatically redacted with 28 pattern types: - - Generic fields (password, token, secret, api_key, etc.) - - Authorization headers (Bearer, Basic) - - Cloud credentials (AWS, GCP, Azure) - - API tokens (GitHub, GitLab, OpenAI, Anthropic) - - Cryptographic keys (JWT, SSH, PGP, RSA) - - Database connection strings (PostgreSQL, MySQL, MongoDB) +See the **[MCP Logging Guide](docs/logging.md)**. ## 🛠️ Tools and Functionalities diff --git a/docs/README.md b/docs/README.md index 40fb93a23..9715e69b5 100644 --- a/docs/README.md +++ b/docs/README.md @@ -26,6 +26,7 @@ Choose the guide that matches your needs: ## Advanced Topics +- **[MCP Logging](logging.md)** - Structured logging to MCP clients with automatic K8s error categorization and secret redaction - **[OpenTelemetry Observability](OTEL.md)** - Distributed tracing and metrics configuration - **[MCP Prompts](prompts.md)** - Custom workflow templates for AI assistants - **[Keycloak OIDC Setup](KEYCLOAK_OIDC_SETUP.md)** - Developer guide for local Keycloak environment and testing with MCP Inspector diff --git a/docs/logging.md b/docs/logging.md new file mode 100644 index 000000000..edb9d7555 --- /dev/null +++ b/docs/logging.md @@ -0,0 +1,83 @@ +# MCP Logging + +The server supports the MCP logging capability, allowing clients to receive debugging information via structured log messages. + +## For Clients + +Clients can control log verbosity by sending a `logging/setLevel` request: + +```json +{ + "method": "logging/setLevel", + "params": { "level": "info" } +} +``` + +**Available log levels** (in order of increasing severity): +- `debug` - Detailed debugging information +- `info` - General informational messages (default) +- `notice` - Normal but significant events +- `warning` - Warning messages +- `error` - Error conditions +- `critical` - Critical conditions +- `alert` - Action must be taken immediately +- `emergency` - System is unusable + +## For Developers + +### Automatic Kubernetes Error Logging + +Kubernetes API errors returned by tool handlers are **automatically logged** to MCP clients. +When a tool handler returns a `ToolCallResult` with a non-nil error that is a Kubernetes API error (`StatusError`), the server categorizes it and sends an appropriate log message. + +This means toolset authors **do not need to call any logging functions** for standard K8s error handling. +Simply return the error in the `ToolCallResult` and the server handles the rest: + +```go +ret, err := client.CoreV1().Pods(namespace).Get(ctx, name, metav1.GetOptions{}) +if err != nil { + return api.NewToolCallResult("", fmt.Errorf("failed to get pod: %w", err)), nil +} +``` + +The following Kubernetes error types are automatically categorized: + +| Error Type | Log Level | Message | +|-----------|-----------|---------| +| Not Found | `info` | Resource not found - it may not exist or may have been deleted | +| Forbidden | `error` | Permission denied - check RBAC permissions for {tool} | +| Unauthorized | `error` | Authentication failed - check cluster credentials | +| Already Exists | `warning` | Resource already exists | +| Invalid | `error` | Invalid resource specification - check resource definition | +| Bad Request | `error` | Invalid request - check parameters | +| Conflict | `error` | Resource conflict - resource may have been modified | +| Timeout | `error` | Request timeout - cluster may be slow or overloaded | +| Server Timeout | `error` | Server timeout - cluster may be slow or overloaded | +| Service Unavailable | `error` | Service unavailable - cluster may be unreachable | +| Too Many Requests | `warning` | Rate limited - too many requests to the cluster | +| Other K8s API errors | `error` | Operation failed - cluster may be unreachable or experiencing issues | + +Non-Kubernetes errors (e.g., input validation errors) are **not** logged to MCP clients. + +### Manual Logging + +For custom messages beyond automatic K8s error handling, use `SendMCPLog` directly: + +```go +import "github.com/containers/kubernetes-mcp-server/pkg/mcplog" + +mcplog.SendMCPLog(ctx, mcplog.LevelError, "Operation failed - check permissions") +``` + +## Security + +- Authentication failures send generic messages to clients (no security info leaked) +- Sensitive data is automatically redacted before being sent to clients, covering: + - Generic fields (password, token, secret, api_key, etc.) + - Authorization headers (Bearer, Basic) + - Cloud credentials (AWS, GCP, Azure) + - API tokens (GitHub, GitLab, OpenAI, Anthropic) + - Cryptographic keys (JWT, SSH, PGP, RSA) + - Database connection strings (PostgreSQL, MySQL, MongoDB) +- Uses a dedicated named logger (`logger="mcp"`) for complete separation from server logs +- Server logs (klog) remain detailed and unaffected diff --git a/internal/test/mcp.go b/internal/test/mcp.go index c0b5d4f4b..4c9e32a8e 100644 --- a/internal/test/mcp.go +++ b/internal/test/mcp.go @@ -243,3 +243,30 @@ func (c *NotificationCapture) RequireLogNotification(t *testing.T, timeout time. require.NotNil(t, logNotification, "failed to parse log notification") return logNotification } + +// RequireNoLogNotification asserts that no logging notification is received within the given timeout. +// Use this to verify that non-Kubernetes errors do not produce MCP log notifications. +func (c *NotificationCapture) RequireNoLogNotification(t *testing.T, timeout time.Duration) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + c.mu.RLock() + for _, n := range c.notifications { + if n.Method == "notifications/message" { + c.mu.RUnlock() + require.Fail(t, "unexpected log notification received", "notification: %v", n) + return + } + } + c.mu.RUnlock() + + select { + case <-c.signal: + // New notification arrived, check it + case <-ctx.Done(): + // Timeout with no log notification — success + return + } + } +} diff --git a/pkg/mcp/gosdk.go b/pkg/mcp/gosdk.go index 8196488fd..02f94a284 100644 --- a/pkg/mcp/gosdk.go +++ b/pkg/mcp/gosdk.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/containers/kubernetes-mcp-server/pkg/mcplog" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" "k8s.io/utils/ptr" @@ -55,6 +56,9 @@ func ServerToolToGoSdkTool(s *Server, tool api.ServerTool) (*mcp.Tool, mcp.ToolH if err != nil { return nil, err } + if result.Error != nil { + mcplog.HandleK8sError(ctx, result.Error, tool.Tool.Name) + } return NewStructuredResult(result.Content, result.StructuredContent, result.Error), nil } return goSdkTool, goSdkHandler, nil diff --git a/pkg/mcp/helm_test.go b/pkg/mcp/helm_test.go index 2646118d0..7ce42c022 100644 --- a/pkg/mcp/helm_test.go +++ b/pkg/mcp/helm_test.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/BurntSushi/toml" "github.com/mark3labs/mcp-go/mcp" @@ -95,6 +96,7 @@ func (s *HelmSuite) TestHelmInstallDenied() { `), s.Cfg), "Expected to parse denied resources config") s.InitMcpClient() s.Run("helm_install(chart=helm-chart-secret, denied)", func() { + capture := s.StartCapturingLogNotifications() _, file, _, _ := runtime.Caller(0) chartPath := filepath.Join(filepath.Dir(file), "testdata", "helm-chart-secret") toolResult, err := s.CallTool("helm_install", map[string]interface{}{ @@ -111,6 +113,9 @@ func (s *HelmSuite) TestHelmInstallDenied() { expectedMessage := ": resource not allowed: /v1, Kind=Secret" s.Truef(strings.HasSuffix(msg, expectedMessage), "expected descriptive error '%s', got %v", expectedMessage, msg) }) + s.Run("does not send log notification for non-K8s error", func() { + capture.RequireNoLogNotification(s.T(), 500*time.Millisecond) + }) }) } @@ -328,6 +333,28 @@ func (s *HelmSuite) TestHelmUninstallDenied() { }) } +func (s *HelmSuite) TestHelmListForbidden() { + s.InitMcpClient() + defer restoreAuth(s.T().Context()) + client := kubernetes.NewForConfigOrDie(envTestRestConfig) + _ = client.RbacV1().ClusterRoles().Delete(s.T().Context(), "allow-all", metav1.DeleteOptions{}) + + s.Run("helm_list (forbidden)", func() { + capture := s.StartCapturingLogNotifications() + toolResult, _ := s.CallTool("helm_list", map[string]interface{}{}) + s.Run("returns error", func() { + s.Truef(toolResult.IsError, "call tool should fail") + s.Contains(toolResult.Content[0].(mcp.TextContent).Text, "forbidden", + "error message should indicate forbidden") + }) + s.Run("sends log notification", func() { + logNotification := capture.RequireLogNotification(s.T(), 2*time.Second) + s.Equal("error", logNotification.Level, "forbidden errors should log at error level") + s.Contains(logNotification.Data, "Permission denied", "log message should indicate permission denied") + }) + }) +} + func clearHelmReleases(ctx context.Context, kc *kubernetes.Clientset) { secrets, _ := kc.CoreV1().Secrets("default").List(ctx, metav1.ListOptions{}) for _, secret := range secrets.Items { diff --git a/pkg/mcp/pods_test.go b/pkg/mcp/pods_test.go index 918f9811a..f827b7ef8 100644 --- a/pkg/mcp/pods_test.go +++ b/pkg/mcp/pods_test.go @@ -157,6 +157,7 @@ func (s *PodsSuite) TestPodsListDenied() { `), s.Cfg), "Expected to parse denied resources config") s.InitMcpClient() s.Run("pods_list (denied)", func() { + capture := s.StartCapturingLogNotifications() podsList, err := s.CallTool("pods_list", map[string]interface{}{}) s.Run("has error", func() { s.Truef(podsList.IsError, "call tool should fail") @@ -169,8 +170,12 @@ func (s *PodsSuite) TestPodsListDenied() { s.Regexpf(expectedMessage, msg, "expected descriptive error '%s', got %v", expectedMessage, msg) }) + s.Run("does not send log notification for non-K8s error", func() { + capture.RequireNoLogNotification(s.T(), 500*time.Millisecond) + }) }) s.Run("pods_list_in_namespace (denied)", func() { + capture := s.StartCapturingLogNotifications() podsListInNamespace, err := s.CallTool("pods_list_in_namespace", map[string]interface{}{"namespace": "ns-1"}) s.Run("has error", func() { s.Truef(podsListInNamespace.IsError, "call tool should fail") @@ -183,6 +188,9 @@ func (s *PodsSuite) TestPodsListDenied() { s.Regexpf(expectedMessage, msg, "expected descriptive error '%s', got %v", expectedMessage, msg) }) + s.Run("does not send log notification for non-K8s error", func() { + capture.RequireNoLogNotification(s.T(), 500*time.Millisecond) + }) }) } diff --git a/pkg/mcplog/k8s.go b/pkg/mcplog/k8s.go index 850dbebe5..59fd60fb7 100644 --- a/pkg/mcplog/k8s.go +++ b/pkg/mcplog/k8s.go @@ -2,40 +2,54 @@ package mcplog import ( "context" + "errors" apierrors "k8s.io/apimachinery/pkg/api/errors" ) -// HandleK8sError sends appropriate MCP log messages based on Kubernetes API error types. -// operation should describe the operation (e.g., "pod access", "deployment deletion"). -func HandleK8sError(ctx context.Context, err error, operation string) { +// classifyK8sError maps a Kubernetes API error to a log level and message. +// Returns the level, message, and true if the error should be logged. +// Returns zero values and false for nil errors or non-Kubernetes errors. +func classifyK8sError(err error, operation string) (Level, string, bool) { if err == nil { - return + return 0, "", false } if apierrors.IsNotFound(err) { - SendMCPLog(ctx, LevelInfo, "Resource not found - it may not exist or may have been deleted") + return LevelInfo, "Resource not found - it may not exist or may have been deleted", true } else if apierrors.IsForbidden(err) { - SendMCPLog(ctx, LevelError, "Permission denied - check RBAC permissions for "+operation) + return LevelError, "Permission denied - check RBAC permissions for " + operation, true } else if apierrors.IsUnauthorized(err) { - SendMCPLog(ctx, LevelError, "Authentication failed - check cluster credentials") + return LevelError, "Authentication failed - check cluster credentials", true } else if apierrors.IsAlreadyExists(err) { - SendMCPLog(ctx, LevelWarning, "Resource already exists") + return LevelWarning, "Resource already exists", true } else if apierrors.IsInvalid(err) { - SendMCPLog(ctx, LevelError, "Invalid resource specification - check resource definition") + return LevelError, "Invalid resource specification - check resource definition", true } else if apierrors.IsBadRequest(err) { - SendMCPLog(ctx, LevelError, "Invalid request - check parameters") + return LevelError, "Invalid request - check parameters", true } else if apierrors.IsConflict(err) { - SendMCPLog(ctx, LevelError, "Resource conflict - resource may have been modified") + return LevelError, "Resource conflict - resource may have been modified", true } else if apierrors.IsTimeout(err) { - SendMCPLog(ctx, LevelError, "Request timeout - cluster may be slow or overloaded") + return LevelError, "Request timeout - cluster may be slow or overloaded", true } else if apierrors.IsServerTimeout(err) { - SendMCPLog(ctx, LevelError, "Server timeout - cluster may be slow or overloaded") + return LevelError, "Server timeout - cluster may be slow or overloaded", true } else if apierrors.IsServiceUnavailable(err) { - SendMCPLog(ctx, LevelError, "Service unavailable - cluster may be unreachable") + return LevelError, "Service unavailable - cluster may be unreachable", true } else if apierrors.IsTooManyRequests(err) { - SendMCPLog(ctx, LevelWarning, "Rate limited - too many requests to the cluster") + return LevelWarning, "Rate limited - too many requests to the cluster", true } else { - SendMCPLog(ctx, LevelError, "Operation failed - cluster may be unreachable or experiencing issues") + var apiStatus apierrors.APIStatus + if errors.As(err, &apiStatus) { + return LevelError, "Operation failed - cluster may be unreachable or experiencing issues", true + } + } + return 0, "", false +} + +// HandleK8sError sends appropriate MCP log messages based on Kubernetes API error types. +// operation should describe the operation (e.g., "pod access", "deployment deletion"). +func HandleK8sError(ctx context.Context, err error, operation string) { + if level, message, ok := classifyK8sError(err, operation); ok { + SendMCPLog(ctx, level, message) } } diff --git a/pkg/mcplog/k8s_test.go b/pkg/mcplog/k8s_test.go new file mode 100644 index 000000000..a4ffc8ec3 --- /dev/null +++ b/pkg/mcplog/k8s_test.go @@ -0,0 +1,181 @@ +package mcplog + +import ( + "context" + "fmt" + "testing" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + + "github.com/stretchr/testify/suite" +) + +type K8sErrorSuite struct { + suite.Suite +} + +func (s *K8sErrorSuite) TestClassifyK8sError() { + gr := schema.GroupResource{Group: "", Resource: "pods"} + + s.Run("nil error returns false", func() { + _, _, ok := classifyK8sError(nil, "any operation") + s.False(ok) + }) + + s.Run("NotFound returns info level", func() { + level, message, ok := classifyK8sError(apierrors.NewNotFound(gr, "test-pod"), "pod access") + s.True(ok) + s.Equal(LevelInfo, level) + s.Contains(message, "Resource not found") + }) + + s.Run("Forbidden returns error level with operation", func() { + level, message, ok := classifyK8sError(apierrors.NewForbidden(gr, "test-pod", nil), "pod access") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Permission denied") + s.Contains(message, "pod access") + }) + + s.Run("Unauthorized returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewUnauthorized("unauthorized"), "resource access") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Authentication failed") + }) + + s.Run("AlreadyExists returns warning level", func() { + level, message, ok := classifyK8sError(apierrors.NewAlreadyExists(gr, "test-pod"), "resource creation") + s.True(ok) + s.Equal(LevelWarning, level) + s.Contains(message, "already exists") + }) + + s.Run("Invalid returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewInvalid(schema.GroupKind{Group: "", Kind: "Pod"}, "test-pod", nil), "resource update") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Invalid resource specification") + }) + + s.Run("BadRequest returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewBadRequest("bad request"), "resource scaling") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Invalid request") + }) + + s.Run("Conflict returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewConflict(gr, "test-pod", nil), "resource update") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Resource conflict") + }) + + s.Run("Timeout returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewTimeoutError("timeout", 30), "node log access") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "timeout") + }) + + s.Run("ServerTimeout returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewServerTimeout(gr, "get", 60), "node stats access") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "timeout") + }) + + s.Run("ServiceUnavailable returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewServiceUnavailable("unavailable"), "events listing") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Service unavailable") + }) + + s.Run("TooManyRequests returns warning level", func() { + level, message, ok := classifyK8sError(apierrors.NewTooManyRequests("rate limited", 10), "namespace listing") + s.True(ok) + s.Equal(LevelWarning, level) + s.Contains(message, "Rate limited") + }) + + s.Run("other K8s API error returns error level", func() { + level, message, ok := classifyK8sError(apierrors.NewInternalError(fmt.Errorf("internal error")), "resource access") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Operation failed") + }) +} + +func (s *K8sErrorSuite) TestClassifyK8sErrorIgnoresNonK8sErrors() { + s.Run("plain error returns false", func() { + _, _, ok := classifyK8sError(fmt.Errorf("some non-k8s error"), "operation") + s.False(ok) + }) + + s.Run("wrapped non-K8s error returns false", func() { + inner := fmt.Errorf("connection refused") + _, _, ok := classifyK8sError(fmt.Errorf("failed to connect: %w", inner), "operation") + s.False(ok) + }) +} + +func (s *K8sErrorSuite) TestClassifyK8sErrorWithWrappedK8sErrors() { + gr := schema.GroupResource{Group: "", Resource: "secrets"} + + s.Run("wrapped NotFound is detected", func() { + inner := apierrors.NewNotFound(gr, "my-secret") + wrapped := fmt.Errorf("helm operation failed: %w", inner) + level, message, ok := classifyK8sError(wrapped, "helm install") + s.True(ok) + s.Equal(LevelInfo, level) + s.Contains(message, "Resource not found") + }) + + s.Run("wrapped Forbidden is detected", func() { + inner := apierrors.NewForbidden(gr, "my-secret", nil) + wrapped := fmt.Errorf("helm operation failed: %w", inner) + level, message, ok := classifyK8sError(wrapped, "helm install") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Permission denied") + s.Contains(message, "helm install") + }) + + s.Run("wrapped generic K8s API error is detected", func() { + inner := apierrors.NewInternalError(fmt.Errorf("internal")) + wrapped := fmt.Errorf("helm operation failed: %w", inner) + level, message, ok := classifyK8sError(wrapped, "helm install") + s.True(ok) + s.Equal(LevelError, level) + s.Contains(message, "Operation failed") + }) +} + +func (s *K8sErrorSuite) TestHandleK8sErrorDoesNotPanic() { + ctx := context.Background() + + s.Run("nil error", func() { + s.NotPanics(func() { + HandleK8sError(ctx, nil, "any operation") + }) + }) + + s.Run("K8s error without session in context", func() { + s.NotPanics(func() { + HandleK8sError(ctx, apierrors.NewNotFound(schema.GroupResource{Resource: "pods"}, "test"), "pod access") + }) + }) + + s.Run("non-K8s error without session in context", func() { + s.NotPanics(func() { + HandleK8sError(ctx, fmt.Errorf("some error"), "operation") + }) + }) +} + +func TestK8sError(t *testing.T) { + suite.Run(t, new(K8sErrorSuite)) +} diff --git a/pkg/toolsets/core/error_handling_test.go b/pkg/toolsets/core/error_handling_test.go deleted file mode 100644 index 26759b802..000000000 --- a/pkg/toolsets/core/error_handling_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package core - -import ( - "context" - "fmt" - "testing" - - apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/runtime/schema" - - "github.com/containers/kubernetes-mcp-server/pkg/mcplog" - "github.com/stretchr/testify/suite" -) - -type ErrorHandlingSuite struct { - suite.Suite -} - -func (s *ErrorHandlingSuite) TestHandleK8sErrorIntegration() { - ctx := context.Background() - gr := schema.GroupResource{Group: "v1", Resource: "pods"} - - s.Run("handles NotFound errors", func() { - err := apierrors.NewNotFound(gr, "test-pod") - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "pod access") - }) - }) - - s.Run("handles Forbidden errors", func() { - err := apierrors.NewForbidden(gr, "test-pod", nil) - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "pod deletion") - }) - }) - - s.Run("handles Unauthorized errors", func() { - err := apierrors.NewUnauthorized("unauthorized") - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "resource access") - }) - }) - - s.Run("handles AlreadyExists errors", func() { - err := apierrors.NewAlreadyExists(gr, "test-resource") - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "resource creation") - }) - }) - - s.Run("handles Invalid errors", func() { - err := apierrors.NewInvalid(schema.GroupKind{Group: "v1", Kind: "Pod"}, "test-pod", nil) - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "resource creation or update") - }) - }) - - s.Run("handles BadRequest errors", func() { - err := apierrors.NewBadRequest("bad request") - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "resource scaling") - }) - }) - - s.Run("handles Conflict errors", func() { - err := apierrors.NewConflict(gr, "test-resource", nil) - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "resource update") - }) - }) - - s.Run("handles Timeout errors", func() { - err := apierrors.NewTimeoutError("request timeout", 30) - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "node log access") - }) - }) - - s.Run("handles ServerTimeout errors", func() { - err := apierrors.NewServerTimeout(gr, "operation", 60) - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "node stats access") - }) - }) - - s.Run("handles ServiceUnavailable errors", func() { - err := apierrors.NewServiceUnavailable("service unavailable") - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "events listing") - }) - }) - - s.Run("handles TooManyRequests errors", func() { - err := apierrors.NewTooManyRequests("rate limited", 10) - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "namespace listing") - }) - }) - - s.Run("handles generic errors", func() { - err := apierrors.NewInternalError(fmt.Errorf("internal server error")) - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, err, "node metrics access") - }) - }) - - s.Run("handles nil error gracefully", func() { - s.NotPanics(func() { - mcplog.HandleK8sError(ctx, nil, "any operation") - }) - }) -} - -func (s *ErrorHandlingSuite) TestErrorHandlingCoverage() { - s.Run("error handling is consistent across handlers", func() { - handlers := []string{ - "podsGet - pod access", - "podsDelete - pod deletion", - "resourcesList - resource listing", - "resourcesGet - resource access", - "resourcesCreateOrUpdate - resource creation or update", - "resourcesDelete - resource deletion", - "resourcesScale - resource scaling", - "nodesLog - node log access", - "nodesStatsSummary - node stats access", - "nodesTop - node metrics access, node listing", - "eventsList - events listing", - "namespacesList - namespace listing", - "projectsList - project listing", - } - - s.GreaterOrEqual(len(handlers), 13, "should document all error handling points") - }) -} - -func (s *ErrorHandlingSuite) TestOperationDescriptions() { - s.Run("operation descriptions follow naming conventions", func() { - validDescriptions := []string{ - "pod access", - "pod deletion", - "resource listing", - "resource access", - "resource creation or update", - "resource deletion", - "resource scaling", - "node log access", - "node stats access", - "node metrics access", - "node listing", - "events listing", - "namespace listing", - "project listing", - } - - for _, desc := range validDescriptions { - s.NotEmpty(desc, "description should not be empty") - s.Equal(desc, desc, "description should be lowercase: %s", desc) - } - }) -} - -func TestErrorHandling(t *testing.T) { - suite.Run(t, new(ErrorHandlingSuite)) -} diff --git a/pkg/toolsets/core/events.go b/pkg/toolsets/core/events.go index 31beb72b7..72af6c950 100644 --- a/pkg/toolsets/core/events.go +++ b/pkg/toolsets/core/events.go @@ -8,7 +8,6 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/api" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" - "github.com/containers/kubernetes-mcp-server/pkg/mcplog" "github.com/containers/kubernetes-mcp-server/pkg/output" ) @@ -43,7 +42,6 @@ func eventsList(params api.ToolHandlerParams) (*api.ToolCallResult, error) { } eventMap, err := kubernetes.NewCore(params).EventsList(params, namespace.(string)) if err != nil { - mcplog.HandleK8sError(params.Context, err, "events listing") return api.NewToolCallResult("", fmt.Errorf("failed to list events in all namespaces: %w", err)), nil } if len(eventMap) == 0 { diff --git a/pkg/toolsets/core/namespaces.go b/pkg/toolsets/core/namespaces.go index d4851991c..1538cbe0e 100644 --- a/pkg/toolsets/core/namespaces.go +++ b/pkg/toolsets/core/namespaces.go @@ -9,7 +9,6 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/api" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" - "github.com/containers/kubernetes-mcp-server/pkg/mcplog" ) func initNamespaces(o api.Openshift) []api.ServerTool { @@ -52,7 +51,6 @@ func initNamespaces(o api.Openshift) []api.ServerTool { func namespacesList(params api.ToolHandlerParams) (*api.ToolCallResult, error) { ret, err := kubernetes.NewCore(params).NamespacesList(params, api.ListOptions{AsTable: params.ListOutput.AsTable()}) if err != nil { - mcplog.HandleK8sError(params.Context, err, "namespace listing") return api.NewToolCallResult("", fmt.Errorf("failed to list namespaces: %w", err)), nil } return api.NewToolCallResult(params.ListOutput.PrintObj(ret)), nil @@ -61,7 +59,6 @@ func namespacesList(params api.ToolHandlerParams) (*api.ToolCallResult, error) { func projectsList(params api.ToolHandlerParams) (*api.ToolCallResult, error) { ret, err := kubernetes.NewCore(params).ProjectsList(params, api.ListOptions{AsTable: params.ListOutput.AsTable()}) if err != nil { - mcplog.HandleK8sError(params.Context, err, "project listing") return api.NewToolCallResult("", fmt.Errorf("failed to list projects: %w", err)), nil } return api.NewToolCallResult(params.ListOutput.PrintObj(ret)), nil diff --git a/pkg/toolsets/core/nodes.go b/pkg/toolsets/core/nodes.go index 3f91a5003..4c5a3d99e 100644 --- a/pkg/toolsets/core/nodes.go +++ b/pkg/toolsets/core/nodes.go @@ -14,7 +14,6 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/api" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" - "github.com/containers/kubernetes-mcp-server/pkg/mcplog" ) func initNodes() []api.ServerTool { @@ -117,7 +116,6 @@ func nodesLog(params api.ToolHandlerParams) (*api.ToolCallResult, error) { } ret, err := kubernetes.NewCore(params).NodesLog(params, name, query, tailInt) if err != nil { - mcplog.HandleK8sError(params.Context, err, "node log access") return api.NewToolCallResult("", fmt.Errorf("failed to get node log for %s: %w", name, err)), nil } else if ret == "" { ret = fmt.Sprintf("The node %s has not logged any message yet or the log file is empty", name) @@ -132,7 +130,6 @@ func nodesStatsSummary(params api.ToolHandlerParams) (*api.ToolCallResult, error } ret, err := kubernetes.NewCore(params).NodesStatsSummary(params, name) if err != nil { - mcplog.HandleK8sError(params.Context, err, "node stats access") return api.NewToolCallResult("", fmt.Errorf("failed to get node stats summary for %s: %w", name, err)), nil } return api.NewToolCallResult(ret, nil), nil @@ -149,7 +146,6 @@ func nodesTop(params api.ToolHandlerParams) (*api.ToolCallResult, error) { nodeMetrics, err := kubernetes.NewCore(params).NodesTop(params, nodesTopOptions) if err != nil { - mcplog.HandleK8sError(params.Context, err, "node metrics access") return api.NewToolCallResult("", fmt.Errorf("failed to get nodes top: %w", err)), nil } @@ -158,7 +154,6 @@ func nodesTop(params api.ToolHandlerParams) (*api.ToolCallResult, error) { LabelSelector: nodesTopOptions.LabelSelector, }) if err != nil { - mcplog.HandleK8sError(params.Context, err, "node listing") return api.NewToolCallResult("", fmt.Errorf("failed to list nodes: %w", err)), nil } diff --git a/pkg/toolsets/core/pods.go b/pkg/toolsets/core/pods.go index a69e8a44b..09e4e959a 100644 --- a/pkg/toolsets/core/pods.go +++ b/pkg/toolsets/core/pods.go @@ -11,7 +11,6 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/api" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" - "github.com/containers/kubernetes-mcp-server/pkg/mcplog" "github.com/containers/kubernetes-mcp-server/pkg/output" ) @@ -274,7 +273,6 @@ func podsListInAllNamespaces(params api.ToolHandlerParams) (*api.ToolCallResult, } ret, err := kubernetes.NewCore(params).PodsListInAllNamespaces(params, resourceListOptions) if err != nil { - mcplog.HandleK8sError(params.Context, err, "pod listing") return api.NewToolCallResult("", fmt.Errorf("failed to list pods in all namespaces: %w", err)), nil } return api.NewToolCallResult(params.ListOutput.PrintObj(ret)), nil @@ -298,7 +296,6 @@ func podsListInNamespace(params api.ToolHandlerParams) (*api.ToolCallResult, err } ret, err := kubernetes.NewCore(params).PodsListInNamespace(params, ns.(string), resourceListOptions) if err != nil { - mcplog.HandleK8sError(params.Context, err, "pod listing") return api.NewToolCallResult("", fmt.Errorf("failed to list pods in namespace %s: %w", ns, err)), nil } return api.NewToolCallResult(params.ListOutput.PrintObj(ret)), nil @@ -315,7 +312,6 @@ func podsGet(params api.ToolHandlerParams) (*api.ToolCallResult, error) { } ret, err := kubernetes.NewCore(params).PodsGet(params, ns.(string), name.(string)) if err != nil { - mcplog.HandleK8sError(params.Context, err, "pod access") return api.NewToolCallResult("", fmt.Errorf("failed to get pod %s in namespace %s: %w", name, ns, err)), nil } return api.NewToolCallResult(output.MarshalYaml(ret)), nil @@ -332,7 +328,6 @@ func podsDelete(params api.ToolHandlerParams) (*api.ToolCallResult, error) { } ret, err := kubernetes.NewCore(params).PodsDelete(params, ns.(string), name.(string)) if err != nil { - mcplog.HandleK8sError(params.Context, err, "pod deletion") return api.NewToolCallResult("", fmt.Errorf("failed to delete pod %s in namespace %s: %w", name, ns, err)), nil } return api.NewToolCallResult(ret, err), nil diff --git a/pkg/toolsets/core/resources.go b/pkg/toolsets/core/resources.go index 22fe3dc65..cfefdd396 100644 --- a/pkg/toolsets/core/resources.go +++ b/pkg/toolsets/core/resources.go @@ -11,7 +11,6 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/api" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" - "github.com/containers/kubernetes-mcp-server/pkg/mcplog" "github.com/containers/kubernetes-mcp-server/pkg/output" ) @@ -224,7 +223,6 @@ func resourcesList(params api.ToolHandlerParams) (*api.ToolCallResult, error) { ret, err := kubernetes.NewCore(params).ResourcesList(params, gvk, ns, resourceListOptions) if err != nil { - mcplog.HandleK8sError(params.Context, err, "resource listing") return api.NewToolCallResult("", fmt.Errorf("failed to list resources: %w", err)), nil } return api.NewToolCallResult(params.ListOutput.PrintObj(ret)), nil @@ -256,7 +254,6 @@ func resourcesGet(params api.ToolHandlerParams) (*api.ToolCallResult, error) { ret, err := kubernetes.NewCore(params).ResourcesGet(params, gvk, ns, n) if err != nil { - mcplog.HandleK8sError(params.Context, err, "resource access") return api.NewToolCallResult("", fmt.Errorf("failed to get resource: %w", err)), nil } return api.NewToolCallResult(output.MarshalYaml(ret)), nil @@ -275,7 +272,6 @@ func resourcesCreateOrUpdate(params api.ToolHandlerParams) (*api.ToolCallResult, resources, err := kubernetes.NewCore(params).ResourcesCreateOrUpdate(params, r) if err != nil { - mcplog.HandleK8sError(params.Context, err, "resource creation or update") return api.NewToolCallResult("", fmt.Errorf("failed to create or update resources: %w", err)), nil } marshalledYaml, err := output.MarshalYaml(resources) @@ -320,7 +316,6 @@ func resourcesDelete(params api.ToolHandlerParams) (*api.ToolCallResult, error) err = kubernetes.NewCore(params).ResourcesDelete(params, gvk, ns, n, gracePeriodSecondsPtr) if err != nil { - mcplog.HandleK8sError(params.Context, err, "resource deletion") return api.NewToolCallResult("", fmt.Errorf("failed to delete resource: %w", err)), nil } return api.NewToolCallResult("Resource deleted successfully", err), nil @@ -363,7 +358,6 @@ func resourcesScale(params api.ToolHandlerParams) (*api.ToolCallResult, error) { scale, err := kubernetes.NewCore(params).ResourcesScale(params.Context, gvk, ns, n, desiredScale, shouldScale) if err != nil { - mcplog.HandleK8sError(params.Context, err, "resource scaling") return api.NewToolCallResult("", fmt.Errorf("failed to get/update resource scale: %w", err)), nil } diff --git a/pkg/toolsets/helm/helm.go b/pkg/toolsets/helm/helm.go index 7f8360d5a..4887c5b93 100644 --- a/pkg/toolsets/helm/helm.go +++ b/pkg/toolsets/helm/helm.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/containers/kubernetes-mcp-server/pkg/helm" - "github.com/containers/kubernetes-mcp-server/pkg/mcplog" "github.com/google/jsonschema-go/jsonschema" "k8s.io/utils/ptr" @@ -116,7 +115,6 @@ func helmInstall(params api.ToolHandlerParams) (*api.ToolCallResult, error) { } ret, err := helm.NewHelm(params).Install(params, chart, values, name, namespace) if err != nil { - mcplog.HandleK8sError(params.Context, err, "helm install") return api.NewToolCallResult("", fmt.Errorf("failed to install helm chart '%s': %w", chart, err)), nil } return api.NewToolCallResult(ret, err), nil @@ -133,7 +131,6 @@ func helmList(params api.ToolHandlerParams) (*api.ToolCallResult, error) { } ret, err := helm.NewHelm(params).List(namespace, allNamespaces) if err != nil { - mcplog.HandleK8sError(params.Context, err, "helm list") return api.NewToolCallResult("", fmt.Errorf("failed to list helm releases in namespace '%s': %w", namespace, err)), nil } return api.NewToolCallResult(ret, err), nil @@ -151,7 +148,6 @@ func helmUninstall(params api.ToolHandlerParams) (*api.ToolCallResult, error) { } ret, err := helm.NewHelm(params).Uninstall(name, namespace) if err != nil { - mcplog.HandleK8sError(params.Context, err, "helm uninstall") return api.NewToolCallResult("", fmt.Errorf("failed to uninstall helm chart '%s': %w", name, err)), nil } return api.NewToolCallResult(ret, err), nil From 128b37c195b0bae13fbd46fbad3d69acc44e583d Mon Sep 17 00:00:00 2001 From: Nader Ziada Date: Fri, 20 Feb 2026 00:23:42 -0500 Subject: [PATCH 2/3] fix(ci): resolve race conditions in tests and MCP server (#793) - Add mutex protection to shared state accessed by concurrent goroutines - Fix SIGHUP handler goroutine leak by returning a stop function that properly cleans up signal notification and waits for goroutine exit Signed-off-by: Nader Ziada --- internal/test/mock_server.go | 44 ++++++++++++++++++- pkg/http/authorization_mcp_test.go | 3 +- pkg/http/http_test.go | 3 +- pkg/kubernetes-mcp-server/cmd/root.go | 15 +++++-- .../cmd/root_sighup_test.go | 17 +++++-- pkg/mcp/helm_test.go | 4 +- pkg/mcp/mcp.go | 30 ++++++++++--- pkg/mcp/mcp_middleware_test.go | 3 +- 8 files changed, 97 insertions(+), 22 deletions(-) diff --git a/internal/test/mock_server.go b/internal/test/mock_server.go index 1c5d6ba81..37ceaf143 100644 --- a/internal/test/mock_server.go +++ b/internal/test/mock_server.go @@ -1,6 +1,7 @@ package test import ( + "bytes" "encoding/json" "errors" "io" @@ -8,6 +9,7 @@ import ( "net/http/httptest" "path/filepath" "strings" + "sync" "testing" "github.com/stretchr/testify/require" @@ -24,6 +26,7 @@ import ( ) type MockServer struct { + mu sync.RWMutex server *httptest.Server config *rest.Config restHandlers []http.HandlerFunc @@ -34,7 +37,10 @@ func NewMockServer() *MockServer { scheme := runtime.NewScheme() codecs := serializer.NewCodecFactory(scheme) ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - for _, handler := range ms.restHandlers { + ms.mu.RLock() + handlers := ms.restHandlers + ms.mu.RUnlock() + for _, handler := range handlers { handler(w, req) } })) @@ -58,10 +64,14 @@ func (m *MockServer) Close() { } func (m *MockServer) Handle(handler http.Handler) { + m.mu.Lock() + defer m.mu.Unlock() m.restHandlers = append(m.restHandlers, handler.ServeHTTP) } func (m *MockServer) ResetHandlers() { + m.mu.Lock() + defer m.mu.Unlock() m.restHandlers = make([]http.HandlerFunc, 0) } @@ -189,6 +199,7 @@ WaitForStreams: } type DiscoveryClientHandler struct { + mu sync.RWMutex // APIResourceLists defines all API groups and their resources. // The handler automatically generates /api, /apis, and /apis// endpoints. APIResourceLists []metav1.APIResourceList @@ -222,6 +233,9 @@ func NewDiscoveryClientHandler(additionalResources ...metav1.APIResourceList) *D func (h *DiscoveryClientHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-Type", "application/json") + h.mu.RLock() + defer h.mu.RUnlock() + // Request Performed by DiscoveryClient to Kube API (Get API Groups legacy -core-) if req.URL.Path == "/api" { WriteObject(w, &metav1.APIVersions{ @@ -289,6 +303,8 @@ func parseGroupVersion(gv string) (group, version string) { // AddAPIResourceList adds an API resource list to the handler. // This is useful for dynamically modifying the handler during tests. func (h *DiscoveryClientHandler) AddAPIResourceList(resourceList metav1.APIResourceList) { + h.mu.Lock() + defer h.mu.Unlock() h.APIResourceLists = append(h.APIResourceLists, resourceList) } @@ -313,3 +329,29 @@ func NewInOpenShiftHandler(additionalResources ...metav1.APIResourceList) *Disco openShiftResources = append(openShiftResources, additionalResources...) return NewDiscoveryClientHandler(openShiftResources...) } + +// SyncBuffer is a thread-safe wrapper around bytes.Buffer. +// Use this for test log buffers to avoid race conditions when multiple +// goroutines write to the logger concurrently. +type SyncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *SyncBuffer) Write(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *SyncBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + +func (b *SyncBuffer) Reset() { + b.mu.Lock() + defer b.mu.Unlock() + b.buf.Reset() +} diff --git a/pkg/http/authorization_mcp_test.go b/pkg/http/authorization_mcp_test.go index 96e313136..a57b714a4 100644 --- a/pkg/http/authorization_mcp_test.go +++ b/pkg/http/authorization_mcp_test.go @@ -1,7 +1,6 @@ package http import ( - "bytes" "flag" "fmt" "net/http" @@ -25,7 +24,7 @@ type AuthorizationSuite struct { BaseHttpSuite mcpClient *client.Client klogState klog.State - logBuffer bytes.Buffer + logBuffer test.SyncBuffer } func (s *AuthorizationSuite) SetupTest() { diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 1360ecc81..ec356b285 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -1,7 +1,6 @@ package http import ( - "bytes" "context" "crypto/rand" "crypto/rsa" @@ -84,7 +83,7 @@ func (s *BaseHttpSuite) TearDownTest() { type httpContext struct { klogState klog.State mockServer *test.MockServer - LogBuffer bytes.Buffer + LogBuffer test.SyncBuffer HttpAddress string // HTTP server address timeoutCancel context.CancelFunc // Release resources if test completes before the timeout StopServer context.CancelFunc diff --git a/pkg/kubernetes-mcp-server/cmd/root.go b/pkg/kubernetes-mcp-server/cmd/root.go index 247b735bc..046e22b39 100644 --- a/pkg/kubernetes-mcp-server/cmd/root.go +++ b/pkg/kubernetes-mcp-server/cmd/root.go @@ -380,7 +380,7 @@ func (m *MCPServerOptions) Run() error { // Set up SIGHUP handler for configuration reload if m.ConfigPath != "" || m.ConfigDir != "" { - m.setupSIGHUPHandler(mcpServer) + _ = m.setupSIGHUPHandler(mcpServer) } if m.StaticConfig.Port != "" { @@ -397,12 +397,15 @@ func (m *MCPServerOptions) Run() error { } // setupSIGHUPHandler sets up a signal handler to reload configuration on SIGHUP. -// This is a blocking call that runs in a separate goroutine. -func (m *MCPServerOptions) setupSIGHUPHandler(mcpServer *mcp.Server) { +// Returns a stop function that should be called to clean up the handler. +// The stop function waits for the handler goroutine to finish. +func (m *MCPServerOptions) setupSIGHUPHandler(mcpServer *mcp.Server) (stop func()) { sigHupCh := make(chan os.Signal, 1) + done := make(chan struct{}) signal.Notify(sigHupCh, syscall.SIGHUP) go func() { + defer close(done) for range sigHupCh { klog.V(1).Info("Received SIGHUP signal, reloading configuration...") @@ -424,4 +427,10 @@ func (m *MCPServerOptions) setupSIGHUPHandler(mcpServer *mcp.Server) { }() klog.V(2).Info("SIGHUP handler registered for configuration reload") + + return func() { + signal.Stop(sigHupCh) + close(sigHupCh) + <-done // Wait for goroutine to finish + } } diff --git a/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go b/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go index a2cf2376c..5d86966ad 100644 --- a/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go +++ b/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go @@ -3,7 +3,6 @@ package cmd import ( - "bytes" "os" "path/filepath" "slices" @@ -28,7 +27,9 @@ type SIGHUPSuite struct { server *mcp.Server tempDir string dropInConfigDir string - logBuffer *bytes.Buffer + logBuffer *test.SyncBuffer + klogState klog.State + stopSIGHUP func() } func (s *SIGHUPSuite) SetupTest() { @@ -38,19 +39,27 @@ func (s *SIGHUPSuite) SetupTest() { s.dropInConfigDir = filepath.Join(s.tempDir, "conf.d") s.Require().NoError(os.Mkdir(s.dropInConfigDir, 0755)) + // Capture klog state so we can restore it after the test + s.klogState = klog.CaptureState() + // Set up klog to write to our buffer so we can verify log messages - s.logBuffer = &bytes.Buffer{} + s.logBuffer = &test.SyncBuffer{} logger := textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(2), textlogger.Output(s.logBuffer))) klog.SetLoggerWithOptions(logger) } func (s *SIGHUPSuite) TearDownTest() { + // Stop the SIGHUP handler goroutine before restoring klog + if s.stopSIGHUP != nil { + s.stopSIGHUP() + } if s.server != nil { s.server.Close() } if s.mockServer != nil { s.mockServer.Close() } + s.klogState.Restore() } func (s *SIGHUPSuite) InitServer(configPath, configDir string) { @@ -69,7 +78,7 @@ func (s *SIGHUPSuite) InitServer(configPath, configDir string) { ConfigPath: configPath, ConfigDir: configDir, } - opts.setupSIGHUPHandler(s.server) + s.stopSIGHUP = opts.setupSIGHUPHandler(s.server) } func (s *SIGHUPSuite) TestSIGHUPReloadsConfigFromFile() { diff --git a/pkg/mcp/helm_test.go b/pkg/mcp/helm_test.go index 7ce42c022..04e3a6ecc 100644 --- a/pkg/mcp/helm_test.go +++ b/pkg/mcp/helm_test.go @@ -1,7 +1,6 @@ package mcp import ( - "bytes" "context" "encoding/base64" "flag" @@ -13,6 +12,7 @@ import ( "time" "github.com/BurntSushi/toml" + "github.com/containers/kubernetes-mcp-server/internal/test" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/suite" corev1 "k8s.io/api/core/v1" @@ -27,7 +27,7 @@ import ( type HelmSuite struct { BaseMcpSuite klogState klog.State - logBuffer bytes.Buffer + logBuffer test.SyncBuffer } func (s *HelmSuite) SetupTest() { diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 1e5ac1b6d..8a6bbb54b 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "slices" + "sync" "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -61,6 +62,7 @@ func (c *Configuration) isToolApplicable(tool api.ServerTool) bool { } type Server struct { + mu sync.RWMutex configuration *Configuration server *mcp.Server enabledTools []string @@ -134,9 +136,15 @@ func (s *Server) reloadToolsets() error { applicableTools := s.collectApplicableTools(targets) applicablePrompts := s.collectApplicablePrompts() - // Reload tools, and track the newly enabled tools so that we can diff on reload to figure out which to remove (if any) - s.enabledTools, err = reloadItems( - s.enabledTools, + // Read the previous state with read lock - don't hold lock while calling external code + s.mu.RLock() + previousTools := s.enabledTools + previousPrompts := s.enabledPrompts + s.mu.RUnlock() + + // Reload tools (calls s.server.AddTool/RemoveTools - external code, no lock held) + newTools, err := reloadItems( + previousTools, applicableTools, func(t api.ServerTool) string { return t.Tool.Name }, s.server.RemoveTools, @@ -146,9 +154,9 @@ func (s *Server) reloadToolsets() error { return err } - // Reload prompts, and track the newly enabled prompts so that we can diff on reload to figure out which to remove (if any) - s.enabledPrompts, err = reloadItems( - s.enabledPrompts, + // Reload prompts (calls s.server.AddPrompt/RemovePrompts - external code, no lock held) + newPrompts, err := reloadItems( + previousPrompts, applicablePrompts, func(p api.ServerPrompt) string { return p.Prompt.Name }, s.server.RemovePrompts, @@ -158,6 +166,12 @@ func (s *Server) reloadToolsets() error { return err } + // Only hold write lock for the final assignment + s.mu.Lock() + s.enabledTools = newTools + s.enabledPrompts = newPrompts + s.mu.Unlock() + // Start new watch s.p.WatchTargets(s.reloadToolsets) return nil @@ -315,11 +329,15 @@ func (s *Server) GetTargetParameterName() string { } func (s *Server) GetEnabledTools() []string { + s.mu.RLock() + defer s.mu.RUnlock() return s.enabledTools } // GetEnabledPrompts returns the names of the currently enabled prompts func (s *Server) GetEnabledPrompts() []string { + s.mu.RLock() + defer s.mu.RUnlock() return s.enabledPrompts } diff --git a/pkg/mcp/mcp_middleware_test.go b/pkg/mcp/mcp_middleware_test.go index f5240e5ea..c8722b042 100644 --- a/pkg/mcp/mcp_middleware_test.go +++ b/pkg/mcp/mcp_middleware_test.go @@ -1,7 +1,6 @@ package mcp import ( - "bytes" "context" "flag" "regexp" @@ -22,7 +21,7 @@ import ( type McpLoggingSuite struct { BaseMcpSuite klogState klog.State - logBuffer bytes.Buffer + logBuffer test.SyncBuffer } func (s *McpLoggingSuite) SetupTest() { From 9a33b1064cb71de9538ace1f7a5e51811f02c7b3 Mon Sep 17 00:00:00 2001 From: Nader Ziada Date: Mon, 23 Feb 2026 05:19:05 -0500 Subject: [PATCH 3/3] feat(validation): add pre-execution validation layer (#764) * feat(validation): add pre-execution validation layer Add validation middleware that catches errors before they reach the Kubernetes API. Signed-off-by: Nader Ziada * simplify config and merge into AccessControlRoundTripper Signed-off-by: Nader Ziada * remove redundant ResourceValidator and simplify validation cleanup up unused func and fields Signed-off-by: Nader Ziada --------- Signed-off-by: Nader Ziada --- README.md | 134 +++++++-------- docs/VALIDATION.md | 120 +++++++++++++ docs/configuration.md | 2 +- pkg/api/config.go | 6 + pkg/api/validation.go | 81 +++++++++ pkg/config/config.go | 9 + pkg/kubernetes/accesscontrol_round_tripper.go | 158 ++++++++++++++++++ .../accesscontrol_round_tripper_test.go | 3 +- pkg/kubernetes/auth.go | 54 ++++++ pkg/kubernetes/kubernetes.go | 19 ++- pkg/kubernetes/rbac_validator.go | 52 ++++++ pkg/kubernetes/rbac_validator_test.go | 155 +++++++++++++++++ pkg/kubernetes/resources.go | 18 +- pkg/kubernetes/schema_validator.go | 130 ++++++++++++++ pkg/kubernetes/validator_registry.go | 41 +++++ pkg/kubernetes/validator_registry_test.go | 49 ++++++ pkg/mcp/mcp.go | 1 + 17 files changed, 940 insertions(+), 92 deletions(-) create mode 100644 docs/VALIDATION.md create mode 100644 pkg/api/validation.go create mode 100644 pkg/kubernetes/auth.go create mode 100644 pkg/kubernetes/rbac_validator.go create mode 100644 pkg/kubernetes/rbac_validator_test.go create mode 100644 pkg/kubernetes/schema_validator.go create mode 100644 pkg/kubernetes/validator_registry.go create mode 100644 pkg/kubernetes/validator_registry_test.go diff --git a/README.md b/README.md index cb4a64630..8623888f8 100644 --- a/README.md +++ b/README.md @@ -251,10 +251,10 @@ The following sets of tools are available (toolsets marked with ✓ in the Defau | Toolset | Description | Default | |----------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------| +| kiali | Most common tools for managing Kiali, check the [Kiali documentation](https://github.com/containers/kubernetes-mcp-server/blob/main/docs/KIALI.md) for more details. | | | config | View and manage the current local Kubernetes configuration (kubeconfig) | ✓ | | core | Most common tools for Kubernetes management (Pods, Generic Resources, Events, etc.) | ✓ | | kcp | Manage kcp workspaces and multi-tenancy features | | -| kiali | Most common tools for managing Kiali, check the [Kiali documentation](https://github.com/containers/kubernetes-mcp-server/blob/main/docs/KIALI.md) for more details. | | | kubevirt | KubeVirt virtual machine management tools | | | helm | Tools for managing Helm charts and releases | ✓ | @@ -268,6 +268,72 @@ In case multi-cluster support is enabled (default) and you have access to multip
+kiali + +- **kiali_mesh_graph** - Returns the topology of a specific namespaces, health, status of the mesh and namespaces. Includes a mesh health summary overview with aggregated counts of healthy, degraded, and failing apps, workloads, and services. Use this for high-level overviews + - `graphType` (`string`) - Optional type of graph to return: 'versionedApp', 'app', 'service', 'workload', 'mesh' + - `namespace` (`string`) - Optional single namespace to include in the graph (alternative to namespaces) + - `namespaces` (`string`) - Optional comma-separated list of namespaces to include in the graph + - `rateInterval` (`string`) - Optional rate interval for fetching (e.g., '10m', '5m', '1h'). + +- **kiali_manage_istio_config_read** - Lists or gets Istio configuration objects (Gateways, VirtualServices, etc.) + - `action` (`string`) **(required)** - Action to perform: list or get + - `group` (`string`) - API group of the Istio object (e.g., 'networking.istio.io', 'gateway.networking.k8s.io') + - `kind` (`string`) - Kind of the Istio object (e.g., 'DestinationRule', 'VirtualService', 'HTTPRoute', 'Gateway') + - `name` (`string`) - Name of the Istio object + - `namespace` (`string`) - Namespace containing the Istio object + - `version` (`string`) - API version of the Istio object (e.g., 'v1', 'v1beta1') + +- **kiali_manage_istio_config** - Creates, patches, or deletes Istio configuration objects (Gateways, VirtualServices, etc.) + - `action` (`string`) **(required)** - Action to perform: create, patch, or delete + - `group` (`string`) - API group of the Istio object (e.g., 'networking.istio.io', 'gateway.networking.k8s.io') + - `json_data` (`string`) - JSON data to apply or create the object + - `kind` (`string`) - Kind of the Istio object (e.g., 'DestinationRule', 'VirtualService', 'HTTPRoute', 'Gateway') + - `name` (`string`) - Name of the Istio object + - `namespace` (`string`) - Namespace containing the Istio object + - `version` (`string`) - API version of the Istio object (e.g., 'v1', 'v1beta1') + +- **kiali_get_resource_details** - Gets lists or detailed info for Kubernetes resources (services, workloads) within the mesh + - `namespaces` (`string`) - Comma-separated list of namespaces to get services from (e.g. 'bookinfo' or 'bookinfo,default'). If not provided, will list services from all accessible namespaces + - `resource_name` (`string`) - Name of the resource to get details for (optional string - if provided, gets details; if empty, lists all). + - `resource_type` (`string`) - Type of resource to get details for (service, workload) + +- **kiali_get_metrics** - Gets lists or detailed info for Kubernetes resources (services, workloads) within the mesh + - `byLabels` (`string`) - Comma-separated list of labels to group metrics by (e.g., 'source_workload,destination_service'). Optional + - `direction` (`string`) - Traffic direction: 'inbound' or 'outbound'. Optional, defaults to 'outbound' + - `duration` (`string`) - Time range to get metrics for (optional string - if provided, gets metrics (e.g., '1m', '5m', '1h'); if empty, get default 30m). + - `namespace` (`string`) **(required)** - Namespace to get resources from + - `quantiles` (`string`) - Comma-separated list of quantiles for histogram metrics (e.g., '0.5,0.95,0.99'). Optional + - `rateInterval` (`string`) - Rate interval for metrics (e.g., '1m', '5m'). Optional, defaults to '10m' + - `reporter` (`string`) - Metrics reporter: 'source', 'destination', or 'both'. Optional, defaults to 'source' + - `requestProtocol` (`string`) - Filter by request protocol (e.g., 'http', 'grpc', 'tcp'). Optional + - `resource_name` (`string`) **(required)** - Name of the resource to get details for (optional string - if provided, gets details; if empty, lists all). + - `resource_type` (`string`) **(required)** - Type of resource to get details for (service, workload) + - `step` (`string`) - Step between data points in seconds (e.g., '15'). Optional, defaults to 15 seconds + +- **kiali_workload_logs** - Get logs for a specific workload's pods in a namespace. Only requires namespace and workload name - automatically discovers pods and containers. Optionally filter by container name, time range, and other parameters. Container is auto-detected if not specified. + - `container` (`string`) - Optional container name to filter logs. If not provided, automatically detects and uses the main application container (excludes istio-proxy and istio-init) + - `namespace` (`string`) **(required)** - Namespace containing the workload + - `since` (`string`) - Time duration to fetch logs from (e.g., '5m', '1h', '30s'). If not provided, returns recent logs + - `tail` (`integer`) - Number of lines to retrieve from the end of logs (default: 100) + - `workload` (`string`) **(required)** - Name of the workload to get logs for + +- **kiali_get_traces** - Gets traces for a specific resource (app, service, workload) in a namespace, or gets detailed information for a specific trace by its ID. If traceId is provided, it returns detailed trace information and other parameters are not required. + - `clusterName` (`string`) - Cluster name for multi-cluster environments (optional, only used when traceId is not provided) + - `endMicros` (`string`) - End time for traces in microseconds since epoch (optional, defaults to 10 minutes after startMicros if not provided, only used when traceId is not provided) + - `limit` (`integer`) - Maximum number of traces to return (default: 100, only used when traceId is not provided) + - `minDuration` (`integer`) - Minimum trace duration in microseconds (optional, only used when traceId is not provided) + - `namespace` (`string`) - Namespace to get resources from. Required if traceId is not provided. + - `resource_name` (`string`) - Name of the resource to get traces for. Required if traceId is not provided. + - `resource_type` (`string`) - Type of resource to get traces for (app, service, workload). Required if traceId is not provided. + - `startMicros` (`string`) - Start time for traces in microseconds since epoch (optional, defaults to 10 minutes before current time if not provided, only used when traceId is not provided) + - `tags` (`string`) - JSON string of tags to filter traces (optional, only used when traceId is not provided) + - `traceId` (`string`) - Unique identifier of the trace to retrieve detailed information for. If provided, this will return detailed trace information and other parameters (resource_type, namespace, resource_name) are not required. + +
+ +
+ config - **configuration_contexts_list** - List all available context names and associated server urls from the kubeconfig file @@ -393,72 +459,6 @@ In case multi-cluster support is enabled (default) and you have access to multip
-kiali - -- **kiali_mesh_graph** - Returns the topology of a specific namespaces, health, status of the mesh and namespaces. Includes a mesh health summary overview with aggregated counts of healthy, degraded, and failing apps, workloads, and services. Use this for high-level overviews - - `graphType` (`string`) - Optional type of graph to return: 'versionedApp', 'app', 'service', 'workload', 'mesh' - - `namespace` (`string`) - Optional single namespace to include in the graph (alternative to namespaces) - - `namespaces` (`string`) - Optional comma-separated list of namespaces to include in the graph - - `rateInterval` (`string`) - Optional rate interval for fetching (e.g., '10m', '5m', '1h'). - -- **kiali_manage_istio_config_read** - Lists or gets Istio configuration objects (Gateways, VirtualServices, etc.) - - `action` (`string`) **(required)** - Action to perform: list or get - - `group` (`string`) - API group of the Istio object (e.g., 'networking.istio.io', 'gateway.networking.k8s.io') - - `kind` (`string`) - Kind of the Istio object (e.g., 'DestinationRule', 'VirtualService', 'HTTPRoute', 'Gateway') - - `name` (`string`) - Name of the Istio object - - `namespace` (`string`) - Namespace containing the Istio object - - `version` (`string`) - API version of the Istio object (e.g., 'v1', 'v1beta1') - -- **kiali_manage_istio_config** - Creates, patches, or deletes Istio configuration objects (Gateways, VirtualServices, etc.) - - `action` (`string`) **(required)** - Action to perform: create, patch, or delete - - `group` (`string`) - API group of the Istio object (e.g., 'networking.istio.io', 'gateway.networking.k8s.io') - - `json_data` (`string`) - JSON data to apply or create the object - - `kind` (`string`) - Kind of the Istio object (e.g., 'DestinationRule', 'VirtualService', 'HTTPRoute', 'Gateway') - - `name` (`string`) - Name of the Istio object - - `namespace` (`string`) - Namespace containing the Istio object - - `version` (`string`) - API version of the Istio object (e.g., 'v1', 'v1beta1') - -- **kiali_get_resource_details** - Gets lists or detailed info for Kubernetes resources (services, workloads) within the mesh - - `namespaces` (`string`) - Comma-separated list of namespaces to get services from (e.g. 'bookinfo' or 'bookinfo,default'). If not provided, will list services from all accessible namespaces - - `resource_name` (`string`) - Name of the resource to get details for (optional string - if provided, gets details; if empty, lists all). - - `resource_type` (`string`) - Type of resource to get details for (service, workload) - -- **kiali_get_metrics** - Gets lists or detailed info for Kubernetes resources (services, workloads) within the mesh - - `byLabels` (`string`) - Comma-separated list of labels to group metrics by (e.g., 'source_workload,destination_service'). Optional - - `direction` (`string`) - Traffic direction: 'inbound' or 'outbound'. Optional, defaults to 'outbound' - - `duration` (`string`) - Time range to get metrics for (optional string - if provided, gets metrics (e.g., '1m', '5m', '1h'); if empty, get default 30m). - - `namespace` (`string`) **(required)** - Namespace to get resources from - - `quantiles` (`string`) - Comma-separated list of quantiles for histogram metrics (e.g., '0.5,0.95,0.99'). Optional - - `rateInterval` (`string`) - Rate interval for metrics (e.g., '1m', '5m'). Optional, defaults to '10m' - - `reporter` (`string`) - Metrics reporter: 'source', 'destination', or 'both'. Optional, defaults to 'source' - - `requestProtocol` (`string`) - Filter by request protocol (e.g., 'http', 'grpc', 'tcp'). Optional - - `resource_name` (`string`) **(required)** - Name of the resource to get details for (optional string - if provided, gets details; if empty, lists all). - - `resource_type` (`string`) **(required)** - Type of resource to get details for (service, workload) - - `step` (`string`) - Step between data points in seconds (e.g., '15'). Optional, defaults to 15 seconds - -- **kiali_workload_logs** - Get logs for a specific workload's pods in a namespace. Only requires namespace and workload name - automatically discovers pods and containers. Optionally filter by container name, time range, and other parameters. Container is auto-detected if not specified. - - `container` (`string`) - Optional container name to filter logs. If not provided, automatically detects and uses the main application container (excludes istio-proxy and istio-init) - - `namespace` (`string`) **(required)** - Namespace containing the workload - - `since` (`string`) - Time duration to fetch logs from (e.g., '5m', '1h', '30s'). If not provided, returns recent logs - - `tail` (`integer`) - Number of lines to retrieve from the end of logs (default: 100) - - `workload` (`string`) **(required)** - Name of the workload to get logs for - -- **kiali_get_traces** - Gets traces for a specific resource (app, service, workload) in a namespace, or gets detailed information for a specific trace by its ID. If traceId is provided, it returns detailed trace information and other parameters are not required. - - `clusterName` (`string`) - Cluster name for multi-cluster environments (optional, only used when traceId is not provided) - - `endMicros` (`string`) - End time for traces in microseconds since epoch (optional, defaults to 10 minutes after startMicros if not provided, only used when traceId is not provided) - - `limit` (`integer`) - Maximum number of traces to return (default: 100, only used when traceId is not provided) - - `minDuration` (`integer`) - Minimum trace duration in microseconds (optional, only used when traceId is not provided) - - `namespace` (`string`) - Namespace to get resources from. Required if traceId is not provided. - - `resource_name` (`string`) - Name of the resource to get traces for. Required if traceId is not provided. - - `resource_type` (`string`) - Type of resource to get traces for (app, service, workload). Required if traceId is not provided. - - `startMicros` (`string`) - Start time for traces in microseconds since epoch (optional, defaults to 10 minutes before current time if not provided, only used when traceId is not provided) - - `tags` (`string`) - JSON string of tags to filter traces (optional, only used when traceId is not provided) - - `traceId` (`string`) - Unique identifier of the trace to retrieve detailed information for. If provided, this will return detailed trace information and other parameters (resource_type, namespace, resource_name) are not required. - -
- -
- kubevirt - **vm_create** - Create a VirtualMachine in the cluster with the specified configuration, automatically resolving instance types, preferences, and container disk images. VM will be created in Halted state by default; use autostart parameter to start it immediately. diff --git a/docs/VALIDATION.md b/docs/VALIDATION.md new file mode 100644 index 000000000..003c289cc --- /dev/null +++ b/docs/VALIDATION.md @@ -0,0 +1,120 @@ +# Pre-Execution Validation + +The kubernetes-mcp-server includes a validation layer that catches errors before they reach the Kubernetes API. This prevents AI hallucinations (like typos in resource names) and permission issues from causing confusing failures. + +## Why Validation? + +When an AI assistant makes a Kubernetes API call with errors, the raw Kubernetes error messages can be cryptic: + +``` +the server doesn't have a resource type "Deploymnt" +``` + +With validation enabled, you get clearer feedback: + +``` +Resource apps/v1/Deploymnt does not exist in the cluster +``` + +The validation layer catches these types of issues: + +1. **Resource Existence** - Catches typos like "Deploymnt" instead of "Deployment" (checked in access control) +2. **Schema Validation** - Catches invalid fields like "spec.replcias" instead of "spec.replicas" +3. **RBAC Validation** - Pre-checks permissions before attempting operations + +## Configuration + +Validation is **disabled by default**. Schema and RBAC validators run together when enabled. Resource existence is always checked as part of access control. + +```toml +# Enable all validation (default: false) +validation_enabled = true +``` + +### Configuration Reference + +| TOML Field | Default | Description | +|------------|---------|-------------| +| `validation_enabled` | `false` | Enable/disable all validators | + +**Note:** The schema validator caches the OpenAPI schema for 15 minutes internally. + +## How It Works + +### Validation Flow + +Validation happens at the HTTP RoundTripper level, intercepting all Kubernetes API calls: + +``` +MCP Tool Call → Kubernetes Client → HTTP RoundTripper → Kubernetes API + ↓ + Access Control + - Check deny list + - Check resource exists + ↓ + Schema Validator (if enabled) + "Are the fields valid?" + ↓ + RBAC Validator (if enabled) + "Does the user have permission?" + ↓ + Forward to K8s API +``` + +This HTTP-layer approach ensures **all** Kubernetes API calls are validated, including those from plugins (KubeVirt, Kiali, Helm, etc.) - not just the core tools. + +If any validator fails, the request is rejected with a clear error message before reaching the Kubernetes API. + +### 1. Resource Existence (Access Control) + +The access control layer validates that the requested resource type exists in the cluster. This check runs regardless of whether validation is enabled. + +**What it catches:** +- Typos in Kind names: "Deploymnt" → should be "Deployment" +- Wrong API versions: "apps/v2" → should be "apps/v1" +- Non-existent custom resources + +**Example error:** +``` +RESOURCE_NOT_FOUND: Resource deployments.apps does not exist in the cluster +``` + +### 2. Schema Validation + +Validates resource manifests against the cluster's OpenAPI schema for create/update operations. + +**What it catches:** +- Invalid field names: "spec.replcias" → should be "spec.replicas" +- Wrong field types: string where integer expected +- Missing required fields + +**Example error:** +``` +INVALID_FIELD: unknown field "spec.replcias" +``` + +**Note:** Schema validation uses kubectl's validation library and caches the OpenAPI schema for 15 minutes. + +### 3. RBAC Validation + +Pre-checks permissions using Kubernetes `SelfSubjectAccessReview` before attempting operations. + +**What it catches:** +- Missing permissions: can't create Deployments in namespace X +- Cluster-scoped vs namespace-scoped mismatches +- Read-only access attempting writes + +**Example error:** +``` +PERMISSION_DENIED: Cannot create deployments.apps in namespace "production" +``` + +**Note:** RBAC validation uses the same credentials as the actual operation - either the server's service account or the user's token (when OAuth is enabled). + +## Error Codes + +| Code | Description | +|------|-------------| +| `RESOURCE_NOT_FOUND` | The requested resource type doesn't exist in the cluster | +| `INVALID_FIELD` | A field in the manifest doesn't exist or has the wrong type | +| `PERMISSION_DENIED` | RBAC denies the requested operation | diff --git a/docs/configuration.md b/docs/configuration.md index 1c24e5608..7e9feb18f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -184,10 +184,10 @@ Toolsets group related tools together. Enable only the toolsets you need to redu | Toolset | Description | Default | |----------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------| +| kiali | Most common tools for managing Kiali, check the [Kiali documentation](https://github.com/containers/kubernetes-mcp-server/blob/main/docs/KIALI.md) for more details. | | | config | View and manage the current local Kubernetes configuration (kubeconfig) | ✓ | | core | Most common tools for Kubernetes management (Pods, Generic Resources, Events, etc.) | ✓ | | kcp | Manage kcp workspaces and multi-tenancy features | | -| kiali | Most common tools for managing Kiali, check the [Kiali documentation](https://github.com/containers/kubernetes-mcp-server/blob/main/docs/KIALI.md) for more details. | | | kubevirt | KubeVirt virtual machine management tools | | | helm | Tools for managing Helm charts and releases | ✓ | diff --git a/pkg/api/config.go b/pkg/api/config.go index 85c095cd8..929a6f3e6 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -53,10 +53,16 @@ type StsConfigProvider interface { GetStsScopes() []string } +// ValidationEnabledProvider provides access to validation enabled setting. +type ValidationEnabledProvider interface { + IsValidationEnabled() bool +} + type BaseConfig interface { AuthProvider ClusterProvider DeniedResourcesProvider ExtendedConfigProvider StsConfigProvider + ValidationEnabledProvider } diff --git a/pkg/api/validation.go b/pkg/api/validation.go new file mode 100644 index 000000000..36f1dcedd --- /dev/null +++ b/pkg/api/validation.go @@ -0,0 +1,81 @@ +package api + +import ( + "context" + "fmt" + "strings" + + "k8s.io/apimachinery/pkg/runtime/schema" +) + +// HTTPValidationRequest contains info extracted from an HTTP request for validation. +type HTTPValidationRequest struct { + GVR *schema.GroupVersionResource + GVK *schema.GroupVersionKind + HTTPMethod string // GET, POST, PUT, DELETE, PATCH + Verb string // get, list, create, update, delete, patch + Namespace string + ResourceName string + Body []byte // For create/update validation + Path string +} + +// HTTPValidator validates HTTP requests before they reach the K8s API server. +type HTTPValidator interface { + Validate(ctx context.Context, req *HTTPValidationRequest) error + Name() string +} + +// ValidationErrorCode categorizes validation failures. +type ValidationErrorCode string + +const ( + ErrorCodeResourceNotFound ValidationErrorCode = "RESOURCE_NOT_FOUND" + ErrorCodeInvalidField ValidationErrorCode = "INVALID_FIELD" + ErrorCodePermissionDenied ValidationErrorCode = "PERMISSION_DENIED" +) + +// ValidationError provides AI-friendly error information for validation failures. +type ValidationError struct { + Code ValidationErrorCode + Message string + Field string // optional, for field-level errors +} + +// Error implements the error interface. +func (e *ValidationError) Error() string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Validation Error [%s]: %s", e.Code, e.Message)) + + if e.Field != "" { + sb.WriteString(fmt.Sprintf("\n Field: %s", e.Field)) + } + + return sb.String() +} + +// NewPermissionDeniedError creates an error for RBAC permission failures. +func NewPermissionDeniedError(verb, resource, namespace string) *ValidationError { + var msg string + if namespace != "" { + msg = fmt.Sprintf("Cannot %s %s in namespace %q", verb, resource, namespace) + } else { + msg = fmt.Sprintf("Cannot %s %s (cluster-scoped)", verb, resource) + } + + return &ValidationError{ + Code: ErrorCodePermissionDenied, + Message: msg, + } +} + +// FormatResourceName creates a human-readable resource identifier from GVR. +func FormatResourceName(gvr *schema.GroupVersionResource) string { + if gvr == nil { + return "unknown" + } + if gvr.Group == "" { + return gvr.Resource + } + return gvr.Resource + "." + gvr.Group +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 668cb6dd2..a4f8e53a1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -90,6 +90,11 @@ type StaticConfig struct { // These can also be configured via OTEL_* environment variables. Telemetry TelemetryConfig `toml:"telemetry,omitempty"` + // ValidationEnabled enables pre-execution validation of tool calls. + // When enabled, validates resources, schemas, and RBAC before execution. + // Defaults to false. + ValidationEnabled bool `toml:"validation_enabled,omitempty"` + // Internal: parsed provider configs (not exposed to TOML package) parsedClusterProviderConfigs map[string]api.ExtendedConfig // Internal: parsed toolset configs (not exposed to TOML package) @@ -341,3 +346,7 @@ func (c *StaticConfig) GetStsAudience() string { func (c *StaticConfig) GetStsScopes() []string { return c.StsScopes } + +func (c *StaticConfig) IsValidationEnabled() bool { + return c.ValidationEnabled +} diff --git a/pkg/kubernetes/accesscontrol_round_tripper.go b/pkg/kubernetes/accesscontrol_round_tripper.go index 24bc513ee..a97db9adf 100644 --- a/pkg/kubernetes/accesscontrol_round_tripper.go +++ b/pkg/kubernetes/accesscontrol_round_tripper.go @@ -1,19 +1,57 @@ package kubernetes import ( + "bytes" "fmt" + "io" "net/http" "strings" "github.com/containers/kubernetes-mcp-server/pkg/api" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/discovery" + authv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" + "k8s.io/klog/v2" ) +// AccessControlRoundTripper intercepts HTTP requests to enforce access control +// and optionally run validators before they reach the Kubernetes API. type AccessControlRoundTripper struct { delegate http.RoundTripper deniedResourcesProvider api.DeniedResourcesProvider restMapperProvider func() meta.RESTMapper + validationEnabled bool + validators []api.HTTPValidator +} + +// AccessControlRoundTripperConfig configures the AccessControlRoundTripper. +type AccessControlRoundTripperConfig struct { + Delegate http.RoundTripper + DeniedResourcesProvider api.DeniedResourcesProvider + RestMapperProvider func() meta.RESTMapper + DiscoveryProvider func() discovery.DiscoveryInterface + AuthClientProvider func() authv1client.AuthorizationV1Interface + ValidationEnabled bool +} + +// NewAccessControlRoundTripper creates a new AccessControlRoundTripper. +func NewAccessControlRoundTripper(cfg AccessControlRoundTripperConfig) *AccessControlRoundTripper { + rt := &AccessControlRoundTripper{ + delegate: cfg.Delegate, + deniedResourcesProvider: cfg.DeniedResourcesProvider, + restMapperProvider: cfg.RestMapperProvider, + validationEnabled: cfg.ValidationEnabled, + } + + if cfg.ValidationEnabled { + rt.validators = CreateValidators(ValidatorProviders{ + Discovery: cfg.DiscoveryProvider, + AuthClient: cfg.AuthClientProvider, + }) + } + + return rt } func (rt *AccessControlRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -33,12 +71,56 @@ func (rt *AccessControlRoundTripper) RoundTrip(req *http.Request) (*http.Respons gvk, err := restMapper.KindFor(gvr) if err != nil { + if meta.IsNoMatchError(err) { + return nil, &api.ValidationError{ + Code: api.ErrorCodeResourceNotFound, + Message: fmt.Sprintf("Resource %s does not exist in the cluster", api.FormatResourceName(&gvr)), + } + } return nil, fmt.Errorf("failed to make request: AccessControlRoundTripper failed to get kind for gvr %v: %w", gvr, err) } if !rt.isAllowed(gvk) { return nil, fmt.Errorf("resource not allowed: %s", gvk.String()) } + // Skip validators if disabled or if this is SelfSubjectAccessReview (used by RBAC validator) + skipValidation := !rt.validationEnabled || (gvr.Group == "authorization.k8s.io" && gvr.Resource == "selfsubjectaccessreviews") + if skipValidation { + return rt.delegate.RoundTrip(req) + } + + namespace, resourceName := parseURLToNamespaceAndName(req.URL.Path) + verb := httpMethodToVerb(req.Method, req.URL.Path) + + validationReq := &api.HTTPValidationRequest{ + GVR: &gvr, + GVK: &gvk, + HTTPMethod: req.Method, + Verb: verb, + Namespace: namespace, + ResourceName: resourceName, + Path: req.URL.Path, + } + + if req.Body != nil && (req.Method == "POST" || req.Method == "PUT" || req.Method == "PATCH") { + body, readErr := io.ReadAll(req.Body) + _ = req.Body.Close() + if readErr != nil { + return nil, fmt.Errorf("failed to read request body: %w", readErr) + } + req.Body = io.NopCloser(bytes.NewReader(body)) + validationReq.Body = body + } + + for _, v := range rt.validators { + if validationErr := v.Validate(req.Context(), validationReq); validationErr != nil { + if ve, ok := validationErr.(*api.ValidationError); ok { + klog.V(4).Infof("Validation failed [%s]: %v", v.Name(), ve) + } + return nil, validationErr + } + } + return rt.delegate.RoundTrip(req) } @@ -102,3 +184,79 @@ func parseURLToGVR(path string) (gvr schema.GroupVersionResource, ok bool) { } return gvr, true } + +func parseURLToNamespaceAndName(path string) (namespace, name string) { + parts := strings.Split(strings.Trim(path, "/"), "/") + + for i, part := range parts { + if part == "namespaces" && i+1 < len(parts) { + namespace = parts[i+1] + break + } + } + + resourceIdx := findResourceTypeIndex(parts) + if resourceIdx >= 0 && resourceIdx+1 < len(parts) { + name = parts[resourceIdx+1] + } + + return namespace, name +} + +func findResourceTypeIndex(parts []string) int { + if len(parts) == 0 { + return -1 + } + + switch parts[0] { + case "api": + if len(parts) < 3 { + return -1 + } + if parts[2] == "namespaces" && len(parts) > 4 { + return 4 + } + return 2 + case "apis": + if len(parts) < 4 { + return -1 + } + if parts[3] == "namespaces" && len(parts) > 5 { + return 5 + } + return 3 + } + return -1 +} + +func httpMethodToVerb(method, path string) string { + switch method { + case "GET": + if isCollectionPath(path) { + return "list" + } + return "get" + case "POST": + return "create" + case "PUT": + return "update" + case "PATCH": + return "patch" + case "DELETE": + if isCollectionPath(path) { + return "deletecollection" + } + return "delete" + default: + return strings.ToLower(method) + } +} + +func isCollectionPath(path string) bool { + parts := strings.Split(strings.Trim(path, "/"), "/") + resourceIdx := findResourceTypeIndex(parts) + if resourceIdx < 0 { + return false + } + return resourceIdx == len(parts)-1 +} diff --git a/pkg/kubernetes/accesscontrol_round_tripper_test.go b/pkg/kubernetes/accesscontrol_round_tripper_test.go index c8a5de34a..facfa5bbd 100644 --- a/pkg/kubernetes/accesscontrol_round_tripper_test.go +++ b/pkg/kubernetes/accesscontrol_round_tripper_test.go @@ -287,7 +287,8 @@ func (s *AccessControlRoundTripperTestSuite) TestRoundTripForDeniedAPIResources( s.Error(err) s.Nil(resp) s.False(delegateCalled, "Expected delegate not to be called when RESTMapper fails") - s.Contains(err.Error(), "failed to make request") + s.Contains(err.Error(), "RESOURCE_NOT_FOUND") + s.Contains(err.Error(), "does not exist in the cluster") }) } diff --git a/pkg/kubernetes/auth.go b/pkg/kubernetes/auth.go new file mode 100644 index 000000000..311b9285e --- /dev/null +++ b/pkg/kubernetes/auth.go @@ -0,0 +1,54 @@ +package kubernetes + +import ( + "context" + + authv1 "k8s.io/api/authorization/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + authv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" + "k8s.io/klog/v2" +) + +// CanI checks if the current identity can perform verb on resource. +// Uses SelfSubjectAccessReview to pre-check RBAC permissions. +func CanI( + ctx context.Context, + authClient authv1client.AuthorizationV1Interface, + gvr *schema.GroupVersionResource, + namespace, resourceName, verb string, +) (bool, error) { + if authClient == nil { + return true, nil + } + + accessReview := &authv1.SelfSubjectAccessReview{ + Spec: authv1.SelfSubjectAccessReviewSpec{ + ResourceAttributes: &authv1.ResourceAttributes{ + Namespace: namespace, + Verb: verb, + Group: gvr.Group, + Version: gvr.Version, + Resource: gvr.Resource, + Name: resourceName, + }, + }, + } + + response, err := authClient.SelfSubjectAccessReviews().Create(ctx, accessReview, metav1.CreateOptions{}) + if err != nil { + return false, err + } + + if klog.V(5).Enabled() { + if response.Status.Allowed { + klog.V(5).Infof("RBAC check: allowed %s on %s/%s in %s", + verb, gvr.Group, gvr.Resource, namespace) + } else { + klog.V(5).Infof("RBAC check: denied %s on %s/%s in %s: %s", + verb, gvr.Group, gvr.Resource, namespace, response.Status.Reason) + } + } + + return response.Status.Allowed, nil +} diff --git a/pkg/kubernetes/kubernetes.go b/pkg/kubernetes/kubernetes.go index 3cd47c3b4..f44bd46de 100644 --- a/pkg/kubernetes/kubernetes.go +++ b/pkg/kubernetes/kubernetes.go @@ -12,6 +12,7 @@ import ( "k8s.io/client-go/dynamic" "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" + authv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" _ "k8s.io/client-go/plugin/pkg/client/auth/oidc" "k8s.io/client-go/rest" "k8s.io/client-go/restmapper" @@ -50,21 +51,25 @@ type Kubernetes struct { var _ api.KubernetesClient = (*Kubernetes)(nil) -func NewKubernetes(config api.BaseConfig, clientCmdConfig clientcmd.ClientConfig, restConfig *rest.Config) (*Kubernetes, error) { +func NewKubernetes(baseConfig api.BaseConfig, clientCmdConfig clientcmd.ClientConfig, restConfig *rest.Config) (*Kubernetes, error) { k := &Kubernetes{ - config: config, + config: baseConfig, clientCmdConfig: clientCmdConfig, restConfig: rest.CopyConfig(restConfig), } if k.restConfig.UserAgent == "" { k.restConfig.UserAgent = rest.DefaultKubernetesUserAgent() } + k.restConfig.Wrap(func(original http.RoundTripper) http.RoundTripper { - return &AccessControlRoundTripper{ - delegate: original, - deniedResourcesProvider: config, - restMapperProvider: func() meta.RESTMapper { return k.restMapper }, - } + return NewAccessControlRoundTripper(AccessControlRoundTripperConfig{ + Delegate: original, + DeniedResourcesProvider: baseConfig, + RestMapperProvider: func() meta.RESTMapper { return k.restMapper }, + DiscoveryProvider: func() discovery.DiscoveryInterface { return k.discoveryClient }, + AuthClientProvider: func() authv1client.AuthorizationV1Interface { return k.AuthorizationV1() }, + ValidationEnabled: baseConfig.IsValidationEnabled(), + }) }) k.restConfig.Wrap(func(original http.RoundTripper) http.RoundTripper { return &UserAgentRoundTripper{delegate: original} diff --git a/pkg/kubernetes/rbac_validator.go b/pkg/kubernetes/rbac_validator.go new file mode 100644 index 000000000..450330f5c --- /dev/null +++ b/pkg/kubernetes/rbac_validator.go @@ -0,0 +1,52 @@ +package kubernetes + +import ( + "context" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + authv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" + "k8s.io/klog/v2" +) + +// RBACValidator pre-checks RBAC permissions before execution. +type RBACValidator struct { + authClientProvider func() authv1client.AuthorizationV1Interface +} + +// NewRBACValidator creates a new RBAC validator. +func NewRBACValidator(authClientProvider func() authv1client.AuthorizationV1Interface) *RBACValidator { + return &RBACValidator{ + authClientProvider: authClientProvider, + } +} + +func (v *RBACValidator) Name() string { + return "rbac" +} + +func (v *RBACValidator) Validate(ctx context.Context, req *api.HTTPValidationRequest) error { + if req.GVR == nil || req.Verb == "" { + return nil + } + + authClient := v.authClientProvider() + if authClient == nil { + return nil + } + + allowed, err := CanI(ctx, authClient, req.GVR, req.Namespace, req.ResourceName, req.Verb) + if err != nil { + klog.V(4).Infof("RBAC pre-validation failed with error: %v", err) + return nil + } + + if !allowed { + return api.NewPermissionDeniedError( + req.Verb, + api.FormatResourceName(req.GVR), + req.Namespace, + ) + } + + return nil +} diff --git a/pkg/kubernetes/rbac_validator_test.go b/pkg/kubernetes/rbac_validator_test.go new file mode 100644 index 000000000..28328bbc6 --- /dev/null +++ b/pkg/kubernetes/rbac_validator_test.go @@ -0,0 +1,155 @@ +package kubernetes + +import ( + "context" + "errors" + "testing" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/stretchr/testify/suite" + authv1 "k8s.io/api/authorization/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + authv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" + "k8s.io/client-go/rest" +) + +type mockSelfSubjectAccessReviewInterface struct { + allowed bool + err error +} + +func (m *mockSelfSubjectAccessReviewInterface) Create(ctx context.Context, review *authv1.SelfSubjectAccessReview, opts metav1.CreateOptions) (*authv1.SelfSubjectAccessReview, error) { + if m.err != nil { + return nil, m.err + } + review.Status.Allowed = m.allowed + return review, nil +} + +type mockAuthorizationV1Interface struct { + authv1client.AuthorizationV1Interface + selfSubjectAccessReview *mockSelfSubjectAccessReviewInterface +} + +func (m *mockAuthorizationV1Interface) RESTClient() rest.Interface { + return nil +} + +func (m *mockAuthorizationV1Interface) SelfSubjectAccessReviews() authv1client.SelfSubjectAccessReviewInterface { + return m.selfSubjectAccessReview +} + +type RBACValidatorTestSuite struct { + suite.Suite +} + +func (s *RBACValidatorTestSuite) TestName() { + v := NewRBACValidator(nil) + s.Equal("rbac", v.Name()) +} + +func (s *RBACValidatorTestSuite) TestValidate() { + testCases := []struct { + name string + req *api.HTTPValidationRequest + authClient authv1client.AuthorizationV1Interface + expectError bool + errorCode api.ValidationErrorCode + }{ + { + name: "nil GVR passes validation", + req: &api.HTTPValidationRequest{GVR: nil, Verb: "get"}, + authClient: nil, + expectError: false, + }, + { + name: "empty verb passes validation", + req: &api.HTTPValidationRequest{ + GVR: &schema.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"}, + Verb: "", + }, + authClient: nil, + expectError: false, + }, + { + name: "nil auth client passes validation", + req: &api.HTTPValidationRequest{ + GVR: &schema.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"}, + Verb: "get", + }, + authClient: nil, + expectError: false, + }, + { + name: "allowed action passes validation", + req: &api.HTTPValidationRequest{ + GVR: &schema.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"}, + Verb: "get", + Namespace: "default", + }, + authClient: &mockAuthorizationV1Interface{ + selfSubjectAccessReview: &mockSelfSubjectAccessReviewInterface{allowed: true}, + }, + expectError: false, + }, + { + name: "denied action fails validation", + req: &api.HTTPValidationRequest{ + GVR: &schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, + Verb: "delete", + Namespace: "kube-system", + }, + authClient: &mockAuthorizationV1Interface{ + selfSubjectAccessReview: &mockSelfSubjectAccessReviewInterface{allowed: false}, + }, + expectError: true, + errorCode: api.ErrorCodePermissionDenied, + }, + { + name: "auth client error passes validation", + req: &api.HTTPValidationRequest{ + GVR: &schema.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"}, + Verb: "get", + Namespace: "default", + }, + authClient: &mockAuthorizationV1Interface{ + selfSubjectAccessReview: &mockSelfSubjectAccessReviewInterface{err: errors.New("connection refused")}, + }, + expectError: false, + }, + { + name: "cluster-scoped resource denied", + req: &api.HTTPValidationRequest{ + GVR: &schema.GroupVersionResource{Group: "", Version: "v1", Resource: "nodes"}, + Verb: "delete", + Namespace: "", + }, + authClient: &mockAuthorizationV1Interface{ + selfSubjectAccessReview: &mockSelfSubjectAccessReviewInterface{allowed: false}, + }, + expectError: true, + errorCode: api.ErrorCodePermissionDenied, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + v := NewRBACValidator(func() authv1client.AuthorizationV1Interface { return tc.authClient }) + err := v.Validate(context.Background(), tc.req) + + if tc.expectError { + s.Error(err) + if ve, ok := err.(*api.ValidationError); ok { + s.Equal(tc.errorCode, ve.Code) + } + } else { + s.NoError(err) + } + }) + } +} + +func TestRBACValidator(t *testing.T) { + suite.Run(t, new(RBACValidatorTestSuite)) +} diff --git a/pkg/kubernetes/resources.go b/pkg/kubernetes/resources.go index 513a51be2..3769361bd 100644 --- a/pkg/kubernetes/resources.go +++ b/pkg/kubernetes/resources.go @@ -11,7 +11,6 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/api" "github.com/containers/kubernetes-mcp-server/pkg/version" - authv1 "k8s.io/api/authorization/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" metav1beta1 "k8s.io/apimachinery/pkg/apis/meta/v1beta1" @@ -230,19 +229,6 @@ func (c *Core) supportsGroupVersion(groupVersion string) bool { } func (c *Core) canIUse(ctx context.Context, gvr *schema.GroupVersionResource, namespace, verb string) bool { - accessReviews := c.AuthorizationV1().SelfSubjectAccessReviews() - response, err := accessReviews.Create(ctx, &authv1.SelfSubjectAccessReview{ - Spec: authv1.SelfSubjectAccessReviewSpec{ResourceAttributes: &authv1.ResourceAttributes{ - Namespace: namespace, - Verb: verb, - Group: gvr.Group, - Version: gvr.Version, - Resource: gvr.Resource, - }}, - }, metav1.CreateOptions{}) - if err != nil { - // TODO: maybe return the error too - return false - } - return response.Status.Allowed + allowed, _ := CanI(ctx, c.AuthorizationV1(), gvr, namespace, "", verb) + return allowed } diff --git a/pkg/kubernetes/schema_validator.go b/pkg/kubernetes/schema_validator.go new file mode 100644 index 000000000..fabd48676 --- /dev/null +++ b/pkg/kubernetes/schema_validator.go @@ -0,0 +1,130 @@ +package kubernetes + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + "k8s.io/client-go/discovery" + "k8s.io/klog/v2" + kubectlopenapi "k8s.io/kubectl/pkg/util/openapi" + kubectlvalidation "k8s.io/kubectl/pkg/validation" +) + +const schemaCacheTTL = 15 * time.Minute + +// SchemaValidator validates resource manifests against the OpenAPI schema. +type SchemaValidator struct { + discoveryClientProvider func() discovery.DiscoveryInterface + kubectlValidator kubectlvalidation.Schema + validatorMu sync.Mutex + validatorCachedAt time.Time +} + +// NewSchemaValidator creates a new schema validator. +func NewSchemaValidator(discoveryClientProvider func() discovery.DiscoveryInterface) *SchemaValidator { + return &SchemaValidator{ + discoveryClientProvider: discoveryClientProvider, + } +} + +func (v *SchemaValidator) Name() string { + return "schema" +} + +func (v *SchemaValidator) Validate(ctx context.Context, req *api.HTTPValidationRequest) error { + if req.GVK == nil || len(req.Body) == 0 { + return nil + } + + // Only validate for create/update operations (exclude patch as partial bodies cause false positives) + if req.Verb != "create" && req.Verb != "update" { + return nil + } + + validator, err := v.getValidator() + if err != nil { + klog.V(4).Infof("Failed to get schema validator: %v", err) + return nil + } + + if validator == nil { + return nil + } + + err = validator.ValidateBytes(req.Body) + if err != nil { + // Check if this is a parsing error (e.g., binary data that can't be parsed as YAML) + // In that case, skip validation rather than blocking the request + errMsg := err.Error() + if strings.Contains(errMsg, "yaml:") || strings.Contains(errMsg, "json:") { + klog.V(4).Infof("Schema validation skipped due to parsing error: %v", err) + return nil + } + return convertKubectlValidationError(err) + } + + return nil +} + +// openAPIResourcesAdapter adapts CachedOpenAPIParser to OpenAPIResourcesGetter interface. +type openAPIResourcesAdapter struct { + parser *kubectlopenapi.CachedOpenAPIParser +} + +func (a *openAPIResourcesAdapter) OpenAPISchema() (kubectlopenapi.Resources, error) { + return a.parser.Parse() +} + +func (v *SchemaValidator) getValidator() (kubectlvalidation.Schema, error) { + v.validatorMu.Lock() + defer v.validatorMu.Unlock() + + if v.kubectlValidator != nil && time.Since(v.validatorCachedAt) <= schemaCacheTTL { + return v.kubectlValidator, nil + } + + discoveryClient := v.discoveryClientProvider() + if discoveryClient == nil { + return nil, nil + } + + openAPIClient, ok := discoveryClient.(discovery.OpenAPISchemaInterface) + if !ok { + klog.V(4).Infof("Discovery client does not support OpenAPI schema") + return nil, nil + } + + parser := kubectlopenapi.NewOpenAPIParser(openAPIClient) + adapter := &openAPIResourcesAdapter{parser: parser} + + v.kubectlValidator = kubectlvalidation.NewSchemaValidation(adapter) + v.validatorCachedAt = time.Now() + + return v.kubectlValidator, nil +} + +func convertKubectlValidationError(err error) *api.ValidationError { + if err == nil { + return nil + } + + errMsg := err.Error() + + var field string + if strings.Contains(errMsg, "unknown field") { + if start := strings.Index(errMsg, "\""); start != -1 { + if end := strings.Index(errMsg[start+1:], "\""); end != -1 { + field = errMsg[start+1 : start+1+end] + } + } + } + + return &api.ValidationError{ + Code: api.ErrorCodeInvalidField, + Message: errMsg, + Field: field, + } +} diff --git a/pkg/kubernetes/validator_registry.go b/pkg/kubernetes/validator_registry.go new file mode 100644 index 000000000..a6457c69e --- /dev/null +++ b/pkg/kubernetes/validator_registry.go @@ -0,0 +1,41 @@ +package kubernetes + +import ( + "github.com/containers/kubernetes-mcp-server/pkg/api" + "k8s.io/client-go/discovery" + authv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" +) + +// ValidatorProviders holds the providers needed to create validators. +type ValidatorProviders struct { + Discovery func() discovery.DiscoveryInterface + AuthClient func() authv1client.AuthorizationV1Interface +} + +// ValidatorFactory creates a validator given the providers. +type ValidatorFactory func(ValidatorProviders) api.HTTPValidator + +var validatorFactories []ValidatorFactory + +// RegisterValidator adds a validator factory to the registry. +func RegisterValidator(factory ValidatorFactory) { + validatorFactories = append(validatorFactories, factory) +} + +// CreateValidators creates all registered validators with the given providers. +func CreateValidators(providers ValidatorProviders) []api.HTTPValidator { + validators := make([]api.HTTPValidator, 0, len(validatorFactories)) + for _, factory := range validatorFactories { + validators = append(validators, factory(providers)) + } + return validators +} + +func init() { + RegisterValidator(func(p ValidatorProviders) api.HTTPValidator { + return NewSchemaValidator(p.Discovery) + }) + RegisterValidator(func(p ValidatorProviders) api.HTTPValidator { + return NewRBACValidator(p.AuthClient) + }) +} diff --git a/pkg/kubernetes/validator_registry_test.go b/pkg/kubernetes/validator_registry_test.go new file mode 100644 index 000000000..8fab8b4f1 --- /dev/null +++ b/pkg/kubernetes/validator_registry_test.go @@ -0,0 +1,49 @@ +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/suite" + "k8s.io/client-go/discovery" + authv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" +) + +type ValidatorRegistryTestSuite struct { + suite.Suite +} + +func (s *ValidatorRegistryTestSuite) TestCreateValidatorsReturnsRegisteredValidators() { + providers := ValidatorProviders{ + Discovery: func() discovery.DiscoveryInterface { return nil }, + AuthClient: func() authv1client.AuthorizationV1Interface { return nil }, + } + + validators := CreateValidators(providers) + + s.GreaterOrEqual(len(validators), 2, "Expected at least 2 validators (schema, rbac)") + + names := make(map[string]bool) + for _, v := range validators { + names[v.Name()] = true + } + + s.True(names["schema"], "Expected schema validator to be registered") + s.True(names["rbac"], "Expected rbac validator to be registered") +} + +func (s *ValidatorRegistryTestSuite) TestCreateValidatorsWithNilProviders() { + providers := ValidatorProviders{ + Discovery: nil, + AuthClient: nil, + } + + // Should not panic + s.NotPanics(func() { + validators := CreateValidators(providers) + s.NotEmpty(validators, "Expected validators to be created even with nil providers") + }) +} + +func TestValidatorRegistry(t *testing.T) { + suite.Run(t, new(ValidatorRegistryTestSuite)) +} diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 8a6bbb54b..9cdd190aa 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -112,6 +112,7 @@ func NewServer(configuration Configuration, targetProvider internalk8s.Provider) s.server.AddReceivingMiddleware(userAgentPropagationMiddleware(version.BinaryName, version.Version)) s.server.AddReceivingMiddleware(toolCallLoggingMiddleware) s.server.AddReceivingMiddleware(s.metricsMiddleware()) + err = s.reloadToolsets() if err != nil { return nil, err