From 8b69a9130fe5f32203e4019e506a31c9c0623a5c Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Tue, 31 Mar 2026 16:59:50 +0800 Subject: [PATCH 1/6] add tests --- internal/tools/alarm_test.go | 99 +++++++++++++++++++++++++ internal/tools/common_test.go | 133 ++++++++++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 internal/tools/alarm_test.go create mode 100644 internal/tools/common_test.go diff --git a/internal/tools/alarm_test.go b/internal/tools/alarm_test.go new file mode 100644 index 0000000..e469b21 --- /dev/null +++ b/internal/tools/alarm_test.go @@ -0,0 +1,99 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" + "time" + + api "skywalking.apache.org/repo/goapi/query" +) + +func TestBuildAlarmQueryCondition(t *testing.T) { + timeCtx := TimeContext{ + NowUTC: time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC), + Location: time.UTC, + } + + req := &AlarmQueryRequest{ + Scope: "Service", + Keyword: "timeout", + Tags: []AlarmTag{ + {Key: "level", Value: "critical"}, + {Key: "team", Value: "payments"}, + }, + Start: "-2h", + End: "now", + PageNum: 2, + PageSize: 30, + } + + cond := buildAlarmQueryCondition(req, timeCtx) + + if cond.Scope != api.Scope("Service") { + t.Fatalf("scope = %q", cond.Scope) + } + if cond.Keyword != "timeout" { + t.Fatalf("keyword = %q", cond.Keyword) + } + if cond.Duration == nil { + t.Fatal("duration is nil") + } + if cond.Duration.Start != "2026-03-31 1000" { + t.Fatalf("start = %q", cond.Duration.Start) + } + if cond.Duration.End != "2026-03-31 1200" { + t.Fatalf("end = %q", cond.Duration.End) + } + if cond.Paging == nil || cond.Paging.PageNum == nil || *cond.Paging.PageNum != 2 { + t.Fatalf("page num = %v", cond.Paging) + } + if cond.Paging.PageSize != 30 { + t.Fatalf("page size = %d", cond.Paging.PageSize) + } + if len(cond.Tags) != 2 { + t.Fatalf("tags len = %d", len(cond.Tags)) + } + if cond.Tags[0].Key != "level" || cond.Tags[0].Value == nil || *cond.Tags[0].Value != "critical" { + t.Fatalf("first tag = %+v", cond.Tags[0]) + } +} + +func TestBuildAlarmQueryConditionDefaults(t *testing.T) { + timeCtx := TimeContext{ + NowUTC: time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC), + Location: time.UTC, + } + + cond := buildAlarmQueryCondition(&AlarmQueryRequest{}, timeCtx) + + if cond.Scope != "" { + t.Fatalf("scope = %q", cond.Scope) + } + if cond.Paging == nil || cond.Paging.PageNum == nil || *cond.Paging.PageNum != DefaultPageNum { + t.Fatalf("default page num = %v", cond.Paging) + } + if cond.Paging.PageSize != DefaultPageSize { + t.Fatalf("default page size = %d", cond.Paging.PageSize) + } + if cond.Duration == nil { + t.Fatal("duration is nil") + } + if cond.Duration.End != "2026-03-31 1200" { + t.Fatalf("end = %q", cond.Duration.End) + } +} diff --git a/internal/tools/common_test.go b/internal/tools/common_test.go new file mode 100644 index 0000000..e5323a8 --- /dev/null +++ b/internal/tools/common_test.go @@ -0,0 +1,133 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" + "time" + + api "skywalking.apache.org/repo/goapi/query" +) + +func TestFinalizeURL(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "adds graphql suffix", in: "http://localhost:12800", want: "http://localhost:12800/graphql"}, + {name: "trims trailing slash", in: "http://localhost:12800/", want: "http://localhost:12800/graphql"}, + {name: "keeps existing graphql", in: "http://localhost:12800/graphql", want: "http://localhost:12800/graphql"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := FinalizeURL(tc.in); got != tc.want { + t.Fatalf("FinalizeURL(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestParseTimezoneOffset(t *testing.T) { + loc, ok := parseTimezoneOffset("+0830") + if !ok { + t.Fatal("expected timezone to parse") + } + if got := loc.String(); got != "+0830" { + t.Fatalf("location name = %q", got) + } + + _, ok = parseTimezoneOffset("UTC") + if ok { + t.Fatal("expected invalid timezone offset to fail") + } +} + +func TestParseDurationWithContextRelativeDuration(t *testing.T) { + now := time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC) + timeCtx := TimeContext{NowUTC: now, Location: time.UTC} + + got := ParseDurationWithContext("-2h", false, timeCtx) + + if got.Start != "2026-03-31 1000" { + t.Fatalf("start = %q", got.Start) + } + if got.End != "2026-03-31 1200" { + t.Fatalf("end = %q", got.End) + } + if got.Step != api.StepMinute { + t.Fatalf("step = %q", got.Step) + } +} + +func TestParseDurationWithContextLegacyDays(t *testing.T) { + now := time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC) + timeCtx := TimeContext{NowUTC: now, Location: time.UTC} + + got := ParseDurationWithContext("7d", true, timeCtx) + + if got.Start != "2026-03-24" { + t.Fatalf("start = %q", got.Start) + } + if got.End != "2026-03-31" { + t.Fatalf("end = %q", got.End) + } + if got.Step != api.StepDay { + t.Fatalf("step = %q", got.Step) + } + if got.ColdStage == nil || !*got.ColdStage { + t.Fatal("expected cold stage to be true") + } +} + +func TestBuildPaginationDefaultsAndCustomValues(t *testing.T) { + gotDefault := BuildPagination(0, 0) + if gotDefault.PageNum == nil || *gotDefault.PageNum != DefaultPageNum { + t.Fatalf("default page num = %v", gotDefault.PageNum) + } + if gotDefault.PageSize != DefaultPageSize { + t.Fatalf("default page size = %d", gotDefault.PageSize) + } + + gotCustom := BuildPagination(3, 50) + if gotCustom.PageNum == nil || *gotCustom.PageNum != 3 { + t.Fatalf("custom page num = %v", gotCustom.PageNum) + } + if gotCustom.PageSize != 50 { + t.Fatalf("custom page size = %d", gotCustom.PageSize) + } +} + +func TestBuildDurationWithContextParsesAbsoluteTimes(t *testing.T) { + timeCtx := TimeContext{ + NowUTC: time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC), + Location: time.FixedZone("+0800", 8*3600), + } + + got := BuildDurationWithContext("2026-03-31 18:00:00", "2026-03-31 20:00:00", "", false, 30, timeCtx) + + if got.Start != "2026-03-31 1000" { + t.Fatalf("start = %q", got.Start) + } + if got.End != "2026-03-31 1200" { + t.Fatalf("end = %q", got.End) + } + if got.Step != api.StepMinute { + t.Fatalf("step = %q", got.Step) + } +} From 298a3ef35e20812dd31433dbb85fbc7df7bc7aa5 Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Tue, 31 Mar 2026 17:30:58 +0800 Subject: [PATCH 2/6] add verification --- CHANGES.md | 10 +++++++ README.md | 4 +++ cmd/skywalking-mcp/main.go | 2 ++ internal/swmcp/server.go | 13 ++++++--- internal/tools/alarm_test.go | 11 +++++--- internal/tools/common.go | 13 +++++++++ internal/tools/common_test.go | 8 +++--- internal/tools/io.go | 13 +++++++-- internal/tools/mqe.go | 51 +++++++++++++++++++++++++++++++---- 9 files changed, 107 insertions(+), 18 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 1d8c6e7..8b98175 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,16 @@ Release Notes. +## Next + +### Security + +* TLS certificate verification is now enforced for OAP connections. Added `--sw-insecure` flag to opt out (development/self-signed certs only). +* Sensitive fields (`authorization`, `password`, `token`, `secret`) are redacted in `--log-command` output. +* Environment variable references (`${VAR}`) in `--sw-username`/`--sw-password` now log a warning when the variable is not set, preventing silent unauthenticated requests. +* URL scheme validation rejects non-http/https OAP URLs. +* Regex patterns supplied to `list_mqe_metrics` are validated for complexity before compilation. + ## 0.1.0 ### Features diff --git a/README.md b/README.md index af91904..5d6d60e 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ Flags: --sw-url string Specify the OAP URL to connect to (e.g. http://localhost:12800) --sw-username string Username for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax) --sw-password string Password for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax) + --sw-insecure Skip TLS certificate verification for OAP connections (use only in development) -v, --version version for swmcp Use "swmcp [command] --help" for more information about a command. @@ -61,6 +62,9 @@ bin/swmcp stdio --sw-url http://localhost:12800 --sw-username admin --sw-passwor # with basic auth (password from environment variable) bin/swmcp stdio --sw-url http://localhost:12800 --sw-username admin --sw-password '${SW_PASSWORD}' +# skip TLS verification (development only, e.g. self-signed certs) +bin/swmcp stdio --sw-url https://localhost:12800 --sw-insecure + # or use SSE server bin/swmcp sse --sse-address localhost:8000 --base-path /mcp --sw-url http://localhost:12800 ``` diff --git a/cmd/skywalking-mcp/main.go b/cmd/skywalking-mcp/main.go index 66babdc..1a80d8f 100644 --- a/cmd/skywalking-mcp/main.go +++ b/cmd/skywalking-mcp/main.go @@ -59,6 +59,7 @@ func init() { rootCmd.PersistentFlags().String("sw-url", "", "Specify the OAP URL to connect to (e.g. http://localhost:12800)") rootCmd.PersistentFlags().String("sw-username", "", "Username for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax)") rootCmd.PersistentFlags().String("sw-password", "", "Password for basic auth to SkyWalking OAP (supports ${ENV_VAR} syntax)") + rootCmd.PersistentFlags().Bool("sw-insecure", false, "Skip TLS certificate verification for OAP connections (use only in development)") rootCmd.PersistentFlags().String("log-level", "info", "Logging level (debug, info, warn, error)") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") rootCmd.PersistentFlags().Bool("log-command", false, "When true, log commands to the log file") @@ -68,6 +69,7 @@ func init() { _ = viper.BindPFlag("url", rootCmd.PersistentFlags().Lookup("sw-url")) _ = viper.BindPFlag("username", rootCmd.PersistentFlags().Lookup("sw-username")) _ = viper.BindPFlag("password", rootCmd.PersistentFlags().Lookup("sw-password")) + _ = viper.BindPFlag("insecure", rootCmd.PersistentFlags().Lookup("sw-insecure")) _ = viper.BindPFlag("log-level", rootCmd.PersistentFlags().Lookup("log-level")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) _ = viper.BindPFlag("log-command", rootCmd.PersistentFlags().Lookup("log-command")) diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go index 0cdba97..e5307d4 100644 --- a/internal/swmcp/server.go +++ b/internal/swmcp/server.go @@ -107,7 +107,12 @@ func resolveEnvVar(value string) string { trimmed := strings.TrimSpace(value) if strings.HasPrefix(trimmed, "${") && strings.HasSuffix(trimmed, "}") { envName := trimmed[2 : len(trimmed)-1] - return os.Getenv(envName) + resolved, ok := os.LookupEnv(envName) + if !ok { + logrus.Warnf("environment variable %q is referenced but not set", envName) + return "" + } + return resolved } return value } @@ -130,7 +135,7 @@ func withConfiguredAuth(ctx context.Context) context.Context { // with SkyWalking settings from the global configuration. func EnhanceStdioContextFunc() server.StdioContextFunc { return func(ctx context.Context) context.Context { - ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), viper.GetBool("insecure")) ctx = withConfiguredAuth(ctx) return ctx } @@ -140,7 +145,7 @@ func EnhanceStdioContextFunc() server.StdioContextFunc { // with SkyWalking settings from the CLI configuration and configured auth. func EnhanceSSEContextFunc() server.SSEContextFunc { return func(ctx context.Context, _ *http.Request) context.Context { - ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), viper.GetBool("insecure")) ctx = withConfiguredAuth(ctx) return ctx } @@ -150,7 +155,7 @@ func EnhanceSSEContextFunc() server.SSEContextFunc { // with SkyWalking settings from the CLI configuration and configured auth. func EnhanceHTTPContextFunc() server.HTTPContextFunc { return func(ctx context.Context, _ *http.Request) context.Context { - ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), viper.GetBool("insecure")) ctx = withConfiguredAuth(ctx) return ctx } diff --git a/internal/tools/alarm_test.go b/internal/tools/alarm_test.go index e469b21..7c5e682 100644 --- a/internal/tools/alarm_test.go +++ b/internal/tools/alarm_test.go @@ -23,6 +23,11 @@ import ( api "skywalking.apache.org/repo/goapi/query" ) +const ( + testTimeStart = "2026-03-31 1000" + testTimeEnd = "2026-03-31 1200" +) + func TestBuildAlarmQueryCondition(t *testing.T) { timeCtx := TimeContext{ NowUTC: time.Date(2026, 3, 31, 12, 0, 0, 0, time.UTC), @@ -53,10 +58,10 @@ func TestBuildAlarmQueryCondition(t *testing.T) { if cond.Duration == nil { t.Fatal("duration is nil") } - if cond.Duration.Start != "2026-03-31 1000" { + if cond.Duration.Start != testTimeStart { t.Fatalf("start = %q", cond.Duration.Start) } - if cond.Duration.End != "2026-03-31 1200" { + if cond.Duration.End != testTimeEnd { t.Fatalf("end = %q", cond.Duration.End) } if cond.Paging == nil || cond.Paging.PageNum == nil || *cond.Paging.PageNum != 2 { @@ -93,7 +98,7 @@ func TestBuildAlarmQueryConditionDefaults(t *testing.T) { if cond.Duration == nil { t.Fatal("duration is nil") } - if cond.Duration.End != "2026-03-31 1200" { + if cond.Duration.End != testTimeEnd { t.Fatalf("end = %q", cond.Duration.End) } } diff --git a/internal/tools/common.go b/internal/tools/common.go index fc89e0a..86b8fcb 100644 --- a/internal/tools/common.go +++ b/internal/tools/common.go @@ -20,6 +20,7 @@ package tools import ( "context" "fmt" + "net/url" "strconv" "strings" "time" @@ -50,6 +51,18 @@ func FinalizeURL(urlStr string) string { return urlStr } +// validateURLScheme ensures the URL uses http or https. +func validateURLScheme(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid OAP URL: %w", err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("unsupported OAP URL scheme %q: only http and https are allowed", u.Scheme) + } + return nil +} + // FormatTimeByStep formats time according to step granularity func FormatTimeByStep(t time.Time, step api.Step) string { switch step { diff --git a/internal/tools/common_test.go b/internal/tools/common_test.go index e5323a8..c9a1ef7 100644 --- a/internal/tools/common_test.go +++ b/internal/tools/common_test.go @@ -64,10 +64,10 @@ func TestParseDurationWithContextRelativeDuration(t *testing.T) { got := ParseDurationWithContext("-2h", false, timeCtx) - if got.Start != "2026-03-31 1000" { + if got.Start != testTimeStart { t.Fatalf("start = %q", got.Start) } - if got.End != "2026-03-31 1200" { + if got.End != testTimeEnd { t.Fatalf("end = %q", got.End) } if got.Step != api.StepMinute { @@ -121,10 +121,10 @@ func TestBuildDurationWithContextParsesAbsoluteTimes(t *testing.T) { got := BuildDurationWithContext("2026-03-31 18:00:00", "2026-03-31 20:00:00", "", false, 30, timeCtx) - if got.Start != "2026-03-31 1000" { + if got.Start != testTimeStart { t.Fatalf("start = %q", got.Start) } - if got.End != "2026-03-31 1200" { + if got.End != testTimeEnd { t.Fatalf("end = %q", got.End) } if got.Step != api.StepMinute { diff --git a/internal/tools/io.go b/internal/tools/io.go index 3eeaee0..d4a6380 100644 --- a/internal/tools/io.go +++ b/internal/tools/io.go @@ -19,10 +19,19 @@ package tools import ( "io" + "regexp" log "github.com/sirupsen/logrus" ) +// sensitiveFieldPattern matches JSON fields whose values should be redacted in logs. +var sensitiveFieldPattern = regexp.MustCompile(`(?i)("(?:authorization|password|token|secret)"\s*:\s*")((?:[^"\\]|\\.)*)(")`) //nolint:lll // regex must be on one line + +// redactSensitiveData masks values of sensitive JSON fields before logging. +func redactSensitiveData(data string) string { + return sensitiveFieldPattern.ReplaceAllString(data, `${1}[REDACTED]${3}`) +} + // IOLogger is a wrapper around io.Reader and io.Writer that can be used // to log the data being read and written from the underlying streams type IOLogger struct { @@ -47,7 +56,7 @@ func (l *IOLogger) Read(p []byte) (n int, err error) { } n, err = l.reader.Read(p) if n > 0 { - l.logger.Infof("[stdin]: received %d bytes: %s", n, string(p[:n])) + l.logger.Infof("[stdin]: received %d bytes: %s", n, redactSensitiveData(string(p[:n]))) } return n, err } @@ -57,6 +66,6 @@ func (l *IOLogger) Write(p []byte) (n int, err error) { if l.writer == nil { return 0, io.ErrClosedPipe } - l.logger.Infof("[stdout]: sending %d bytes: %s", len(p), string(p)) + l.logger.Infof("[stdout]: sending %d bytes: %s", len(p), redactSensitiveData(string(p))) return l.writer.Write(p) } diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index 3a45396..8b19049 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -20,12 +20,14 @@ package tools import ( "bytes" "context" + "crypto/tls" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "regexp" + "regexp/syntax" "time" "unicode" "unicode/utf8" @@ -75,10 +77,22 @@ func getContextString(ctx context.Context, key any) string { return "" } +// getContextBool safely extracts a bool value from context. +func getContextBool(ctx context.Context, key any) bool { + if v, ok := ctx.Value(key).(bool); ok { + return v + } + return false +} + // executeGraphQLWithContext executes a GraphQL query using URL and auth from context. func executeGraphQLWithContext(ctx context.Context, query string, variables map[string]interface{}) (*GraphQLResponse, error) { - url := getContextString(ctx, contextkey.BaseURL{}) - url = FinalizeURL(url) + rawURL := getContextString(ctx, contextkey.BaseURL{}) + rawURL = FinalizeURL(rawURL) + + if err := validateURLScheme(rawURL); err != nil { + return nil, err + } reqBody := GraphQLRequest{ Query: query, @@ -90,7 +104,7 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, "POST", rawURL, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) } @@ -105,7 +119,12 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ req.Header.Set("Authorization", auth) } - client := &http.Client{Timeout: 30 * time.Second} + insecure := getContextBool(ctx, contextkey.Insecure{}) + //nolint:gosec // InsecureSkipVerify is intentional and controlled by the --sw-insecure operator flag + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: insecure}, + } + client := &http.Client{Transport: transport, Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to execute HTTP request: %w", err) @@ -512,12 +531,34 @@ func validateMQEMetricsListRequest(req *MQEMetricsListRequest) error { if err := validateMQETextField("regex", req.Regex, maxMQERegexLength); err != nil { return err } - if _, err := regexp.Compile(req.Regex); err != nil { + if err := validateRegexComplexity(req.Regex); err != nil { + return err + } + return nil +} + +const maxRegexNodes = 50 + +// validateRegexComplexity rejects patterns with excessive AST node counts. +func validateRegexComplexity(pattern string) error { + re, err := syntax.Parse(pattern, syntax.Perl) + if err != nil { return fmt.Errorf("regex is invalid") } + if regexNodeCount(re) > maxRegexNodes { + return fmt.Errorf("regex is too complex") + } return nil } +func regexNodeCount(re *syntax.Regexp) int { + count := 1 + for _, sub := range re.Sub { + count += regexNodeCount(sub) + } + return count +} + func validateMetricName(metricName string) error { if err := validateMQETextField("metric_name", metricName, maxMetricNameLength); err != nil { return err From 64ce6098d457bfffdd1742bb788607174c0adafb Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Tue, 31 Mar 2026 18:13:24 +0800 Subject: [PATCH 3/6] add --allowed-origins flag --- CHANGES.md | 8 ++- CLAUDE.md | 22 +++++- README.md | 15 +++- internal/swmcp/cors.go | 75 ++++++++++++++++++++ internal/swmcp/server_registry_test.go | 97 ++++++++++++++------------ internal/swmcp/sse.go | 17 +++++ internal/swmcp/streamable.go | 20 ++++++ internal/tools/mqe.go | 25 ++++++- 8 files changed, 228 insertions(+), 51 deletions(-) create mode 100644 internal/swmcp/cors.go diff --git a/CHANGES.md b/CHANGES.md index 8b98175..beb8e10 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,15 +2,19 @@ Release Notes. -## Next +## 0.2.0 -### Security +### Features * TLS certificate verification is now enforced for OAP connections. Added `--sw-insecure` flag to opt out (development/self-signed certs only). * Sensitive fields (`authorization`, `password`, `token`, `secret`) are redacted in `--log-command` output. * Environment variable references (`${VAR}`) in `--sw-username`/`--sw-password` now log a warning when the variable is not set, preventing silent unauthenticated requests. * URL scheme validation rejects non-http/https OAP URLs. * Regex patterns supplied to `list_mqe_metrics` are validated for complexity before compilation. +* Added `--allowed-origins` flag to `sse` and `streamable` transports for CORS/CSRF origin enforcement. Requests carrying an `Origin` header not in the allowlist receive `403 Forbidden`. When unset, all origins are permitted (backward compatible). +* Implement unit tests. +* Remove the unnecessary tool and parameter. +* Validate properties for tools. ## 0.1.0 diff --git a/CLAUDE.md b/CLAUDE.md index dfaf53c..71b52de 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -45,9 +45,15 @@ The SkyWalking OAP URL is resolved in priority order: SSE and HTTP transports always use the configured server URL. -Basic auth is configured via `--sw-username` / `--sw-password` flags. The startup flags support `${ENV_VAR}` syntax to resolve credentials from environment variables (e.g. `--sw-password ${MY_SECRET}`). +Basic auth is configured via `--sw-username` / `--sw-password` flags. The startup flags support `${ENV_VAR}` syntax to resolve credentials from environment variables (e.g. `--sw-password ${MY_SECRET}`). If a referenced env var is not set, a warning is logged and the credential is treated as empty. -Each transport injects the OAP URL and auth into the request context via `WithSkyWalkingURLAndInsecure()` and `WithSkyWalkingAuth()`. Tools extract them downstream using `skywalking-cli`'s `contextkey.BaseURL{}`, `contextkey.Username{}`, and `contextkey.Password{}`. +TLS verification is enforced by default. Use `--sw-insecure` to skip verification (development/self-signed certs only). + +Each transport injects the OAP URL, insecure flag, and auth into the request context via `WithSkyWalkingURLAndInsecure()` and `WithSkyWalkingAuth()`. Tools extract them downstream using `skywalking-cli`'s `contextkey.BaseURL{}`, `contextkey.Insecure{}`, `contextkey.Username{}`, and `contextkey.Password{}`. + +### CORS / CSRF (`internal/swmcp/cors.go`) + +`sse` and `streamable` transports support `--allowed-origins` (comma-separated). When set, requests with an `Origin` header not in the list are rejected with `403 Forbidden`. CORS response headers are set for allowed origins. When the flag is empty (default), all origins are permitted. The middleware is injected via `WithHTTPServer` / `WithStreamableHTTPServer` so the MCP handler is wrapped rather than forked. ### Server Wiring (`internal/swmcp/server.go`) @@ -60,9 +66,19 @@ Each transport injects the OAP URL and auth into the request context via `WithSk ### Communication with SkyWalking OAP - **Most tools** use `skywalking-cli` packages (`pkg/graphql/...`) which communicate via GraphQL -- **MQE tools** use direct HTTP calls to the OAP `/graphql` endpoint +- **MQE tools** use direct HTTP calls to the OAP `/graphql` endpoint via `executeGraphQLWithContext()` in `mqe.go`. The HTTP client reads `contextkey.Insecure{}` to configure TLS and validates the URL scheme (`http`/`https` only) before each request. - **Time handling**: `common.go` provides `BuildDurationWithContext()` and `GetTimeContext()` which fetch the OAP server's time/timezone for accurate duration calculations +### Input Validation (`internal/tools/mqe.go`) + +All MQE tool inputs are validated before use: +- `validateMQETextField`: UTF-8, max length, no control characters — applied to all entity fields +- `validateLayerField`: additionally enforces `^[A-Z0-9_]+$` for `layer` / `dest_layer` +- `validateMQEExpression`: UTF-8, max 2048 chars, no control characters, max nesting depth 12 +- `validateMetricName`: `^[A-Za-z0-9_.:-]+$` pattern, max 128 chars +- `validateRegexComplexity`: parses the regex AST via `regexp/syntax` and rejects patterns with >50 nodes +- `validateURLScheme` (`common.go`): rejects non-http/https OAP URLs before HTTP requests + ## Extending the Server ### Adding a New Tool diff --git a/README.md b/README.md index 5d6d60e..cc1f950 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Available Commands: stdio Start stdio server streamable Start Streamable server -Flags: +Global Flags: -h, --help help for swmcp --log-command When true, log commands to the log file --log-file string Path to log file @@ -47,6 +47,16 @@ Flags: --sw-insecure Skip TLS certificate verification for OAP connections (use only in development) -v, --version version for swmcp +SSE-specific Flags: + --sse-address string Host and port for the SSE server (default "localhost:8000") + --base-path string Base path for the SSE server + --allowed-origins string Comma-separated list of allowed CORS origins. Empty permits all. + +Streamable-specific Flags: + --address string Host and port for the Streamable HTTP server (default "localhost:8000") + --endpoint-path string Endpoint path for the Streamable HTTP server (default "/mcp") + --allowed-origins string Comma-separated list of allowed CORS origins. Empty permits all. + Use "swmcp [command] --help" for more information about a command. ``` @@ -67,6 +77,9 @@ bin/swmcp stdio --sw-url https://localhost:12800 --sw-insecure # or use SSE server bin/swmcp sse --sse-address localhost:8000 --base-path /mcp --sw-url http://localhost:12800 + +# restrict CORS to specific origins (SSE and streamable transports) +bin/swmcp streamable --sw-url http://localhost:12800 --allowed-origins "http://localhost:3000,https://app.example.com" ``` Transport URL behavior: diff --git a/internal/swmcp/cors.go b/internal/swmcp/cors.go new file mode 100644 index 0000000..4814d49 --- /dev/null +++ b/internal/swmcp/cors.go @@ -0,0 +1,75 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package swmcp + +import ( + "net/http" + "strings" +) + +// corsMiddleware adds CORS response headers and enforces origin validation. +// When allowedOrigins is empty, all origins are permitted (no restriction). +// When populated, only listed origins receive CORS headers; requests from +// any other origin with an Origin header receive 403 Forbidden. +func corsMiddleware(allowedOrigins []string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" { + if isOriginAllowed(origin, allowedOrigins) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept") + w.Header().Set("Vary", "Origin") + } else if len(allowedOrigins) > 0 { + http.Error(w, "forbidden: origin not allowed", http.StatusForbidden) + return + } + } + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) +} + +// isOriginAllowed reports whether origin is in the allowed list. +// The wildcard "*" matches any origin. +func isOriginAllowed(origin string, allowed []string) bool { + for _, a := range allowed { + if a == "*" || a == origin { + return true + } + } + return false +} + +// parseAllowedOrigins splits a comma-separated list of origins. +func parseAllowedOrigins(raw string) []string { + if raw == "" { + return nil + } + parts := strings.Split(raw, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + if trimmed := strings.TrimSpace(p); trimmed != "" { + result = append(result, trimmed) + } + } + return result +} diff --git a/internal/swmcp/server_registry_test.go b/internal/swmcp/server_registry_test.go index 55d1384..ebc7073 100644 --- a/internal/swmcp/server_registry_test.go +++ b/internal/swmcp/server_registry_test.go @@ -27,10 +27,11 @@ import ( ) // These registry tests verify that newMCPServer wires up the expected tools, -// prompts, and resources. mcp-go v0.45.0 does not expose a public inventory API -// for MCPServer, so the tests read server internals through a single helper -// layer below. If mcp-go changes its internal field layout, update only the -// helpers in this file rather than spreading reflect/unsafe access across tests. +// prompts, and resources. Prefer MCPServer public APIs where available. As of +// mcp-go v0.45.0 only tools have a public inventory API, so prompt/resource +// assertions go through the helper layer below. If a future mcp-go release +// exposes prompt/resource listing, replace the reflection helper there rather +// than spreading reflect/unsafe access across tests. func TestNewMCPServerRegistersExpectedTools(t *testing.T) { srv := newMCPServer() @@ -60,8 +61,9 @@ func TestNewMCPServerRegistersExpectedTools(t *testing.T) { func TestNewMCPServerRegistersExpectedPrompts(t *testing.T) { srv := newMCPServer() + inventory := inspectServerInventory(srv) - got := sortedPromptNames(srv) + got := sortedPromptNames(inventory.prompts) want := []string{ "analyze-logs", "analyze-performance", @@ -80,8 +82,9 @@ func TestNewMCPServerRegistersExpectedPrompts(t *testing.T) { func TestNewMCPServerRegistersExpectedResources(t *testing.T) { srv := newMCPServer() + inventory := inspectServerInventory(srv) - resources := resourceMap(srv) + resources := inventory.resources got := make([]string, 0, len(resources)) for uri := range resources { got = append(got, uri) @@ -100,7 +103,7 @@ func TestNewMCPServerRegistersExpectedResources(t *testing.T) { func TestPromptMetadataIncludesExpectedArguments(t *testing.T) { srv := newMCPServer() - prompts := promptMap(srv) + prompts := inspectServerInventory(srv).prompts prompt, ok := prompts["generate_duration"] if !ok { @@ -130,7 +133,7 @@ func TestPromptMetadataIncludesExpectedArguments(t *testing.T) { func TestResourceMetadataIncludesExpectedMIMETypes(t *testing.T) { srv := newMCPServer() - resources := resourceMap(srv) + resources := inspectServerInventory(srv).resources tests := []struct { uri string @@ -192,39 +195,10 @@ func TestToolMetadataIncludesExpectedDescriptionsAndSchemas(t *testing.T) { } func toolMap(srv *server.MCPServer) map[string]mcp.Tool { - serverTools := mustReadServerField(testedServerValue(srv), "tools") - result := make(map[string]mcp.Tool, serverTools.Len()) - - iter := serverTools.MapRange() - for iter.Next() { - name := iter.Key().String() - toolValue := copyReflectValue(iter.Value()) - result[name] = toolValue.FieldByName("Tool").Interface().(mcp.Tool) - } - - return result -} - -func promptMap(srv *server.MCPServer) map[string]mcp.Prompt { - serverPrompts := mustReadServerField(testedServerValue(srv), "prompts") - result := make(map[string]mcp.Prompt, serverPrompts.Len()) - - iter := serverPrompts.MapRange() - for iter.Next() { - result[iter.Key().String()] = copyReflectValue(iter.Value()).Interface().(mcp.Prompt) - } - - return result -} - -func resourceMap(srv *server.MCPServer) map[string]mcp.Resource { - serverResources := mustReadServerField(testedServerValue(srv), "resources") - result := make(map[string]mcp.Resource, serverResources.Len()) - - iter := serverResources.MapRange() - for iter.Next() { - resourceField := copyReflectValue(iter.Value()).FieldByName("resource") - result[iter.Key().String()] = readPrivateField(resourceField).Interface().(mcp.Resource) + serverTools := srv.ListTools() + result := make(map[string]mcp.Tool, len(serverTools)) + for name, tool := range serverTools { + result[name] = tool.Tool } return result @@ -240,8 +214,7 @@ func sortedToolNames(srv *server.MCPServer) []string { return names } -func sortedPromptNames(srv *server.MCPServer) []string { - prompts := promptMap(srv) +func sortedPromptNames(prompts map[string]mcp.Prompt) []string { names := make([]string, 0, len(prompts)) for name := range prompts { names = append(names, name) @@ -273,6 +246,44 @@ func mustReadServerField(srv reflect.Value, fieldName string) reflect.Value { return readPrivateField(field) } +type serverInventory struct { + prompts map[string]mcp.Prompt + resources map[string]mcp.Resource +} + +func inspectServerInventory(srv *server.MCPServer) serverInventory { + serverValue := testedServerValue(srv) + return serverInventory{ + prompts: readPromptMap(serverValue), + resources: readResourceMap(serverValue), + } +} + +func readPromptMap(serverValue reflect.Value) map[string]mcp.Prompt { + serverPrompts := mustReadServerField(serverValue, "prompts") + result := make(map[string]mcp.Prompt, serverPrompts.Len()) + + iter := serverPrompts.MapRange() + for iter.Next() { + result[iter.Key().String()] = copyReflectValue(iter.Value()).Interface().(mcp.Prompt) + } + + return result +} + +func readResourceMap(serverValue reflect.Value) map[string]mcp.Resource { + serverResources := mustReadServerField(serverValue, "resources") + result := make(map[string]mcp.Resource, serverResources.Len()) + + iter := serverResources.MapRange() + for iter.Next() { + resourceField := copyReflectValue(iter.Value()).FieldByName("resource") + result[iter.Key().String()] = readPrivateField(resourceField).Interface().(mcp.Resource) + } + + return result +} + func copyReflectValue(v reflect.Value) reflect.Value { cloned := reflect.New(v.Type()).Elem() cloned.Set(v) diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go index 14365a9..7e92da4 100644 --- a/internal/swmcp/sse.go +++ b/internal/swmcp/sse.go @@ -55,8 +55,11 @@ func NewSSEServer() *cobra.Command { "The host and port to start the sse server on") sseCmd.Flags().String("base-path", "", "Base path for the sse server") + sseCmd.Flags().String("allowed-origins", "", + "Comma-separated list of allowed CORS origins (e.g. http://localhost:3000,https://app.example.com). Empty allows all.") _ = viper.BindPFlag("sse-address", sseCmd.Flags().Lookup("sse-address")) _ = viper.BindPFlag("base-path", sseCmd.Flags().Lookup("base-path")) + _ = viper.BindPFlag("allowed-origins", sseCmd.Flags().Lookup("allowed-origins")) return sseCmd } @@ -71,11 +74,25 @@ func runSSEServer(ctx context.Context, cfg *config.SSEServerConfig) error { return fmt.Errorf("failed to initialize logger: %w", err) } + allowedOrigins := parseAllowedOrigins(viper.GetString("allowed-origins")) + + // sseServer is assigned after NewSSEServer so the CORS handler closure + // can forward to it via the captured pointer variable. + var sseServerRef *server.SSEServer + customSrv := &http.Server{ + Handler: corsMiddleware(allowedOrigins, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sseServerRef.ServeHTTP(w, r) + })), + ReadHeaderTimeout: 10 * time.Second, + } + sseServer := server.NewSSEServer( newMCPServer(), server.WithStaticBasePath(cfg.BasePath), server.WithSSEContextFunc(EnhanceSSEContextFunc()), + server.WithHTTPServer(customSrv), ) + sseServerRef = sseServer ssePath := sseServer.CompleteSsePath() log.Printf("Starting SkyWalking MCP server using SSE transport listening on http://%s%s\n ", cfg.Address, ssePath) diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go index 0500352..5fc94c5 100644 --- a/internal/swmcp/streamable.go +++ b/internal/swmcp/streamable.go @@ -19,6 +19,8 @@ package swmcp import ( "fmt" + "net/http" + "time" "github.com/mark3labs/mcp-go/server" log "github.com/sirupsen/logrus" @@ -48,21 +50,39 @@ func NewStreamable() *cobra.Command { "The host and port to start the Streamable server on") streamableCmd.Flags().String("endpoint-path", "/mcp", "The path for the streamable-http server") + streamableCmd.Flags().String("allowed-origins", "", + "Comma-separated list of allowed CORS origins (e.g. http://localhost:3000,https://app.example.com). Empty allows all.") _ = viper.BindPFlag("address", streamableCmd.Flags().Lookup("address")) _ = viper.BindPFlag("endpoint-path", streamableCmd.Flags().Lookup("endpoint-path")) + _ = viper.BindPFlag("allowed-origins", streamableCmd.Flags().Lookup("allowed-origins")) return streamableCmd } // runStreamableServer starts the Streamable server with the provided configuration. func runStreamableServer(cfg *config.StreamableServerConfig) error { + allowedOrigins := parseAllowedOrigins(viper.GetString("allowed-origins")) + + // httpServer is assigned after NewStreamableHTTPServer so the CORS handler + // closure can forward to it via the captured pointer variable. + var httpServerRef *server.StreamableHTTPServer + customSrv := &http.Server{ + Handler: corsMiddleware(allowedOrigins, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpServerRef.ServeHTTP(w, r) + })), + ReadHeaderTimeout: 10 * time.Second, + } + httpServer := server.NewStreamableHTTPServer( newMCPServer(), server.WithStateLess(true), server.WithLogger(log.StandardLogger()), server.WithHTTPContextFunc(EnhanceHTTPContextFunc()), server.WithEndpointPath(viper.GetString("endpoint-path")), + server.WithStreamableHTTPServer(customSrv), ) + httpServerRef = httpServer + log.Infof("streamable HTTP server listening on %s%s\n", cfg.Address, cfg.EndpointPath) if err := httpServer.Start(cfg.Address); err != nil { diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index 8b19049..6abb286 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -55,6 +55,9 @@ const ( var metricNamePattern = regexp.MustCompile(`^[A-Za-z0-9_.:-]+$`) +// layerPattern restricts layer values to the SkyWalking enum format (e.g. GENERAL, K8S_SERVICE). +var layerPattern = regexp.MustCompile(`^[A-Z0-9_]+$`) + // GraphQLRequest represents a GraphQL request type GraphQLRequest struct { Query string `json:"query"` @@ -506,12 +509,10 @@ func validateMQEExpressionRequest(req *MQEExpressionRequest) error { for fieldName, value := range map[string]string{ "service_name": req.ServiceName, - "layer": req.Layer, "service_instance_name": req.ServiceInstanceName, "endpoint_name": req.EndpointName, "process_name": req.ProcessName, "dest_service_name": req.DestServiceName, - "dest_layer": req.DestLayer, "dest_service_instance_name": req.DestServiceInstanceName, "dest_endpoint_name": req.DestEndpointName, "dest_process_name": req.DestProcessName, @@ -521,6 +522,13 @@ func validateMQEExpressionRequest(req *MQEExpressionRequest) error { } } + if err := validateLayerField("layer", req.Layer); err != nil { + return err + } + if err := validateLayerField("dest_layer", req.DestLayer); err != nil { + return err + } + return nil } @@ -559,6 +567,19 @@ func regexNodeCount(re *syntax.Regexp) int { return count } +func validateLayerField(fieldName, value string) error { + if value == "" { + return nil + } + if err := validateMQETextField(fieldName, value, maxMQEEntityFieldLen); err != nil { + return err + } + if !layerPattern.MatchString(value) { + return fmt.Errorf("%s contains invalid characters: only uppercase letters, digits, and underscores are allowed", fieldName) + } + return nil +} + func validateMetricName(metricName string) error { if err := validateMQETextField("metric_name", metricName, maxMetricNameLength); err != nil { return err From cb320bde5738f07407230105d86307e1e9a41915 Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Tue, 31 Mar 2026 19:47:37 +0800 Subject: [PATCH 4/6] Update internal/tools/mqe.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/tools/mqe.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index 6abb286..66e40c6 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -551,7 +551,7 @@ const maxRegexNodes = 50 func validateRegexComplexity(pattern string) error { re, err := syntax.Parse(pattern, syntax.Perl) if err != nil { - return fmt.Errorf("regex is invalid") + return fmt.Errorf("regex is invalid: %w", err) } if regexNodeCount(re) > maxRegexNodes { return fmt.Errorf("regex is too complex") From 7dc199029f5055b32c3f1915239d24bd87a4227b Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Tue, 31 Mar 2026 20:01:16 +0800 Subject: [PATCH 5/6] update --- CHANGES.md | 2 +- README.md | 4 ++-- internal/swmcp/cors.go | 12 +++++----- internal/swmcp/sse.go | 2 +- internal/swmcp/streamable.go | 2 +- internal/tools/io.go | 43 ++++++++++++++++++++++++++++-------- internal/tools/mqe.go | 6 ++--- 7 files changed, 48 insertions(+), 23 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index beb8e10..6dc7d4e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,7 +11,7 @@ Release Notes. * Environment variable references (`${VAR}`) in `--sw-username`/`--sw-password` now log a warning when the variable is not set, preventing silent unauthenticated requests. * URL scheme validation rejects non-http/https OAP URLs. * Regex patterns supplied to `list_mqe_metrics` are validated for complexity before compilation. -* Added `--allowed-origins` flag to `sse` and `streamable` transports for CORS/CSRF origin enforcement. Requests carrying an `Origin` header not in the allowlist receive `403 Forbidden`. When unset, all origins are permitted (backward compatible). +* Added `--allowed-origins` flag to `sse` and `streamable` transports for CORS origin enforcement. When unset (default), any `Origin` is reflected back so all browser origins work out of the box. When set, only listed origins receive CORS headers; all others get `403 Forbidden`. Use `*` as an entry to send the wildcard header explicitly. * Implement unit tests. * Remove the unnecessary tool and parameter. * Validate properties for tools. diff --git a/README.md b/README.md index cc1f950..d176830 100644 --- a/README.md +++ b/README.md @@ -50,12 +50,12 @@ Global Flags: SSE-specific Flags: --sse-address string Host and port for the SSE server (default "localhost:8000") --base-path string Base path for the SSE server - --allowed-origins string Comma-separated list of allowed CORS origins. Empty permits all. + --allowed-origins string Comma-separated list of allowed CORS origins. Empty reflects any origin (open CORS). Use * to send the wildcard header. Streamable-specific Flags: --address string Host and port for the Streamable HTTP server (default "localhost:8000") --endpoint-path string Endpoint path for the Streamable HTTP server (default "/mcp") - --allowed-origins string Comma-separated list of allowed CORS origins. Empty permits all. + --allowed-origins string Comma-separated list of allowed CORS origins. Empty reflects any origin (open CORS). Use * to send the wildcard header. Use "swmcp [command] --help" for more information about a command. ``` diff --git a/internal/swmcp/cors.go b/internal/swmcp/cors.go index 4814d49..8f1b0f9 100644 --- a/internal/swmcp/cors.go +++ b/internal/swmcp/cors.go @@ -23,19 +23,21 @@ import ( ) // corsMiddleware adds CORS response headers and enforces origin validation. -// When allowedOrigins is empty, all origins are permitted (no restriction). -// When populated, only listed origins receive CORS headers; requests from -// any other origin with an Origin header receive 403 Forbidden. +// When allowedOrigins is empty, every request with an Origin header is +// reflected back — i.e., CORS is open and all browser origins work. +// When allowedOrigins is non-empty, only listed origins receive CORS headers; +// requests from any other origin receive 403 Forbidden. Use "*" as an entry +// to explicitly allow all origins via the wildcard header. func corsMiddleware(allowedOrigins []string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin != "" { - if isOriginAllowed(origin, allowedOrigins) { + if len(allowedOrigins) == 0 || isOriginAllowed(origin, allowedOrigins) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept") w.Header().Set("Vary", "Origin") - } else if len(allowedOrigins) > 0 { + } else { http.Error(w, "forbidden: origin not allowed", http.StatusForbidden) return } diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go index 7e92da4..ae06a07 100644 --- a/internal/swmcp/sse.go +++ b/internal/swmcp/sse.go @@ -56,7 +56,7 @@ func NewSSEServer() *cobra.Command { sseCmd.Flags().String("base-path", "", "Base path for the sse server") sseCmd.Flags().String("allowed-origins", "", - "Comma-separated list of allowed CORS origins (e.g. http://localhost:3000,https://app.example.com). Empty allows all.") + "Comma-separated allowed CORS origins. Empty = open (any origin reflected). Use * for wildcard header.") _ = viper.BindPFlag("sse-address", sseCmd.Flags().Lookup("sse-address")) _ = viper.BindPFlag("base-path", sseCmd.Flags().Lookup("base-path")) _ = viper.BindPFlag("allowed-origins", sseCmd.Flags().Lookup("allowed-origins")) diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go index 5fc94c5..0f85fb1 100644 --- a/internal/swmcp/streamable.go +++ b/internal/swmcp/streamable.go @@ -51,7 +51,7 @@ func NewStreamable() *cobra.Command { streamableCmd.Flags().String("endpoint-path", "/mcp", "The path for the streamable-http server") streamableCmd.Flags().String("allowed-origins", "", - "Comma-separated list of allowed CORS origins (e.g. http://localhost:3000,https://app.example.com). Empty allows all.") + "Comma-separated allowed CORS origins. Empty = open (any origin reflected). Use * for wildcard header.") _ = viper.BindPFlag("address", streamableCmd.Flags().Lookup("address")) _ = viper.BindPFlag("endpoint-path", streamableCmd.Flags().Lookup("endpoint-path")) _ = viper.BindPFlag("allowed-origins", streamableCmd.Flags().Lookup("allowed-origins")) diff --git a/internal/tools/io.go b/internal/tools/io.go index d4a6380..7c2d70f 100644 --- a/internal/tools/io.go +++ b/internal/tools/io.go @@ -18,6 +18,7 @@ package tools import ( + "bytes" "io" "regexp" @@ -32,12 +33,16 @@ func redactSensitiveData(data string) string { return sensitiveFieldPattern.ReplaceAllString(data, `${1}[REDACTED]${3}`) } -// IOLogger is a wrapper around io.Reader and io.Writer that can be used -// to log the data being read and written from the underlying streams +// IOLogger is a wrapper around io.Reader and io.Writer that logs complete +// newline-delimited JSON-RPC messages. Partial chunks are held in per-direction +// buffers until a newline arrives, ensuring the redaction regex always sees a +// full message and secrets split across read boundaries are never partially logged. type IOLogger struct { - reader io.Reader - writer io.Writer - logger *log.Logger + reader io.Reader + writer io.Writer + logger *log.Logger + readBuf bytes.Buffer + writeBuf bytes.Buffer } // NewIOLogger creates a new IOLogger instance @@ -49,23 +54,43 @@ func NewIOLogger(r io.Reader, w io.Writer, logger *log.Logger) *IOLogger { } } -// Read reads data from the underlying io.Reader and logs it. +// logCompleteLines drains newline-terminated lines from buf, redacts each one, +// and logs it under the given direction label. Any trailing partial line is left +// in buf for the next call. +func (l *IOLogger) logCompleteLines(buf *bytes.Buffer, direction string) { + data := buf.Bytes() + for { + idx := bytes.IndexByte(data, '\n') + if idx < 0 { + break + } + line := bytes.TrimRight(data[:idx], "\r") + l.logger.Infof("[%s]: %s", direction, redactSensitiveData(string(line))) + data = data[idx+1:] + } + buf.Reset() + buf.Write(data) +} + +// Read reads data from the underlying io.Reader and logs complete lines. func (l *IOLogger) Read(p []byte) (n int, err error) { if l.reader == nil { return 0, io.EOF } n, err = l.reader.Read(p) if n > 0 { - l.logger.Infof("[stdin]: received %d bytes: %s", n, redactSensitiveData(string(p[:n]))) + l.readBuf.Write(p[:n]) + l.logCompleteLines(&l.readBuf, "stdin") } return n, err } -// Write writes data to the underlying io.Writer and logs it. +// Write writes data to the underlying io.Writer and logs complete lines. func (l *IOLogger) Write(p []byte) (n int, err error) { if l.writer == nil { return 0, io.ErrClosedPipe } - l.logger.Infof("[stdout]: sending %d bytes: %s", len(p), redactSensitiveData(string(p))) + l.writeBuf.Write(p) + l.logCompleteLines(&l.writeBuf, "stdout") return l.writer.Write(p) } diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index 6abb286..dcb1ee4 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -123,10 +123,8 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ } insecure := getContextBool(ctx, contextkey.Insecure{}) - //nolint:gosec // InsecureSkipVerify is intentional and controlled by the --sw-insecure operator flag - transport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: insecure}, - } + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: insecure} //nolint:gosec // controlled by --sw-insecure operator flag client := &http.Client{Transport: transport, Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { From 3ca71ac447a1c7c7d32c75e7cbe7617923a9b9e9 Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Tue, 31 Mar 2026 20:02:21 +0800 Subject: [PATCH 6/6] Update CHANGES.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- CHANGES.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index beb8e10..5d1245c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -12,9 +12,9 @@ Release Notes. * URL scheme validation rejects non-http/https OAP URLs. * Regex patterns supplied to `list_mqe_metrics` are validated for complexity before compilation. * Added `--allowed-origins` flag to `sse` and `streamable` transports for CORS/CSRF origin enforcement. Requests carrying an `Origin` header not in the allowlist receive `403 Forbidden`. When unset, all origins are permitted (backward compatible). -* Implement unit tests. -* Remove the unnecessary tool and parameter. -* Validate properties for tools. +* Increased reliability of core CLI commands through expanded automated test coverage. +* Removed an unused CLI tool and its associated parameter to simplify the interface and avoid confusion. +* Added validation for tool configuration properties, returning clear errors when required values are missing or invalid. ## 0.1.0