diff --git a/CHANGES.md b/CHANGES.md index 1d8c6e7..ccc8232 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,20 @@ Release Notes. +## 0.2.0 + +### Features + +* TLS certificate verification is now enforced for OAP connections. Added `--sw-insecure` flag to opt out (development/self-signed certs only). +* Sensitive fields (`authorization`, `password`, `token`, `secret`) are redacted in `--log-command` output. +* Environment variable references (`${VAR}`) in `--sw-username`/`--sw-password` now log a warning when the variable is not set, preventing silent unauthenticated requests. +* URL scheme validation rejects non-http/https OAP URLs. +* Regex patterns supplied to `list_mqe_metrics` are validated for complexity before compilation. +* Added `--allowed-origins` flag to `sse` and `streamable` transports for CORS origin enforcement. When unset (default), any `Origin` is reflected back so all browser origins work out of the box. When set, only listed origins receive CORS headers; all others get `403 Forbidden`. Use `*` as an entry to send the wildcard header explicitly. +* Increased reliability of core CLI commands through expanded automated test coverage. +* Removed an unused CLI tool and its associated parameter to simplify the interface and avoid confusion. +* Added validation for tool configuration properties, returning clear errors when required values are missing or invalid. + ## 0.1.0 ### Features diff --git a/CLAUDE.md b/CLAUDE.md index dfaf53c..71b52de 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -45,9 +45,15 @@ The SkyWalking OAP URL is resolved in priority order: SSE and HTTP transports always use the configured server URL. -Basic auth is configured via `--sw-username` / `--sw-password` flags. The startup flags support `${ENV_VAR}` syntax to resolve credentials from environment variables (e.g. `--sw-password ${MY_SECRET}`). +Basic auth is configured via `--sw-username` / `--sw-password` flags. The startup flags support `${ENV_VAR}` syntax to resolve credentials from environment variables (e.g. `--sw-password ${MY_SECRET}`). If a referenced env var is not set, a warning is logged and the credential is treated as empty. -Each transport injects the OAP URL and auth into the request context via `WithSkyWalkingURLAndInsecure()` and `WithSkyWalkingAuth()`. Tools extract them downstream using `skywalking-cli`'s `contextkey.BaseURL{}`, `contextkey.Username{}`, and `contextkey.Password{}`. +TLS verification is enforced by default. Use `--sw-insecure` to skip verification (development/self-signed certs only). + +Each transport injects the OAP URL, insecure flag, and auth into the request context via `WithSkyWalkingURLAndInsecure()` and `WithSkyWalkingAuth()`. Tools extract them downstream using `skywalking-cli`'s `contextkey.BaseURL{}`, `contextkey.Insecure{}`, `contextkey.Username{}`, and `contextkey.Password{}`. + +### CORS / CSRF (`internal/swmcp/cors.go`) + +`sse` and `streamable` transports support `--allowed-origins` (comma-separated). When set, requests with an `Origin` header not in the list are rejected with `403 Forbidden`. CORS response headers are set for allowed origins. When the flag is empty (default), all origins are permitted. The middleware is injected via `WithHTTPServer` / `WithStreamableHTTPServer` so the MCP handler is wrapped rather than forked. ### Server Wiring (`internal/swmcp/server.go`) @@ -60,9 +66,19 @@ Each transport injects the OAP URL and auth into the request context via `WithSk ### Communication with SkyWalking OAP - **Most tools** use `skywalking-cli` packages (`pkg/graphql/...`) which communicate via GraphQL -- **MQE tools** use direct HTTP calls to the OAP `/graphql` endpoint +- **MQE tools** use direct HTTP calls to the OAP `/graphql` endpoint via `executeGraphQLWithContext()` in `mqe.go`. The HTTP client reads `contextkey.Insecure{}` to configure TLS and validates the URL scheme (`http`/`https` only) before each request. - **Time handling**: `common.go` provides `BuildDurationWithContext()` and `GetTimeContext()` which fetch the OAP server's time/timezone for accurate duration calculations +### Input Validation (`internal/tools/mqe.go`) + +All MQE tool inputs are validated before use: +- `validateMQETextField`: UTF-8, max length, no control characters — applied to all entity fields +- `validateLayerField`: additionally enforces `^[A-Z0-9_]+$` for `layer` / `dest_layer` +- `validateMQEExpression`: UTF-8, max 2048 chars, no control characters, max nesting depth 12 +- `validateMetricName`: `^[A-Za-z0-9_.:-]+$` pattern, max 128 chars +- `validateRegexComplexity`: parses the regex AST via `regexp/syntax` and rejects patterns with >50 nodes +- `validateURLScheme` (`common.go`): rejects non-http/https OAP URLs before HTTP requests + ## Extending the Server ### Adding a New Tool diff --git a/README.md b/README.md index af91904..d176830 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Available Commands: stdio Start stdio server streamable Start Streamable server -Flags: +Global Flags: -h, --help help for swmcp --log-command When true, log commands to the log file --log-file string Path to log file @@ -44,8 +44,19 @@ Flags: --sw-url string Specify the OAP URL to connect to (e.g. http://localhost:12800) --sw-username string Username for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax) --sw-password string Password for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax) + --sw-insecure Skip TLS certificate verification for OAP connections (use only in development) -v, --version version for swmcp +SSE-specific Flags: + --sse-address string Host and port for the SSE server (default "localhost:8000") + --base-path string Base path for the SSE server + --allowed-origins string Comma-separated list of allowed CORS origins. Empty reflects any origin (open CORS). Use * to send the wildcard header. + +Streamable-specific Flags: + --address string Host and port for the Streamable HTTP server (default "localhost:8000") + --endpoint-path string Endpoint path for the Streamable HTTP server (default "/mcp") + --allowed-origins string Comma-separated list of allowed CORS origins. Empty reflects any origin (open CORS). Use * to send the wildcard header. + Use "swmcp [command] --help" for more information about a command. ``` @@ -61,8 +72,14 @@ bin/swmcp stdio --sw-url http://localhost:12800 --sw-username admin --sw-passwor # with basic auth (password from environment variable) bin/swmcp stdio --sw-url http://localhost:12800 --sw-username admin --sw-password '${SW_PASSWORD}' +# skip TLS verification (development only, e.g. self-signed certs) +bin/swmcp stdio --sw-url https://localhost:12800 --sw-insecure + # or use SSE server bin/swmcp sse --sse-address localhost:8000 --base-path /mcp --sw-url http://localhost:12800 + +# restrict CORS to specific origins (SSE and streamable transports) +bin/swmcp streamable --sw-url http://localhost:12800 --allowed-origins "http://localhost:3000,https://app.example.com" ``` Transport URL behavior: diff --git a/cmd/skywalking-mcp/main.go b/cmd/skywalking-mcp/main.go index 66babdc..1a80d8f 100644 --- a/cmd/skywalking-mcp/main.go +++ b/cmd/skywalking-mcp/main.go @@ -59,6 +59,7 @@ func init() { rootCmd.PersistentFlags().String("sw-url", "", "Specify the OAP URL to connect to (e.g. http://localhost:12800)") rootCmd.PersistentFlags().String("sw-username", "", "Username for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax)") rootCmd.PersistentFlags().String("sw-password", "", "Password for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax)") + rootCmd.PersistentFlags().Bool("sw-insecure", false, "Skip TLS certificate verification for OAP connections (use only in development)") rootCmd.PersistentFlags().String("log-level", "info", "Logging level (debug, info, warn, error)") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") rootCmd.PersistentFlags().Bool("log-command", false, "When true, log commands to the log file") @@ -68,6 +69,7 @@ func init() { _ = viper.BindPFlag("url", rootCmd.PersistentFlags().Lookup("sw-url")) _ = viper.BindPFlag("username", rootCmd.PersistentFlags().Lookup("sw-username")) _ = viper.BindPFlag("password", rootCmd.PersistentFlags().Lookup("sw-password")) + _ = viper.BindPFlag("insecure", rootCmd.PersistentFlags().Lookup("sw-insecure")) _ = viper.BindPFlag("log-level", rootCmd.PersistentFlags().Lookup("log-level")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) _ = viper.BindPFlag("log-command", rootCmd.PersistentFlags().Lookup("log-command")) diff --git a/internal/swmcp/cors.go b/internal/swmcp/cors.go new file mode 100644 index 0000000..8f1b0f9 --- /dev/null +++ b/internal/swmcp/cors.go @@ -0,0 +1,77 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package swmcp + +import ( + "net/http" + "strings" +) + +// corsMiddleware adds CORS response headers and enforces origin validation. +// When allowedOrigins is empty, every request with an Origin header is +// reflected back — i.e., CORS is open and all browser origins work. +// When allowedOrigins is non-empty, only listed origins receive CORS headers; +// requests from any other origin receive 403 Forbidden. Use "*" as an entry +// to explicitly allow all origins via the wildcard header. +func corsMiddleware(allowedOrigins []string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" { + if len(allowedOrigins) == 0 || isOriginAllowed(origin, allowedOrigins) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept") + w.Header().Set("Vary", "Origin") + } else { + http.Error(w, "forbidden: origin not allowed", http.StatusForbidden) + return + } + } + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) +} + +// isOriginAllowed reports whether origin is in the allowed list. +// The wildcard "*" matches any origin. +func isOriginAllowed(origin string, allowed []string) bool { + for _, a := range allowed { + if a == "*" || a == origin { + return true + } + } + return false +} + +// parseAllowedOrigins splits a comma-separated list of origins. +func parseAllowedOrigins(raw string) []string { + if raw == "" { + return nil + } + parts := strings.Split(raw, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + if trimmed := strings.TrimSpace(p); trimmed != "" { + result = append(result, trimmed) + } + } + return result +} diff --git a/internal/swmcp/cors_test.go b/internal/swmcp/cors_test.go new file mode 100644 index 0000000..fdaa1db --- /dev/null +++ b/internal/swmcp/cors_test.go @@ -0,0 +1,210 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package swmcp + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +const ( + allowedOrigin = "http://allowed.example.com" + openOrigin = "http://any.example.com" +) + +// sentinelHandler is a simple handler that records whether it was called. +type sentinelHandler struct{ called bool } + +func (s *sentinelHandler) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + s.called = true + w.WriteHeader(http.StatusOK) +} + +func corsRequest(method, origin string, allowedOrigins []string) (*httptest.ResponseRecorder, *sentinelHandler) { + sentinel := &sentinelHandler{} + handler := corsMiddleware(allowedOrigins, sentinel) + + req := httptest.NewRequest(method, "/mcp", http.NoBody) + if origin != "" { + req.Header.Set("Origin", origin) + } + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + return rr, sentinel +} + +// --- empty allowlist (open CORS) --- + +func TestCORSEmptyAllowlistNoOriginHeader(t *testing.T) { + rr, sentinel := corsRequest(http.MethodPost, "", nil) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + if !sentinel.called { + t.Fatal("next handler not called") + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Fatalf("ACAO header should be absent without Origin, got %q", got) + } +} + +func TestCORSEmptyAllowlistReflectsAnyOrigin(t *testing.T) { + rr, sentinel := corsRequest(http.MethodPost, openOrigin, nil) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + if !sentinel.called { + t.Fatal("next handler not called") + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != openOrigin { + t.Fatalf("ACAO = %q, want reflected origin", got) + } + if got := rr.Header().Get("Vary"); got != "Origin" { + t.Fatalf("Vary = %q, want Origin", got) + } +} + +// --- non-empty allowlist --- + +func TestCORSAllowedOriginReceivesHeaders(t *testing.T) { + allowed := []string{allowedOrigin} + rr, sentinel := corsRequest(http.MethodPost, allowedOrigin, allowed) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + if !sentinel.called { + t.Fatal("next handler not called") + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != allowedOrigin { + t.Fatalf("ACAO = %q, want reflected origin", got) + } +} + +func TestCORSDisallowedOriginRejects(t *testing.T) { + allowed := []string{allowedOrigin} + rr, sentinel := corsRequest(http.MethodPost, "http://attacker.invalid", allowed) + if rr.Code != http.StatusForbidden { + t.Fatalf("status = %d, want 403", rr.Code) + } + if sentinel.called { + t.Fatal("next handler must not be called for disallowed origin") + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Fatalf("ACAO header should be absent for rejected origin, got %q", got) + } +} + +func TestCORSNoOriginHeaderWithAllowlist(t *testing.T) { + allowed := []string{allowedOrigin} + rr, sentinel := corsRequest(http.MethodPost, "", allowed) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200 (non-browser requests without Origin pass through)", rr.Code) + } + if !sentinel.called { + t.Fatal("next handler not called") + } +} + +// --- wildcard entry --- + +func TestCORSWildcardEntryReflectsOrigin(t *testing.T) { + allowed := []string{"*"} + rr, sentinel := corsRequest(http.MethodPost, "http://anything.example.com", allowed) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + if !sentinel.called { + t.Fatal("next handler not called") + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://anything.example.com" { + t.Fatalf("ACAO = %q, want reflected origin for wildcard entry", got) + } +} + +// --- preflight (OPTIONS) --- + +func TestCORSPreflightAllowedOrigin(t *testing.T) { + allowed := []string{allowedOrigin} + rr, sentinel := corsRequest(http.MethodOptions, allowedOrigin, allowed) + if rr.Code != http.StatusNoContent { + t.Fatalf("status = %d, want 204", rr.Code) + } + if sentinel.called { + t.Fatal("next handler must not be called for preflight") + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != allowedOrigin { + t.Fatalf("ACAO = %q, want reflected origin on preflight", got) + } + if got := rr.Header().Get("Access-Control-Allow-Methods"); got == "" { + t.Fatal("Access-Control-Allow-Methods should be set on preflight") + } +} + +func TestCORSPreflightDisallowedOriginRejects(t *testing.T) { + allowed := []string{allowedOrigin} + rr, sentinel := corsRequest(http.MethodOptions, "http://attacker.invalid", allowed) + if rr.Code != http.StatusForbidden { + t.Fatalf("status = %d, want 403", rr.Code) + } + if sentinel.called { + t.Fatal("next handler must not be called") + } +} + +func TestCORSPreflightEmptyAllowlist(t *testing.T) { + rr, sentinel := corsRequest(http.MethodOptions, openOrigin, nil) + if rr.Code != http.StatusNoContent { + t.Fatalf("status = %d, want 204", rr.Code) + } + if sentinel.called { + t.Fatal("next handler must not be called for preflight") + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != openOrigin { + t.Fatalf("ACAO = %q, want reflected origin on open-CORS preflight", got) + } +} + +// --- parseAllowedOrigins --- + +func TestParseAllowedOrigins(t *testing.T) { + tests := []struct { + name string + in string + want []string + }{ + {name: "empty", in: "", want: nil}, + {name: "single", in: "http://a.example.com", want: []string{"http://a.example.com"}}, + {name: "multiple", in: "http://a.example.com,https://b.example.com", want: []string{"http://a.example.com", "https://b.example.com"}}, + {name: "trims spaces", in: " http://a.example.com , https://b.example.com ", want: []string{"http://a.example.com", "https://b.example.com"}}, + {name: "wildcard", in: "*", want: []string{"*"}}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := parseAllowedOrigins(tc.in) + if len(got) != len(tc.want) { + t.Fatalf("len = %d, want %d: %v", len(got), len(tc.want), got) + } + for i := range got { + if got[i] != tc.want[i] { + t.Fatalf("[%d] = %q, want %q", i, got[i], tc.want[i]) + } + } + }) + } +} diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go index 0cdba97..e5307d4 100644 --- a/internal/swmcp/server.go +++ b/internal/swmcp/server.go @@ -107,7 +107,12 @@ func resolveEnvVar(value string) string { trimmed := strings.TrimSpace(value) if strings.HasPrefix(trimmed, "${") && strings.HasSuffix(trimmed, "}") { envName := trimmed[2 : len(trimmed)-1] - return os.Getenv(envName) + resolved, ok := os.LookupEnv(envName) + if !ok { + logrus.Warnf("environment variable %q is referenced but not set", envName) + return "" + } + return resolved } return value } @@ -130,7 +135,7 @@ func withConfiguredAuth(ctx context.Context) context.Context { // with SkyWalking settings from the global configuration. func EnhanceStdioContextFunc() server.StdioContextFunc { return func(ctx context.Context) context.Context { - ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), viper.GetBool("insecure")) ctx = withConfiguredAuth(ctx) return ctx } @@ -140,7 +145,7 @@ func EnhanceStdioContextFunc() server.StdioContextFunc { // with SkyWalking settings from the CLI configuration and configured auth. func EnhanceSSEContextFunc() server.SSEContextFunc { return func(ctx context.Context, _ *http.Request) context.Context { - ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), viper.GetBool("insecure")) ctx = withConfiguredAuth(ctx) return ctx } @@ -150,7 +155,7 @@ func EnhanceSSEContextFunc() server.SSEContextFunc { // with SkyWalking settings from the CLI configuration and configured auth. func EnhanceHTTPContextFunc() server.HTTPContextFunc { return func(ctx context.Context, _ *http.Request) context.Context { - ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), viper.GetBool("insecure")) ctx = withConfiguredAuth(ctx) return ctx } diff --git a/internal/swmcp/server_registry_test.go b/internal/swmcp/server_registry_test.go index 55d1384..ebc7073 100644 --- a/internal/swmcp/server_registry_test.go +++ b/internal/swmcp/server_registry_test.go @@ -27,10 +27,11 @@ import ( ) // These registry tests verify that newMCPServer wires up the expected tools, -// prompts, and resources. mcp-go v0.45.0 does not expose a public inventory API -// for MCPServer, so the tests read server internals through a single helper -// layer below. If mcp-go changes its internal field layout, update only the -// helpers in this file rather than spreading reflect/unsafe access across tests. +// prompts, and resources. Prefer MCPServer public APIs where available. As of +// mcp-go v0.45.0 only tools have a public inventory API, so prompt/resource +// assertions go through the helper layer below. If a future mcp-go release +// exposes prompt/resource listing, replace the reflection helper there rather +// than spreading reflect/unsafe access across tests. func TestNewMCPServerRegistersExpectedTools(t *testing.T) { srv := newMCPServer() @@ -60,8 +61,9 @@ func TestNewMCPServerRegistersExpectedTools(t *testing.T) { func TestNewMCPServerRegistersExpectedPrompts(t *testing.T) { srv := newMCPServer() + inventory := inspectServerInventory(srv) - got := sortedPromptNames(srv) + got := sortedPromptNames(inventory.prompts) want := []string{ "analyze-logs", "analyze-performance", @@ -80,8 +82,9 @@ func TestNewMCPServerRegistersExpectedPrompts(t *testing.T) { func TestNewMCPServerRegistersExpectedResources(t *testing.T) { srv := newMCPServer() + inventory := inspectServerInventory(srv) - resources := resourceMap(srv) + resources := inventory.resources got := make([]string, 0, len(resources)) for uri := range resources { got = append(got, uri) @@ -100,7 +103,7 @@ func TestNewMCPServerRegistersExpectedResources(t *testing.T) { func TestPromptMetadataIncludesExpectedArguments(t *testing.T) { srv := newMCPServer() - prompts := promptMap(srv) + prompts := inspectServerInventory(srv).prompts prompt, ok := prompts["generate_duration"] if !ok { @@ -130,7 +133,7 @@ func TestPromptMetadataIncludesExpectedArguments(t *testing.T) { func TestResourceMetadataIncludesExpectedMIMETypes(t *testing.T) { srv := newMCPServer() - resources := resourceMap(srv) + resources := inspectServerInventory(srv).resources tests := []struct { uri string @@ -192,39 +195,10 @@ func TestToolMetadataIncludesExpectedDescriptionsAndSchemas(t *testing.T) { } func toolMap(srv *server.MCPServer) map[string]mcp.Tool { - serverTools := mustReadServerField(testedServerValue(srv), "tools") - result := make(map[string]mcp.Tool, serverTools.Len()) - - iter := serverTools.MapRange() - for iter.Next() { - name := iter.Key().String() - toolValue := copyReflectValue(iter.Value()) - result[name] = toolValue.FieldByName("Tool").Interface().(mcp.Tool) - } - - return result -} - -func promptMap(srv *server.MCPServer) map[string]mcp.Prompt { - serverPrompts := mustReadServerField(testedServerValue(srv), "prompts") - result := make(map[string]mcp.Prompt, serverPrompts.Len()) - - iter := serverPrompts.MapRange() - for iter.Next() { - result[iter.Key().String()] = copyReflectValue(iter.Value()).Interface().(mcp.Prompt) - } - - return result -} - -func resourceMap(srv *server.MCPServer) map[string]mcp.Resource { - serverResources := mustReadServerField(testedServerValue(srv), "resources") - result := make(map[string]mcp.Resource, serverResources.Len()) - - iter := serverResources.MapRange() - for iter.Next() { - resourceField := copyReflectValue(iter.Value()).FieldByName("resource") - result[iter.Key().String()] = readPrivateField(resourceField).Interface().(mcp.Resource) + serverTools := srv.ListTools() + result := make(map[string]mcp.Tool, len(serverTools)) + for name, tool := range serverTools { + result[name] = tool.Tool } return result @@ -240,8 +214,7 @@ func sortedToolNames(srv *server.MCPServer) []string { return names } -func sortedPromptNames(srv *server.MCPServer) []string { - prompts := promptMap(srv) +func sortedPromptNames(prompts map[string]mcp.Prompt) []string { names := make([]string, 0, len(prompts)) for name := range prompts { names = append(names, name) @@ -273,6 +246,44 @@ func mustReadServerField(srv reflect.Value, fieldName string) reflect.Value { return readPrivateField(field) } +type serverInventory struct { + prompts map[string]mcp.Prompt + resources map[string]mcp.Resource +} + +func inspectServerInventory(srv *server.MCPServer) serverInventory { + serverValue := testedServerValue(srv) + return serverInventory{ + prompts: readPromptMap(serverValue), + resources: readResourceMap(serverValue), + } +} + +func readPromptMap(serverValue reflect.Value) map[string]mcp.Prompt { + serverPrompts := mustReadServerField(serverValue, "prompts") + result := make(map[string]mcp.Prompt, serverPrompts.Len()) + + iter := serverPrompts.MapRange() + for iter.Next() { + result[iter.Key().String()] = copyReflectValue(iter.Value()).Interface().(mcp.Prompt) + } + + return result +} + +func readResourceMap(serverValue reflect.Value) map[string]mcp.Resource { + serverResources := mustReadServerField(serverValue, "resources") + result := make(map[string]mcp.Resource, serverResources.Len()) + + iter := serverResources.MapRange() + for iter.Next() { + resourceField := copyReflectValue(iter.Value()).FieldByName("resource") + result[iter.Key().String()] = readPrivateField(resourceField).Interface().(mcp.Resource) + } + + return result +} + func copyReflectValue(v reflect.Value) reflect.Value { cloned := reflect.New(v.Type()).Elem() cloned.Set(v) diff --git a/internal/swmcp/server_test.go b/internal/swmcp/server_test.go index 28056c8..1007e56 100644 --- a/internal/swmcp/server_test.go +++ b/internal/swmcp/server_test.go @@ -169,3 +169,38 @@ func TestEnhanceSSEContextFuncDoesNotUseSWURLHeader(t *testing.T) { t.Fatalf("base URL = %q", gotURL) } } + +func TestInsecureFlagDefaultsToFalse(t *testing.T) { + t.Cleanup(viper.Reset) + + req, _ := http.NewRequest(http.MethodGet, "http://client/events", http.NoBody) + + for name, ctx := range map[string]context.Context{ + "stdio": EnhanceStdioContextFunc()(context.Background()), + "sse": EnhanceSSEContextFunc()(context.Background(), req), + "streamable": EnhanceHTTPContextFunc()(context.Background(), req), + } { + insecure, _ := ctx.Value(contextkey.Insecure{}).(bool) + if insecure { + t.Errorf("%s: contextkey.Insecure{} should default to false", name) + } + } +} + +func TestInsecureFlagPropagatedToContext(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("insecure", true) + + req, _ := http.NewRequest(http.MethodGet, "http://client/events", http.NoBody) + + for name, ctx := range map[string]context.Context{ + "stdio": EnhanceStdioContextFunc()(context.Background()), + "sse": EnhanceSSEContextFunc()(context.Background(), req), + "streamable": EnhanceHTTPContextFunc()(context.Background(), req), + } { + insecure, ok := ctx.Value(contextkey.Insecure{}).(bool) + if !ok || !insecure { + t.Errorf("%s: contextkey.Insecure{} should be true when viper insecure=true", name) + } + } +} diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go index 14365a9..ae06a07 100644 --- a/internal/swmcp/sse.go +++ b/internal/swmcp/sse.go @@ -55,8 +55,11 @@ func NewSSEServer() *cobra.Command { "The host and port to start the sse server on") sseCmd.Flags().String("base-path", "", "Base path for the sse server") + sseCmd.Flags().String("allowed-origins", "", + "Comma-separated allowed CORS origins. Empty = open (any origin reflected). Use * for wildcard header.") _ = viper.BindPFlag("sse-address", sseCmd.Flags().Lookup("sse-address")) _ = viper.BindPFlag("base-path", sseCmd.Flags().Lookup("base-path")) + _ = viper.BindPFlag("allowed-origins", sseCmd.Flags().Lookup("allowed-origins")) return sseCmd } @@ -71,11 +74,25 @@ func runSSEServer(ctx context.Context, cfg *config.SSEServerConfig) error { return fmt.Errorf("failed to initialize logger: %w", err) } + allowedOrigins := parseAllowedOrigins(viper.GetString("allowed-origins")) + + // sseServer is assigned after NewSSEServer so the CORS handler closure + // can forward to it via the captured pointer variable. + var sseServerRef *server.SSEServer + customSrv := &http.Server{ + Handler: corsMiddleware(allowedOrigins, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sseServerRef.ServeHTTP(w, r) + })), + ReadHeaderTimeout: 10 * time.Second, + } + sseServer := server.NewSSEServer( newMCPServer(), server.WithStaticBasePath(cfg.BasePath), server.WithSSEContextFunc(EnhanceSSEContextFunc()), + server.WithHTTPServer(customSrv), ) + sseServerRef = sseServer ssePath := sseServer.CompleteSsePath() log.Printf("Starting SkyWalking MCP server using SSE transport listening on http://%s%s\n ", cfg.Address, ssePath) diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go index 0500352..0f85fb1 100644 --- a/internal/swmcp/streamable.go +++ b/internal/swmcp/streamable.go @@ -19,6 +19,8 @@ package swmcp import ( "fmt" + "net/http" + "time" "github.com/mark3labs/mcp-go/server" log "github.com/sirupsen/logrus" @@ -48,21 +50,39 @@ func NewStreamable() *cobra.Command { "The host and port to start the Streamable server on") streamableCmd.Flags().String("endpoint-path", "/mcp", "The path for the streamable-http server") + streamableCmd.Flags().String("allowed-origins", "", + "Comma-separated allowed CORS origins. Empty = open (any origin reflected). Use * for wildcard header.") _ = viper.BindPFlag("address", streamableCmd.Flags().Lookup("address")) _ = viper.BindPFlag("endpoint-path", streamableCmd.Flags().Lookup("endpoint-path")) + _ = viper.BindPFlag("allowed-origins", streamableCmd.Flags().Lookup("allowed-origins")) return streamableCmd } // runStreamableServer starts the Streamable server with the provided configuration. func runStreamableServer(cfg *config.StreamableServerConfig) error { + allowedOrigins := parseAllowedOrigins(viper.GetString("allowed-origins")) + + // httpServer is assigned after NewStreamableHTTPServer so the CORS handler + // closure can forward to it via the captured pointer variable. + var httpServerRef *server.StreamableHTTPServer + customSrv := &http.Server{ + Handler: corsMiddleware(allowedOrigins, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpServerRef.ServeHTTP(w, r) + })), + ReadHeaderTimeout: 10 * time.Second, + } + httpServer := server.NewStreamableHTTPServer( newMCPServer(), server.WithStateLess(true), server.WithLogger(log.StandardLogger()), server.WithHTTPContextFunc(EnhanceHTTPContextFunc()), server.WithEndpointPath(viper.GetString("endpoint-path")), + server.WithStreamableHTTPServer(customSrv), ) + httpServerRef = httpServer + log.Infof("streamable HTTP server listening on %s%s\n", cfg.Address, cfg.EndpointPath) if err := httpServer.Start(cfg.Address); err != nil { diff --git a/internal/tools/alarm_test.go b/internal/tools/alarm_test.go new file mode 100644 index 0000000..a23110c --- /dev/null +++ b/internal/tools/alarm_test.go @@ -0,0 +1,99 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" + "time" + + api "skywalking.apache.org/repo/goapi/query" +) + +func TestBuildAlarmQueryCondition(t *testing.T) { + timeCtx := TimeContext{ + NowUTC: time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC), + Location: time.UTC, + } + + req := &AlarmQueryRequest{ + Scope: "Service", + Keyword: "timeout", + Tags: []AlarmTag{ + {Key: "level", Value: "critical"}, + {Key: "team", Value: "payments"}, + }, + Start: "-2h", + End: "now", + PageNum: 2, + PageSize: 30, + } + + cond := buildAlarmQueryCondition(req, timeCtx) + + if cond.Scope != api.Scope("Service") { + t.Fatalf("scope = %q", cond.Scope) + } + if cond.Keyword != "timeout" { + t.Fatalf("keyword = %q", cond.Keyword) + } + if cond.Duration == nil { + t.Fatal("duration is nil") + } + if cond.Duration.Start != testTimeStart { + t.Fatalf("start = %q", cond.Duration.Start) + } + if cond.Duration.End != testTimeEnd { + t.Fatalf("end = %q", cond.Duration.End) + } + if cond.Paging == nil || cond.Paging.PageNum == nil || *cond.Paging.PageNum != 2 { + t.Fatalf("page num = %v", cond.Paging) + } + if cond.Paging.PageSize != 30 { + t.Fatalf("page size = %d", cond.Paging.PageSize) + } + if len(cond.Tags) != 2 { + t.Fatalf("tags len = %d", len(cond.Tags)) + } + if cond.Tags[0].Key != "level" || cond.Tags[0].Value == nil || *cond.Tags[0].Value != "critical" { + t.Fatalf("first tag = %+v", cond.Tags[0]) + } +} + +func TestBuildAlarmQueryConditionDefaults(t *testing.T) { + timeCtx := TimeContext{ + NowUTC: time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC), + Location: time.UTC, + } + + cond := buildAlarmQueryCondition(&AlarmQueryRequest{}, timeCtx) + + if cond.Scope != "" { + t.Fatalf("scope = %q", cond.Scope) + } + if cond.Paging == nil || cond.Paging.PageNum == nil || *cond.Paging.PageNum != DefaultPageNum { + t.Fatalf("default page num = %v", cond.Paging) + } + if cond.Paging.PageSize != DefaultPageSize { + t.Fatalf("default page size = %d", cond.Paging.PageSize) + } + if cond.Duration == nil { + t.Fatal("duration is nil") + } + if cond.Duration.End != testTimeEnd { + t.Fatalf("end = %q", cond.Duration.End) + } +} diff --git a/internal/tools/common.go b/internal/tools/common.go index fc89e0a..86b8fcb 100644 --- a/internal/tools/common.go +++ b/internal/tools/common.go @@ -20,6 +20,7 @@ package tools import ( "context" "fmt" + "net/url" "strconv" "strings" "time" @@ -50,6 +51,18 @@ func FinalizeURL(urlStr string) string { return urlStr } +// validateURLScheme ensures the URL uses http or https. +func validateURLScheme(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid OAP URL: %w", err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("unsupported OAP URL scheme %q: only http and https are allowed", u.Scheme) + } + return nil +} + // FormatTimeByStep formats time according to step granularity func FormatTimeByStep(t time.Time, step api.Step) string { switch step { diff --git a/internal/tools/common_test.go b/internal/tools/common_test.go new file mode 100644 index 0000000..eb60ce1 --- /dev/null +++ b/internal/tools/common_test.go @@ -0,0 +1,140 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" + "time" + + api "skywalking.apache.org/repo/goapi/query" +) + +const ( + // testTimeStart / testTimeEnd are the formatted minute-step values for + // now=2026-03-31T12:00 UTC minus/plus 2 h, used across multiple tests here. + testTimeStart = "2026-03-31 1000" + testTimeEnd = "2026-03-31 1200" +) + +func TestFinalizeURL(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "adds graphql suffix", in: "http://localhost:12800", want: "http://localhost:12800/graphql"}, + {name: "trims trailing slash", in: "http://localhost:12800/", want: "http://localhost:12800/graphql"}, + {name: "keeps existing graphql", in: "http://localhost:12800/graphql", want: "http://localhost:12800/graphql"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := FinalizeURL(tc.in); got != tc.want { + t.Fatalf("FinalizeURL(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestParseTimezoneOffset(t *testing.T) { + loc, ok := parseTimezoneOffset("+0830") + if !ok { + t.Fatal("expected timezone to parse") + } + if got := loc.String(); got != "+0830" { + t.Fatalf("location name = %q", got) + } + + _, ok = parseTimezoneOffset("UTC") + if ok { + t.Fatal("expected invalid timezone offset to fail") + } +} + +func TestParseDurationWithContextRelativeDuration(t *testing.T) { + now := time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC) + timeCtx := TimeContext{NowUTC: now, Location: time.UTC} + + got := ParseDurationWithContext("-2h", false, timeCtx) + + if got.Start != testTimeStart { + t.Fatalf("start = %q", got.Start) + } + if got.End != testTimeEnd { + t.Fatalf("end = %q", got.End) + } + if got.Step != api.StepMinute { + t.Fatalf("step = %q", got.Step) + } +} + +func TestParseDurationWithContextLegacyDays(t *testing.T) { + now := time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC) + timeCtx := TimeContext{NowUTC: now, Location: time.UTC} + + got := ParseDurationWithContext("7d", true, timeCtx) + + if got.Start != "2026-03-24" { + t.Fatalf("start = %q", got.Start) + } + if got.End != "2026-03-31" { + t.Fatalf("end = %q", got.End) + } + if got.Step != api.StepDay { + t.Fatalf("step = %q", got.Step) + } + if got.ColdStage == nil || !*got.ColdStage { + t.Fatal("expected cold stage to be true") + } +} + +func TestBuildPaginationDefaultsAndCustomValues(t *testing.T) { + gotDefault := BuildPagination(0, 0) + if gotDefault.PageNum == nil || *gotDefault.PageNum != DefaultPageNum { + t.Fatalf("default page num = %v", gotDefault.PageNum) + } + if gotDefault.PageSize != DefaultPageSize { + t.Fatalf("default page size = %d", gotDefault.PageSize) + } + + gotCustom := BuildPagination(3, 50) + if gotCustom.PageNum == nil || *gotCustom.PageNum != 3 { + t.Fatalf("custom page num = %v", gotCustom.PageNum) + } + if gotCustom.PageSize != 50 { + t.Fatalf("custom page size = %d", gotCustom.PageSize) + } +} + +func TestBuildDurationWithContextParsesAbsoluteTimes(t *testing.T) { + timeCtx := TimeContext{ + NowUTC: time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC), + Location: time.FixedZone("+0800", 8*3600), + } + + got := BuildDurationWithContext("2026-03-31 18:00:00", "2026-03-31 20:00:00", "", false, 30, timeCtx) + + if got.Start != testTimeStart { + t.Fatalf("start = %q", got.Start) + } + if got.End != testTimeEnd { + t.Fatalf("end = %q", got.End) + } + if got.Step != api.StepMinute { + t.Fatalf("step = %q", got.Step) + } +} diff --git a/internal/tools/io.go b/internal/tools/io.go index 3eeaee0..7c2d70f 100644 --- a/internal/tools/io.go +++ b/internal/tools/io.go @@ -18,17 +18,31 @@ package tools import ( + "bytes" "io" + "regexp" log "github.com/sirupsen/logrus" ) -// IOLogger is a wrapper around io.Reader and io.Writer that can be used -// to log the data being read and written from the underlying streams +// sensitiveFieldPattern matches JSON fields whose values should be redacted in logs. +var sensitiveFieldPattern = regexp.MustCompile(`(?i)("(?:authorization|password|token|secret)"\s*:\s*")((?:[^"\\]|\\.)*)(")`) //nolint:lll // regex must be on one line + +// redactSensitiveData masks values of sensitive JSON fields before logging. +func redactSensitiveData(data string) string { + return sensitiveFieldPattern.ReplaceAllString(data, `${1}[REDACTED]${3}`) +} + +// IOLogger is a wrapper around io.Reader and io.Writer that logs complete +// newline-delimited JSON-RPC messages. Partial chunks are held in per-direction +// buffers until a newline arrives, ensuring the redaction regex always sees a +// full message and secrets split across read boundaries are never partially logged. type IOLogger struct { - reader io.Reader - writer io.Writer - logger *log.Logger + reader io.Reader + writer io.Writer + logger *log.Logger + readBuf bytes.Buffer + writeBuf bytes.Buffer } // NewIOLogger creates a new IOLogger instance @@ -40,23 +54,43 @@ func NewIOLogger(r io.Reader, w io.Writer, logger *log.Logger) *IOLogger { } } -// Read reads data from the underlying io.Reader and logs it. +// logCompleteLines drains newline-terminated lines from buf, redacts each one, +// and logs it under the given direction label. Any trailing partial line is left +// in buf for the next call. +func (l *IOLogger) logCompleteLines(buf *bytes.Buffer, direction string) { + data := buf.Bytes() + for { + idx := bytes.IndexByte(data, '\n') + if idx < 0 { + break + } + line := bytes.TrimRight(data[:idx], "\r") + l.logger.Infof("[%s]: %s", direction, redactSensitiveData(string(line))) + data = data[idx+1:] + } + buf.Reset() + buf.Write(data) +} + +// Read reads data from the underlying io.Reader and logs complete lines. func (l *IOLogger) Read(p []byte) (n int, err error) { if l.reader == nil { return 0, io.EOF } n, err = l.reader.Read(p) if n > 0 { - l.logger.Infof("[stdin]: received %d bytes: %s", n, string(p[:n])) + l.readBuf.Write(p[:n]) + l.logCompleteLines(&l.readBuf, "stdin") } return n, err } -// Write writes data to the underlying io.Writer and logs it. +// Write writes data to the underlying io.Writer and logs complete lines. func (l *IOLogger) Write(p []byte) (n int, err error) { if l.writer == nil { return 0, io.ErrClosedPipe } - l.logger.Infof("[stdout]: sending %d bytes: %s", len(p), string(p)) + l.writeBuf.Write(p) + l.logCompleteLines(&l.writeBuf, "stdout") return l.writer.Write(p) } diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index 3a45396..e9f4e5c 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -20,12 +20,14 @@ package tools import ( "bytes" "context" + "crypto/tls" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "regexp" + "regexp/syntax" "time" "unicode" "unicode/utf8" @@ -53,6 +55,9 @@ const ( var metricNamePattern = regexp.MustCompile(`^[A-Za-z0-9_.:-]+$`) +// layerPattern restricts layer values to the SkyWalking enum format (e.g. GENERAL, K8S_SERVICE). +var layerPattern = regexp.MustCompile(`^[A-Z0-9_]+$`) + // GraphQLRequest represents a GraphQL request type GraphQLRequest struct { Query string `json:"query"` @@ -75,10 +80,22 @@ func getContextString(ctx context.Context, key any) string { return "" } +// getContextBool safely extracts a bool value from context. +func getContextBool(ctx context.Context, key any) bool { + if v, ok := ctx.Value(key).(bool); ok { + return v + } + return false +} + // executeGraphQLWithContext executes a GraphQL query using URL and auth from context. func executeGraphQLWithContext(ctx context.Context, query string, variables map[string]interface{}) (*GraphQLResponse, error) { - url := getContextString(ctx, contextkey.BaseURL{}) - url = FinalizeURL(url) + rawURL := getContextString(ctx, contextkey.BaseURL{}) + rawURL = FinalizeURL(rawURL) + + if err := validateURLScheme(rawURL); err != nil { + return nil, err + } reqBody := GraphQLRequest{ Query: query, @@ -90,7 +107,7 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, "POST", rawURL, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) } @@ -105,7 +122,10 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ req.Header.Set("Authorization", auth) } - client := &http.Client{Timeout: 30 * time.Second} + insecure := getContextBool(ctx, contextkey.Insecure{}) + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: insecure} //nolint:gosec // controlled by --sw-insecure operator flag + client := &http.Client{Transport: transport, Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to execute HTTP request: %w", err) @@ -487,12 +507,10 @@ func validateMQEExpressionRequest(req *MQEExpressionRequest) error { for fieldName, value := range map[string]string{ "service_name": req.ServiceName, - "layer": req.Layer, "service_instance_name": req.ServiceInstanceName, "endpoint_name": req.EndpointName, "process_name": req.ProcessName, "dest_service_name": req.DestServiceName, - "dest_layer": req.DestLayer, "dest_service_instance_name": req.DestServiceInstanceName, "dest_endpoint_name": req.DestEndpointName, "dest_process_name": req.DestProcessName, @@ -502,6 +520,13 @@ func validateMQEExpressionRequest(req *MQEExpressionRequest) error { } } + if err := validateLayerField("layer", req.Layer); err != nil { + return err + } + if err := validateLayerField("dest_layer", req.DestLayer); err != nil { + return err + } + return nil } @@ -512,8 +537,43 @@ func validateMQEMetricsListRequest(req *MQEMetricsListRequest) error { if err := validateMQETextField("regex", req.Regex, maxMQERegexLength); err != nil { return err } - if _, err := regexp.Compile(req.Regex); err != nil { - return fmt.Errorf("regex is invalid") + if err := validateRegexComplexity(req.Regex); err != nil { + return err + } + return nil +} + +const maxRegexNodes = 50 + +// validateRegexComplexity rejects patterns with excessive AST node counts. +func validateRegexComplexity(pattern string) error { + re, err := syntax.Parse(pattern, syntax.Perl) + if err != nil { + return fmt.Errorf("regex is invalid: %w", err) + } + if regexNodeCount(re) > maxRegexNodes { + return fmt.Errorf("regex is too complex") + } + return nil +} + +func regexNodeCount(re *syntax.Regexp) int { + count := 1 + for _, sub := range re.Sub { + count += regexNodeCount(sub) + } + return count +} + +func validateLayerField(fieldName, value string) error { + if value == "" { + return nil + } + if err := validateMQETextField(fieldName, value, maxMQEEntityFieldLen); err != nil { + return err + } + if !layerPattern.MatchString(value) { + return fmt.Errorf("%s contains invalid characters: only uppercase letters, digits, and underscores are allowed", fieldName) } return nil } diff --git a/internal/tools/mqe_test.go b/internal/tools/mqe_test.go index 37d08af..2c1a1c9 100644 --- a/internal/tools/mqe_test.go +++ b/internal/tools/mqe_test.go @@ -40,7 +40,7 @@ func TestValidateMQEExpressionRequestRejectsDeeplyNestedExpression(t *testing.T) func TestValidateMQEMetricsListRequestRejectsInvalidRegex(t *testing.T) { err := validateMQEMetricsListRequest(&MQEMetricsListRequest{Regex: "("}) - if err == nil || err.Error() != "regex is invalid" { + if err == nil || !strings.HasPrefix(err.Error(), "regex is invalid") { t.Fatalf("unexpected error: %v", err) } }