From e522bb87d103bc905ab311c56c8fefc87a981a0f Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Wed, 1 Apr 2026 15:51:30 +0800 Subject: [PATCH 1/2] add validation for sw url --- internal/swmcp/server.go | 16 +++++++++++++++- internal/swmcp/server_test.go | 28 +++++++++++++++++++++++++++ internal/swmcp/sse.go | 4 ++++ internal/swmcp/stdio.go | 4 ++++ internal/swmcp/streamable.go | 4 ++++ internal/tools/common.go | 9 +++++++++ internal/tools/common_test.go | 36 +++++++++++++++++++++++++++++++++++ 7 files changed, 100 insertions(+), 1 deletion(-) diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go index e5307d4..f77254d 100644 --- a/internal/swmcp/server.go +++ b/internal/swmcp/server.go @@ -97,7 +97,21 @@ func configuredSkyWalkingURL() string { if urlStr == "" { urlStr = config.DefaultSWURL } - return tools.FinalizeURL(urlStr) + normalizedURL, err := tools.NormalizeOAPURL(urlStr) + if err != nil { + return tools.FinalizeURL(urlStr) + } + return normalizedURL +} + +func validateConfiguredSkyWalkingURL() error { + urlStr := viper.GetString("url") + if urlStr == "" { + urlStr = config.DefaultSWURL + } + + _, err := tools.NormalizeOAPURL(urlStr) + return err } // resolveEnvVar resolves a value that may contain an environment variable reference diff --git a/internal/swmcp/server_test.go b/internal/swmcp/server_test.go index 1007e56..5f969b1 100644 --- a/internal/swmcp/server_test.go +++ b/internal/swmcp/server_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/apache/skywalking-cli/pkg/contextkey" + "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/apache/skywalking-mcp/internal/config" @@ -52,6 +53,33 @@ func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) { } } +func TestValidateConfiguredSkyWalkingURLRejectsUnsupportedScheme(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "ftp://configured-oap.example.com:12800") + + err := validateConfiguredSkyWalkingURL() + if err == nil { + t.Fatal("validateConfiguredSkyWalkingURL() error = nil, want error") + } +} + +func TestTransportCommandsRejectInvalidSWURL(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "ftp://configured-oap.example.com:12800") + + for name, cmd := range map[string]*cobra.Command{ + "stdio": NewStdioServer(), + "sse": NewSSEServer(), + "streamable": NewStreamable(), + } { + t.Run(name, func(t *testing.T) { + if err := cmd.RunE(cmd, nil); err == nil { + t.Fatal("RunE() error = nil, want invalid sw-url error") + } + }) + } +} + func TestResolveEnvVar(t *testing.T) { t.Setenv("SW_TEST_SECRET", "resolved-secret") diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go index ae06a07..0dab635 100644 --- a/internal/swmcp/sse.go +++ b/internal/swmcp/sse.go @@ -41,6 +41,10 @@ func NewSSEServer() *cobra.Command { Short: "Start SSE server", Long: `Start a server that listens for Server-Sent Events (SSE) on the specified address.`, RunE: func(_ *cobra.Command, _ []string) error { + if err := validateConfiguredSkyWalkingURL(); err != nil { + return err + } + sseServerConfig := config.SSEServerConfig{ Address: viper.GetString("sse-address"), BasePath: viper.GetString("base-path"), diff --git a/internal/swmcp/stdio.go b/internal/swmcp/stdio.go index 02abb4a..12fe05b 100644 --- a/internal/swmcp/stdio.go +++ b/internal/swmcp/stdio.go @@ -41,6 +41,10 @@ func NewStdioServer() *cobra.Command { Short: "Start stdio server", Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { + if err := validateConfiguredSkyWalkingURL(); err != nil { + return err + } + stdioServerConfig := config.StdioServerConfig{ URL: viper.GetString("url"), ReadOnly: viper.GetBool("read-only"), diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go index 0f85fb1..5d692ad 100644 --- a/internal/swmcp/streamable.go +++ b/internal/swmcp/streamable.go @@ -36,6 +36,10 @@ func NewStreamable() *cobra.Command { Short: "Start Streamable server", Long: `Starting SkyWalking MCP server with Streamable HTTP transport.`, RunE: func(_ *cobra.Command, _ []string) error { + if err := validateConfiguredSkyWalkingURL(); err != nil { + return err + } + streamableConfig := config.StreamableServerConfig{ Address: viper.GetString("address"), EndpointPath: viper.GetString("endpoint-path"), diff --git a/internal/tools/common.go b/internal/tools/common.go index 86b8fcb..eb031e8 100644 --- a/internal/tools/common.go +++ b/internal/tools/common.go @@ -51,6 +51,15 @@ func FinalizeURL(urlStr string) string { return urlStr } +// NormalizeOAPURL appends the GraphQL path when needed and rejects unsupported URL schemes. +func NormalizeOAPURL(rawURL string) (string, error) { + finalizedURL := FinalizeURL(rawURL) + if err := validateURLScheme(finalizedURL); err != nil { + return "", err + } + return finalizedURL, nil +} + // validateURLScheme ensures the URL uses http or https. func validateURLScheme(rawURL string) error { u, err := url.Parse(rawURL) diff --git a/internal/tools/common_test.go b/internal/tools/common_test.go index eb60ce1..e4a985a 100644 --- a/internal/tools/common_test.go +++ b/internal/tools/common_test.go @@ -17,6 +17,7 @@ package tools import ( + "strings" "testing" "time" @@ -50,6 +51,41 @@ func TestFinalizeURL(t *testing.T) { } } +func TestNormalizeOAPURL(t *testing.T) { + tests := []struct { + name string + in string + want string + wantErr string + }{ + {name: "http", in: "http://localhost:12800", want: "http://localhost:12800/graphql"}, + {name: "https", in: "https://localhost:12800/graphql", want: "https://localhost:12800/graphql"}, + {name: "rejects unsupported scheme", in: "ftp://localhost:12800", wantErr: "unsupported OAP URL scheme \"ftp\""}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := NormalizeOAPURL(tc.in) + if tc.wantErr != "" { + if err == nil { + t.Fatalf("NormalizeOAPURL(%q) error = nil, want %q", tc.in, tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("NormalizeOAPURL(%q) error = %q, want substring %q", tc.in, err.Error(), tc.wantErr) + } + return + } + + if err != nil { + t.Fatalf("NormalizeOAPURL(%q) unexpected error: %v", tc.in, err) + } + if got != tc.want { + t.Fatalf("NormalizeOAPURL(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + func TestParseTimezoneOffset(t *testing.T) { loc, ok := parseTimezoneOffset("+0830") if !ok { From 33757bd2713a85cfd1c9c5b6a9780568b3ef30d0 Mon Sep 17 00:00:00 2001 From: Fine0830 Date: Wed, 1 Apr 2026 16:40:29 +0800 Subject: [PATCH 2/2] update --- internal/swmcp/server.go | 18 +++++++++--------- internal/swmcp/server_test.go | 10 ++++++++++ internal/tools/common.go | 32 +++++++++++++++++++++----------- internal/tools/common_test.go | 4 ++++ internal/tools/mqe.go | 7 +++---- 5 files changed, 47 insertions(+), 24 deletions(-) diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go index f77254d..c24e888 100644 --- a/internal/swmcp/server.go +++ b/internal/swmcp/server.go @@ -93,24 +93,24 @@ func WithSkyWalkingAuth(ctx context.Context, username, password string) context. // The value is sourced from the CLI/config binding for `--sw-url`, // falling back to the built-in default when unset. func configuredSkyWalkingURL() string { - urlStr := viper.GetString("url") - if urlStr == "" { - urlStr = config.DefaultSWURL - } - normalizedURL, err := tools.NormalizeOAPURL(urlStr) + resolvedURL, err := resolvedConfiguredSkyWalkingURL() if err != nil { - return tools.FinalizeURL(urlStr) + logrus.WithError(err).Warn("invalid SkyWalking OAP URL configuration; falling back to default URL") + return config.DefaultSWURL } - return normalizedURL + return resolvedURL } -func validateConfiguredSkyWalkingURL() error { +func resolvedConfiguredSkyWalkingURL() (string, error) { urlStr := viper.GetString("url") if urlStr == "" { urlStr = config.DefaultSWURL } + return tools.NormalizeOAPURL(urlStr) +} - _, err := tools.NormalizeOAPURL(urlStr) +func validateConfiguredSkyWalkingURL() error { + _, err := resolvedConfiguredSkyWalkingURL() return err } diff --git a/internal/swmcp/server_test.go b/internal/swmcp/server_test.go index 5f969b1..8ec433e 100644 --- a/internal/swmcp/server_test.go +++ b/internal/swmcp/server_test.go @@ -53,6 +53,16 @@ func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) { } } +func TestConfiguredSkyWalkingURLFallsBackToDefaultOnInvalidValue(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "ftp://configured-oap.example.com:12800") + + got := configuredSkyWalkingURL() + if got != config.DefaultSWURL { + t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, config.DefaultSWURL) + } +} + func TestValidateConfiguredSkyWalkingURLRejectsUnsupportedScheme(t *testing.T) { t.Cleanup(viper.Reset) viper.Set("url", "ftp://configured-oap.example.com:12800") diff --git a/internal/tools/common.go b/internal/tools/common.go index eb031e8..83e3db2 100644 --- a/internal/tools/common.go +++ b/internal/tools/common.go @@ -45,27 +45,37 @@ const ( // FinalizeURL ensures the URL ends with "/graphql". func FinalizeURL(urlStr string) string { - if !strings.HasSuffix(urlStr, "/graphql") { - urlStr = strings.TrimRight(urlStr, "/") + "/graphql" + normalizedURL, err := NormalizeOAPURL(urlStr) + if err == nil { + return normalizedURL } return urlStr } -// NormalizeOAPURL appends the GraphQL path when needed and rejects unsupported URL schemes. +// NormalizeOAPURL parses and validates the OAP URL, then ensures the path ends with /graphql. func NormalizeOAPURL(rawURL string) (string, error) { - finalizedURL := FinalizeURL(rawURL) - if err := validateURLScheme(finalizedURL); err != nil { + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("invalid OAP URL: %w", err) + } + if err := validateURLScheme(u); err != nil { return "", err } - return finalizedURL, nil + if u.Host == "" { + return "", fmt.Errorf("invalid OAP URL %q: host is required", rawURL) + } + + if u.Path == "" || u.Path == "/" { + u.Path = "/graphql" + } else if !strings.HasSuffix(u.Path, "/graphql") { + u.Path = strings.TrimRight(u.Path, "/") + "/graphql" + } + + return u.String(), nil } // validateURLScheme ensures the URL uses http or https. -func validateURLScheme(rawURL string) error { - u, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("invalid OAP URL: %w", err) - } +func validateURLScheme(u *url.URL) error { if u.Scheme != "http" && u.Scheme != "https" { return fmt.Errorf("unsupported OAP URL scheme %q: only http and https are allowed", u.Scheme) } diff --git a/internal/tools/common_test.go b/internal/tools/common_test.go index e4a985a..29e06b7 100644 --- a/internal/tools/common_test.go +++ b/internal/tools/common_test.go @@ -40,6 +40,7 @@ func TestFinalizeURL(t *testing.T) { {name: "adds graphql suffix", in: "http://localhost:12800", want: "http://localhost:12800/graphql"}, {name: "trims trailing slash", in: "http://localhost:12800/", want: "http://localhost:12800/graphql"}, {name: "keeps existing graphql", in: "http://localhost:12800/graphql", want: "http://localhost:12800/graphql"}, + {name: "preserves query string", in: "http://localhost:12800?x=1", want: "http://localhost:12800/graphql?x=1"}, } for _, tc := range tests { @@ -60,7 +61,10 @@ func TestNormalizeOAPURL(t *testing.T) { }{ {name: "http", in: "http://localhost:12800", want: "http://localhost:12800/graphql"}, {name: "https", in: "https://localhost:12800/graphql", want: "https://localhost:12800/graphql"}, + {name: "preserves query and fragment", in: "https://localhost:12800/oap?debug=1#frag", want: "https://localhost:12800/oap/graphql?debug=1#frag"}, {name: "rejects unsupported scheme", in: "ftp://localhost:12800", wantErr: "unsupported OAP URL scheme \"ftp\""}, + {name: "rejects missing host", in: "http://", wantErr: "host is required"}, + {name: "rejects malformed hostless path", in: "http:/foo", wantErr: "host is required"}, } for _, tc := range tests { diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index e9f4e5c..032d456 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -91,9 +91,8 @@ func getContextBool(ctx context.Context, key any) bool { // executeGraphQLWithContext executes a GraphQL query using URL and auth from context. func executeGraphQLWithContext(ctx context.Context, query string, variables map[string]interface{}) (*GraphQLResponse, error) { rawURL := getContextString(ctx, contextkey.BaseURL{}) - rawURL = FinalizeURL(rawURL) - - if err := validateURLScheme(rawURL); err != nil { + normalizedURL, err := NormalizeOAPURL(rawURL) + if err != nil { return nil, err } @@ -107,7 +106,7 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", rawURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, "POST", normalizedURL, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) }