diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 6c798fd..53789f5 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -9,33 +9,13 @@ import ( "io" "log/slog" "net/http" - "strings" "time" - "unicode" + "github.com/go-authgate/agent-scanner/internal/httperrors" "github.com/go-authgate/agent-scanner/internal/models" "github.com/go-authgate/agent-scanner/internal/tlsutil" ) -// clientError is a non-retryable HTTP error (4xx). -type clientError struct { - StatusCode int - Body string -} - -func (e *clientError) Error() string { - return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) -} - -// nonRetryableError wraps errors that should not be retried -// (e.g., request construction failures, JSON decode errors). -type nonRetryableError struct { - err error -} - -func (e *nonRetryableError) Error() string { return e.err.Error() } -func (e *nonRetryableError) Unwrap() error { return e.err } - // Analyzer performs security analysis on scan results. type Analyzer interface { Analyze(ctx context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) @@ -144,12 +124,12 @@ func (a *remoteAnalyzer) analyzePathResult( break } // Do not retry non-retryable errors (bad URL, JSON decode, etc.) - var nre *nonRetryableError + var nre *httperrors.NonRetryableError if errors.As(err, &nre) { return fmt.Errorf("analysis API: %w", err) } // Do not retry client errors (4xx) - var ce *clientError + var ce *httperrors.ClientError if errors.As(err, &ce) { return fmt.Errorf("analysis API: %w", err) } @@ -182,7 +162,7 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy bytes.NewReader(body), ) if err != nil { - return &nonRetryableError{err: err} + return &httperrors.NonRetryableError{Err: err} } req.Header.Set("Content-Type", "application/json") @@ -194,30 +174,15 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy if httpResp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(httpResp.Body, 4096)) - bodySnippet := sanitizeBodySnippet(string(respBody), 512) + bodySnippet := httperrors.SanitizeBodySnippet(string(respBody), 512) if httpResp.StatusCode < 500 { - return &clientError{StatusCode: httpResp.StatusCode, Body: bodySnippet} + return &httperrors.ClientError{StatusCode: httpResp.StatusCode, Body: bodySnippet} } return fmt.Errorf("status %d: %s", httpResp.StatusCode, bodySnippet) } if err := json.NewDecoder(httpResp.Body).Decode(resp); err != nil { - return &nonRetryableError{err: fmt.Errorf("decode response: %w", err)} + return &httperrors.NonRetryableError{Err: fmt.Errorf("decode response: %w", err)} } return nil } - -// sanitizeBodySnippet truncates s to approximately maxLen bytes (the -// returned string may be slightly longer due to a " [truncated]" suffix) -// and replaces all Unicode control characters with spaces for safe single-line logging. -func sanitizeBodySnippet(s string, maxLen int) string { - if len(s) > maxLen { - s = s[:maxLen] + " [truncated]" - } - return strings.Map(func(r rune) rune { - if unicode.IsControl(r) { - return ' ' - } - return r - }, s) -} diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index a95023d..05ed53a 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -98,3 +98,83 @@ func TestParseControlServers_MoreServersThanIdentifiers(t *testing.T) { t.Errorf("server[2] Identifier = %q, want empty", servers[2].Identifier) } } + +func TestParseControlServers_WithHeaders(t *testing.T) { + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{ + ControlServers: []string{"https://a.example.com", "https://b.example.com"}, + ControlHeaders: []string{"Authorization: Bearer token123; X-Custom: value"}, + } + + servers := parseControlServers() + if len(servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(servers)) + } + + if servers[0].Headers == nil { + t.Fatal("expected headers for server[0]") + } + if servers[0].Headers["Authorization"] != "Bearer token123" { + t.Errorf( + "Authorization = %q, want %q", + servers[0].Headers["Authorization"], + "Bearer token123", + ) + } + if servers[0].Headers["X-Custom"] != "value" { + t.Errorf("X-Custom = %q, want %q", servers[0].Headers["X-Custom"], "value") + } + + // server[1] has no matching header entry + if servers[1].Headers != nil { + t.Errorf("expected nil headers for server[1], got %v", servers[1].Headers) + } +} + +func TestParseHeaders(t *testing.T) { + tests := []struct { + name string + raw string + want map[string]string + }{ + { + name: "single header", + raw: "Authorization: Bearer abc", + want: map[string]string{"Authorization": "Bearer abc"}, + }, + { + name: "multiple headers", + raw: "Key1: Val1; Key2: Val2", + want: map[string]string{"Key1": "Val1", "Key2": "Val2"}, + }, + { + name: "empty string", + raw: "", + want: nil, + }, + { + name: "whitespace only", + raw: " ; ", + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseHeaders(tt.raw) + if tt.want == nil { + if got != nil { + t.Errorf("expected nil, got %v", got) + } + return + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("header %q = %q, want %q", k, got[k], v) + } + } + }) + } +} diff --git a/internal/cli/flags.go b/internal/cli/flags.go index 0dce9f4..85d889c 100644 --- a/internal/cli/flags.go +++ b/internal/cli/flags.go @@ -7,7 +7,6 @@ type CommonFlags struct { StorageFile string AnalysisURL string VerificationH []string - OAuthTokensPath string Verbose bool PrintErrors bool PrintFullDescs bool @@ -48,8 +47,6 @@ func addCommonFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&commonFlags.AnalysisURL, "analysis-url", "", "Verification server URL") cmd.Flags(). StringSliceVar(&commonFlags.VerificationH, "verification-H", nil, "Additional headers for verification API") - cmd.Flags(). - StringVar(&commonFlags.OAuthTokensPath, "mcp-oauth-tokens-path", "", "OAuth token storage path") cmd.Flags().BoolVar(&commonFlags.Verbose, "verbose", false, "Enable verbose logging") cmd.Flags(). BoolVar(&commonFlags.PrintErrors, "print-errors", false, "Print server startup errors/tracebacks") diff --git a/internal/cli/scan.go b/internal/cli/scan.go index 8d0ac94..996abba 100644 --- a/internal/cli/scan.go +++ b/internal/cli/scan.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "strings" "time" "github.com/go-authgate/agent-scanner/internal/analysis" @@ -108,7 +109,30 @@ func parseControlServers() []pipeline.ControlServerConfig { if i < len(scanFlags.ControlIdentifier) { cs.Identifier = scanFlags.ControlIdentifier[i] } + if i < len(scanFlags.ControlHeaders) { + cs.Headers = parseHeaders(scanFlags.ControlHeaders[i]) + } servers = append(servers, cs) } return servers } + +// parseHeaders parses a semicolon-separated header string into a map. +// Each header is in "Key: Value" format. +func parseHeaders(raw string) map[string]string { + headers := make(map[string]string) + for part := range strings.SplitSeq(raw, ";") { + part = strings.TrimSpace(part) + if key, value, ok := strings.Cut(part, ":"); ok { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key != "" { + headers[key] = value + } + } + } + if len(headers) == 0 { + return nil + } + return headers +} diff --git a/internal/httperrors/httperrors.go b/internal/httperrors/httperrors.go new file mode 100644 index 0000000..bd0b01f --- /dev/null +++ b/internal/httperrors/httperrors.go @@ -0,0 +1,41 @@ +package httperrors + +import ( + "fmt" + "strings" + "unicode" +) + +// ClientError is a non-retryable HTTP error (4xx). +type ClientError struct { + StatusCode int + Body string +} + +func (e *ClientError) Error() string { + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) +} + +// NonRetryableError wraps errors that should not be retried +// (e.g., request construction failures, JSON decode errors). +type NonRetryableError struct { + Err error +} + +func (e *NonRetryableError) Error() string { return e.Err.Error() } +func (e *NonRetryableError) Unwrap() error { return e.Err } + +// SanitizeBodySnippet truncates s to approximately maxLen bytes (the +// returned string may be slightly longer due to a " [truncated]" suffix) +// and replaces all Unicode control characters with spaces for safe single-line logging. +func SanitizeBodySnippet(s string, maxLen int) string { + if len(s) > maxLen { + s = s[:maxLen] + " [truncated]" + } + return strings.Map(func(r rune) rune { + if unicode.IsControl(r) { + return ' ' + } + return r + }, s) +} diff --git a/internal/httperrors/httperrors_test.go b/internal/httperrors/httperrors_test.go new file mode 100644 index 0000000..cecc58d --- /dev/null +++ b/internal/httperrors/httperrors_test.go @@ -0,0 +1,77 @@ +package httperrors + +import ( + "errors" + "strings" + "testing" +) + +func TestClientError_Error(t *testing.T) { + err := &ClientError{StatusCode: 403, Body: "forbidden"} + got := err.Error() + if got != "status 403: forbidden" { + t.Errorf("got %q, want %q", got, "status 403: forbidden") + } +} + +func TestNonRetryableError_Unwrap(t *testing.T) { + inner := errors.New("bad request") + err := &NonRetryableError{Err: inner} + + if err.Error() != "bad request" { + t.Errorf("Error() = %q, want %q", err.Error(), "bad request") + } + if !errors.Is(err, inner) { + t.Error("expected errors.Is to find inner error") + } +} + +func TestSanitizeBodySnippet(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + checks func(t *testing.T, result string) + }{ + { + name: "short string unchanged", + input: "hello world", + maxLen: 100, + checks: func(t *testing.T, result string) { + if result != "hello world" { + t.Errorf("got %q", result) + } + }, + }, + { + name: "truncated", + input: "abcdefghij", + maxLen: 5, + checks: func(t *testing.T, result string) { + if !strings.HasPrefix(result, "abcde") { + t.Errorf("got %q, want prefix 'abcde'", result) + } + if !strings.Contains(result, "[truncated]") { + t.Error("expected [truncated] suffix") + } + }, + }, + { + name: "control chars replaced", + input: "line1\nline2\ttab\x00null", + maxLen: 100, + checks: func(t *testing.T, result string) { + if strings.ContainsAny(result, "\n\t\x00") { + t.Errorf("control characters not replaced: %q", result) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeBodySnippet(tt.input, tt.maxLen) + tt.checks(t, result) + }) + } +} diff --git a/internal/inspect/skill.go b/internal/inspect/skill.go index 5a03f37..d5cec79 100644 --- a/internal/inspect/skill.go +++ b/internal/inspect/skill.go @@ -11,6 +11,14 @@ import ( // ParseSkillDirectory reads a skill directory and returns its signature. func ParseSkillDirectory(path string) (*models.ServerSignature, error) { + // Resolve symlinks on the base path so traverseSkillTree can + // verify that all children stay within the real directory. + resolvedPath, err := filepath.EvalSymlinks(path) + if err != nil { + return nil, fmt.Errorf("resolve skill directory: %w", err) + } + path = resolvedPath + // Find SKILL.md skillMDPath := findSkillMD(path) if skillMDPath == "" { @@ -104,7 +112,15 @@ func parseSkillFrontmatter(content string) (string, string) { return name, description } +// isWithinBase checks that resolved stays within the resolvedBase directory. +func isWithinBase(resolvedBase, resolved string) bool { + // Ensure base ends with separator for prefix check + base := resolvedBase + string(filepath.Separator) + return resolved == resolvedBase || strings.HasPrefix(resolved, base) +} + // traverseSkillTree recursively scans a skill directory for entities. +// basePath must already be resolved via filepath.EvalSymlinks. func traverseSkillTree( basePath, relativePath string, ) ([]models.Prompt, []models.Resource, []models.Tool) { @@ -127,7 +143,16 @@ func traverseSkillTree( relPath := filepath.Join(relativePath, name) fullPath := filepath.Join(basePath, relPath) - if entry.IsDir() { + // Resolve symlinks and verify the target stays within basePath + resolved, err := filepath.EvalSymlinks(fullPath) + if err != nil { + continue + } + if !isWithinBase(basePath, resolved) { + continue + } + + if entry.IsDir() || (entry.Type()&os.ModeSymlink != 0 && isDir(resolved)) { p, r, t := traverseSkillTree(basePath, relPath) prompts = append(prompts, p...) resources = append(resources, r...) @@ -141,7 +166,7 @@ func traverseSkillTree( } ext := strings.ToLower(filepath.Ext(name)) - content, err := os.ReadFile(fullPath) + content, err := os.ReadFile(resolved) if err != nil { continue } @@ -169,6 +194,11 @@ func traverseSkillTree( return prompts, resources, tools } +func isDir(path string) bool { + fi, err := os.Stat(path) + return err == nil && fi.IsDir() +} + func guessMimeType(ext string) string { switch ext { case ".json": diff --git a/internal/inspect/skill_test.go b/internal/inspect/skill_test.go index f19feb4..af0b81f 100644 --- a/internal/inspect/skill_test.go +++ b/internal/inspect/skill_test.go @@ -3,6 +3,7 @@ package inspect import ( "os" "path/filepath" + "runtime" "testing" ) @@ -97,6 +98,13 @@ func TestParseSkillFrontmatter(t *testing.T) { func TestTraverseSkillTree(t *testing.T) { dir := t.TempDir() + // Resolve symlinks (e.g. macOS /var → /private/var) so basePath + // matches what ParseSkillDirectory would produce. + dir, err := filepath.EvalSymlinks(dir) + if err != nil { + t.Fatal(err) + } + // Create files of different types os.WriteFile(filepath.Join(dir, "SKILL.md"), []byte("---\nname: test\n---"), 0o644) os.WriteFile(filepath.Join(dir, "readme.md"), []byte("# Readme"), 0o644) @@ -115,3 +123,61 @@ func TestTraverseSkillTree(t *testing.T) { t.Errorf("expected 1 resource, got %d", len(resources)) } } + +func TestTraverseSkillTree_SymlinkOutsideBase(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink test not supported on Windows") + } + + // Create a skill directory and an outside directory with a secret file + skillDir := t.TempDir() + outsideDir := t.TempDir() + + os.WriteFile(filepath.Join(outsideDir, "secret.py"), []byte("SECRET_KEY='abc'"), 0o644) + + // Create a symlink inside the skill directory pointing outside + symlinkPath := filepath.Join(skillDir, "escape") + if err := os.Symlink(outsideDir, symlinkPath); err != nil { + t.Fatal(err) + } + + // Resolve basePath like ParseSkillDirectory does + resolvedBase, err := filepath.EvalSymlinks(skillDir) + if err != nil { + t.Fatal(err) + } + + prompts, _, tools := traverseSkillTree(resolvedBase, "") + + // The symlinked directory should be skipped — no tools from outside + for _, tool := range tools { + if tool.Name == "escape/secret.py" || tool.Description == "SECRET_KEY='abc'" { + t.Error("symlinked file outside base was included as a tool") + } + } + for _, prompt := range prompts { + if prompt.Description == "SECRET_KEY='abc'" { + t.Error("symlinked file outside base was included as a prompt") + } + } +} + +func TestIsWithinBase(t *testing.T) { + tests := []struct { + base, path string + want bool + }{ + {"/a/b", "/a/b", true}, + {"/a/b", "/a/b/c", true}, + {"/a/b", "/a/b/c/d", true}, + {"/a/b", "/a/bc", false}, + {"/a/b", "/a", false}, + {"/a/b", "/x/y", false}, + } + for _, tt := range tests { + got := isWithinBase(tt.base, tt.path) + if got != tt.want { + t.Errorf("isWithinBase(%q, %q) = %v, want %v", tt.base, tt.path, got, tt.want) + } + } +} diff --git a/internal/mcpclient/client.go b/internal/mcpclient/client.go index 7634146..e0a6f7d 100644 --- a/internal/mcpclient/client.go +++ b/internal/mcpclient/client.go @@ -3,6 +3,7 @@ package mcpclient import ( "context" "fmt" + "time" "github.com/go-authgate/agent-scanner/internal/models" ) @@ -59,6 +60,6 @@ func (c *client) Connect( return nil, fmt.Errorf("connect transport: %w", err) } - session := NewSession(transport) + session := NewSession(transport, time.Duration(timeout)*time.Second) return session, nil } diff --git a/internal/mcpclient/http.go b/internal/mcpclient/http.go index 7d54f63..6d729b5 100644 --- a/internal/mcpclient/http.go +++ b/internal/mcpclient/http.go @@ -68,7 +68,7 @@ func NewHTTPTransport(server *models.RemoteServer, timeout int, skipSSLVerify bo } func (t *httpTransport) Connect(_ context.Context) error { - slog.Debug("HTTP transport ready", "url", t.server.URL) + slog.Debug("HTTP transport ready", "url", sanitizeURL(t.server.URL)) return nil } @@ -110,7 +110,7 @@ func (t *httpTransport) Send(ctx context.Context, msg *JSONRPCMessage) error { if strings.Contains(contentType, "text/event-stream") { if resp.StatusCode >= 400 { - body, _ := io.ReadAll(resp.Body) + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) resp.Body.Close() return fmt.Errorf("HTTP send: status %d: %s", resp.StatusCode, string(body)) } @@ -131,11 +131,11 @@ func (t *httpTransport) Send(ctx context.Context, msg *JSONRPCMessage) error { // Regular JSON response defer resp.Body.Close() if resp.StatusCode >= 400 { - body, _ := io.ReadAll(resp.Body) + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return fmt.Errorf("HTTP send: status %d: %s", resp.StatusCode, string(body)) } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) // 10MB max if err != nil { return fmt.Errorf("read response: %w", err) } diff --git a/internal/mcpclient/sanitize.go b/internal/mcpclient/sanitize.go new file mode 100644 index 0000000..b3edd4e --- /dev/null +++ b/internal/mcpclient/sanitize.go @@ -0,0 +1,37 @@ +package mcpclient + +import ( + "net/url" + + "github.com/go-authgate/agent-scanner/internal/redact" +) + +// sanitizeURL parses a URL and redacts query parameter values for safe logging. +func sanitizeURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + if u.RawQuery == "" { + return rawURL + } + q := u.Query() + for key := range q { + q.Set(key, redact.RedactedValue) + } + u.RawQuery = q.Encode() + return u.String() +} + +// sanitizeArgs returns a copy of args with path-like and secret-like values redacted. +func sanitizeArgs(args []string) []string { + out := make([]string, len(args)) + for i, arg := range args { + if redact.IsPath(arg) || redact.LooksLikeSecret(arg) { + out[i] = redact.RedactedValue + } else { + out[i] = arg + } + } + return out +} diff --git a/internal/mcpclient/sanitize_test.go b/internal/mcpclient/sanitize_test.go new file mode 100644 index 0000000..9f7147b --- /dev/null +++ b/internal/mcpclient/sanitize_test.go @@ -0,0 +1,88 @@ +package mcpclient + +import ( + "strings" + "testing" +) + +func TestSanitizeURL(t *testing.T) { + tests := []struct { + name string + input string + checks func(t *testing.T, result string) + }{ + { + name: "no query params", + input: "https://example.com/mcp", + checks: func(t *testing.T, result string) { + if result != "https://example.com/mcp" { + t.Errorf("got %s, want unchanged", result) + } + }, + }, + { + name: "query params redacted", + input: "https://example.com/mcp?token=secret123&key=abc", + checks: func(t *testing.T, result string) { + if strings.Contains(result, "secret123") { + t.Error("token value not redacted") + } + if strings.Contains(result, "abc") { + t.Error("key value not redacted") + } + if !strings.Contains(result, "token=") { + t.Error("token key should be preserved") + } + }, + }, + { + name: "invalid URL returned as-is", + input: "://invalid", + checks: func(t *testing.T, result string) { + if result != "://invalid" { + t.Errorf("got %s, want unchanged for invalid URL", result) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sanitizeURL(tt.input) + tt.checks(t, result) + }) + } +} + +func TestSanitizeArgs(t *testing.T) { + args := []string{ + "--config", + "/home/user/secret.json", + "sk-abc123", + "--port", + "8080", + } + + result := sanitizeArgs(args) + + if result[0] != "--config" { + t.Errorf("expected --config, got %s", result[0]) + } + if result[1] != "**REDACTED**" { + t.Errorf("expected path redacted, got %s", result[1]) + } + if result[2] != "**REDACTED**" { + t.Errorf("expected secret redacted, got %s", result[2]) + } + if result[3] != "--port" { + t.Errorf("expected --port, got %s", result[3]) + } + if result[4] != "8080" { + t.Errorf("expected 8080, got %s", result[4]) + } + + // Verify original args not modified + if args[1] != "/home/user/secret.json" { + t.Error("original args should not be modified") + } +} diff --git a/internal/mcpclient/session.go b/internal/mcpclient/session.go index 7001dfe..9ca249f 100644 --- a/internal/mcpclient/session.go +++ b/internal/mcpclient/session.go @@ -21,20 +21,29 @@ type Session interface { Close() error } +const defaultCallTimeout = 30 * time.Second + type session struct { - transport Transport - nextID atomic.Int64 - pending map[int]chan *JSONRPCMessage - mu sync.Mutex - done chan struct{} + transport Transport + nextID atomic.Int64 + pending map[int]chan *JSONRPCMessage + mu sync.Mutex + done chan struct{} + callTimeout time.Duration } // NewSession creates a new MCP session from a connected transport. -func NewSession(transport Transport) Session { +// The timeout parameter controls how long each RPC call waits for a response. +// If timeout is <= 0, a default of 30 seconds is used. +func NewSession(transport Transport, timeout time.Duration) Session { + if timeout <= 0 { + timeout = defaultCallTimeout + } s := &session{ - transport: transport, - pending: make(map[int]chan *JSONRPCMessage), - done: make(chan struct{}), + transport: transport, + pending: make(map[int]chan *JSONRPCMessage), + done: make(chan struct{}), + callTimeout: timeout, } go s.readLoop() return s @@ -96,7 +105,7 @@ func (s *session) call(ctx context.Context, method string, params any) (*JSONRPC delete(s.pending, id) s.mu.Unlock() return nil, ctx.Err() - case <-time.After(30 * time.Second): + case <-time.After(s.callTimeout): s.mu.Lock() delete(s.pending, id) s.mu.Unlock() diff --git a/internal/mcpclient/sse.go b/internal/mcpclient/sse.go index 1df4b67..3a21f63 100644 --- a/internal/mcpclient/sse.go +++ b/internal/mcpclient/sse.go @@ -10,6 +10,7 @@ import ( "io" "log/slog" "net/http" + "net/url" "strings" "time" @@ -59,7 +60,7 @@ func (t *sseTransport) Connect(ctx context.Context) error { go t.readSSE(resp.Body) - slog.Debug("SSE transport connected", "url", t.server.URL) + slog.Debug("SSE transport connected", "url", sanitizeURL(t.server.URL)) return nil } @@ -97,20 +98,43 @@ func (t *sseTransport) readSSE(body io.ReadCloser) { } } +// resolveEndpointURL resolves and validates an endpoint URL received from an SSE +// server. Absolute URLs must share the same origin (scheme+host) as the base +// server URL to prevent SSRF. Relative URLs are resolved against the base. +func resolveEndpointURL(baseURL, endpoint string) (string, error) { + base, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("parse base URL: %w", err) + } + + ep, err := url.Parse(endpoint) + if err != nil { + return "", fmt.Errorf("parse endpoint URL: %w", err) + } + + resolved := base.ResolveReference(ep) + + // Validate same origin (scheme + host) + if resolved.Scheme != base.Scheme || resolved.Host != base.Host { + return "", fmt.Errorf( + "endpoint origin %s://%s does not match server origin %s://%s", + resolved.Scheme, resolved.Host, base.Scheme, base.Host, + ) + } + + return resolved.String(), nil +} + func (t *sseTransport) handleSSEEvent(eventType, data string) { switch eventType { case "endpoint": - // The server sends the message endpoint URL - t.messageURL = data - if !strings.HasPrefix(t.messageURL, "http") { - // Relative URL — resolve against base - base := t.server.URL - if idx := strings.LastIndex(base, "/"); idx > 8 { // After "https://" - base = base[:idx] - } - t.messageURL = base + "/" + strings.TrimPrefix(t.messageURL, "/") + resolved, err := resolveEndpointURL(t.server.URL, data) + if err != nil { + slog.Warn("rejecting SSE endpoint", "error", err, "endpoint", data) + return } - slog.Debug("SSE endpoint received", "url", t.messageURL) + t.messageURL = resolved + slog.Debug("SSE endpoint received", "url", sanitizeURL(t.messageURL)) case "message", "": var msg JSONRPCMessage if err := json.Unmarshal([]byte(data), &msg); err != nil { @@ -152,7 +176,7 @@ func (t *sseTransport) Send(ctx context.Context, msg *JSONRPCMessage) error { defer resp.Body.Close() if resp.StatusCode >= 400 { - body, _ := io.ReadAll(resp.Body) + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return fmt.Errorf("SSE send: status %d: %s", resp.StatusCode, string(body)) } diff --git a/internal/mcpclient/sse_test.go b/internal/mcpclient/sse_test.go new file mode 100644 index 0000000..0649632 --- /dev/null +++ b/internal/mcpclient/sse_test.go @@ -0,0 +1,82 @@ +package mcpclient + +import "testing" + +func TestResolveEndpointURL_RelativePath(t *testing.T) { + resolved, err := resolveEndpointURL("https://example.com/sse", "/messages?id=123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resolved != "https://example.com/messages?id=123" { + t.Errorf("got %s, want https://example.com/messages?id=123", resolved) + } +} + +func TestResolveEndpointURL_RelativePathNoLeadingSlash(t *testing.T) { + resolved, err := resolveEndpointURL("https://example.com/v1/sse", "messages?id=123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resolved != "https://example.com/v1/messages?id=123" { + t.Errorf("got %s, want https://example.com/v1/messages?id=123", resolved) + } +} + +func TestResolveEndpointURL_AbsoluteSameOrigin(t *testing.T) { + resolved, err := resolveEndpointURL( + "https://example.com/sse", + "https://example.com/messages?id=456", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resolved != "https://example.com/messages?id=456" { + t.Errorf("got %s, want https://example.com/messages?id=456", resolved) + } +} + +func TestResolveEndpointURL_AbsoluteDifferentOrigin(t *testing.T) { + _, err := resolveEndpointURL( + "https://example.com/sse", + "https://evil.com/steal-data", + ) + if err == nil { + t.Fatal("expected error for different origin, got nil") + } +} + +func TestResolveEndpointURL_DifferentScheme(t *testing.T) { + _, err := resolveEndpointURL( + "https://example.com/sse", + "http://example.com/messages", + ) + if err == nil { + t.Fatal("expected error for different scheme, got nil") + } +} + +func TestResolveEndpointURL_PathTraversal(t *testing.T) { + // url.ResolveReference normalizes path traversal, so this should + // stay on the same origin (resolved to https://example.com/evil.com). + resolved, err := resolveEndpointURL( + "https://example.com/v1/sse", + "../../../evil.com", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // The resolved URL should stay on example.com + if resolved != "https://example.com/evil.com" { + t.Errorf("got %s, want https://example.com/evil.com", resolved) + } +} + +func TestResolveEndpointURL_DifferentPort(t *testing.T) { + _, err := resolveEndpointURL( + "https://example.com:8443/sse", + "https://example.com:9999/messages", + ) + if err == nil { + t.Fatal("expected error for different port, got nil") + } +} diff --git a/internal/mcpclient/stdio.go b/internal/mcpclient/stdio.go index b600428..fa11e89 100644 --- a/internal/mcpclient/stdio.go +++ b/internal/mcpclient/stdio.go @@ -9,6 +9,7 @@ import ( "log/slog" "os/exec" "strings" + "sync" "github.com/go-authgate/agent-scanner/internal/models" ) @@ -20,7 +21,8 @@ type stdioTransport struct { stdout io.ReadCloser stderr io.ReadCloser recvCh chan *JSONRPCMessage - Stderr []string + mu sync.Mutex + lines []string } // NewStdioTransport creates a transport that communicates via subprocess stdio. @@ -75,7 +77,7 @@ func (t *stdioTransport) Connect(ctx context.Context) error { // Capture stderr in background go t.readStderr() - slog.Debug("stdio transport connected", "command", command, "args", args) + slog.Debug("stdio transport connected", "command", command, "args", sanitizeArgs(args)) return nil } @@ -97,15 +99,30 @@ func (t *stdioTransport) readStdout() { } } +const maxStderrLines = 10000 + func (t *stdioTransport) readStderr() { scanner := bufio.NewScanner(t.stderr) for scanner.Scan() { line := scanner.Text() - t.Stderr = append(t.Stderr, line) + t.mu.Lock() + if len(t.lines) < maxStderrLines { + t.lines = append(t.lines, line) + } + t.mu.Unlock() slog.Debug("server stderr", "line", line) } } +// GetStderr returns a copy of the captured stderr lines. +func (t *stdioTransport) GetStderr() []string { + t.mu.Lock() + defer t.mu.Unlock() + out := make([]string, len(t.lines)) + copy(out, t.lines) + return out +} + func (t *stdioTransport) Send(_ context.Context, msg *JSONRPCMessage) error { data, err := json.Marshal(msg) if err != nil { diff --git a/internal/redact/redact.go b/internal/redact/redact.go index 46ed92f..acd04a2 100644 --- a/internal/redact/redact.go +++ b/internal/redact/redact.go @@ -7,7 +7,8 @@ import ( "github.com/go-authgate/agent-scanner/internal/models" ) -const redactedValue = "**REDACTED**" +// RedactedValue is the placeholder used when redacting sensitive data. +const RedactedValue = "**REDACTED**" var absolutePathPatterns = []*regexp.Regexp{ regexp.MustCompile(`(?:/[a-zA-Z0-9._-]+){3,}`), // Unix paths @@ -20,7 +21,7 @@ var absolutePathPatterns = []*regexp.Regexp{ // AbsolutePaths replaces absolute file paths in text. func AbsolutePaths(text string) string { for _, pattern := range absolutePathPatterns { - text = pattern.ReplaceAllString(text, redactedValue) + text = pattern.ReplaceAllString(text, RedactedValue) } return text } @@ -31,22 +32,22 @@ func ServerResult(result *models.ServerScanResult) { case *models.StdioServer: // Redact environment variable values for k := range srv.Env { - srv.Env[k] = redactedValue + srv.Env[k] = RedactedValue } // Redact command arguments that look like paths or secrets for i, arg := range srv.Args { - if isPath(arg) || looksLikeSecret(arg) { - srv.Args[i] = redactedValue + if IsPath(arg) || LooksLikeSecret(arg) { + srv.Args[i] = RedactedValue } } case *models.RemoteServer: // Redact header values for k := range srv.Headers { - srv.Headers[k] = redactedValue + srv.Headers[k] = RedactedValue } // Redact URL query parameters if idx := strings.IndexByte(srv.URL, '?'); idx >= 0 { - srv.URL = srv.URL[:idx] + "?" + redactedValue + srv.URL = srv.URL[:idx] + "?" + RedactedValue } } @@ -66,7 +67,8 @@ func ScanPathResult(result *models.ScanPathResult) { } } -func isPath(arg string) bool { +// IsPath returns true if arg looks like an absolute or home-relative file path. +func IsPath(arg string) bool { if len(arg) == 0 { return false } @@ -74,12 +76,59 @@ func isPath(arg string) bool { (len(arg) >= 3 && arg[1] == ':' && (arg[2] == '\\' || arg[2] == '/')) } -func looksLikeSecret(arg string) bool { +// secretPrefixes lists known API key and token prefixes, pre-lowercased for +// efficient case-insensitive matching. +var secretPrefixes = []string{ + "sk-", // OpenAI / generic + "sk-ant-", // Anthropic + "ghp_", // GitHub personal access token + "gho_", // GitHub OAuth token + "github_pat_", // GitHub fine-grained PAT + "bearer ", // Authorization bearer token + "akia", // AWS access key ID + "xoxb-", // Slack bot token + "xoxp-", // Slack user token + "xapp-", // Slack app token + "xoxs-", // Slack session token + "glpat-", // GitLab personal access token + "npm_", // npm token + "pypi-", // PyPI token + "whsec_", // Stripe webhook secret + "sk_live_", // Stripe live secret key + "sk_test_", // Stripe test secret key + "rk_live_", // Stripe restricted key + "age-secret-key-", // age encryption key +} + +// LooksLikeSecret returns true if arg looks like an API key or secret token. +func LooksLikeSecret(arg string) bool { lower := strings.ToLower(arg) - for _, prefix := range []string{"sk-", "ghp_", "gho_", "github_pat_", "Bearer "} { - if strings.HasPrefix(lower, strings.ToLower(prefix)) { + for _, prefix := range secretPrefixes { + if strings.HasPrefix(lower, prefix) { return true } } + + // High-entropy heuristic: long strings with mixed character classes + if len(arg) > 20 && !strings.Contains(arg, " ") && looksHighEntropy(arg) { + return true + } + return false } + +// looksHighEntropy returns true if s contains a mix of uppercase, lowercase, and digits. +func looksHighEntropy(s string) bool { + var hasUpper, hasLower, hasDigit bool + for _, c := range s { + switch { + case c >= 'A' && c <= 'Z': + hasUpper = true + case c >= 'a' && c <= 'z': + hasLower = true + case c >= '0' && c <= '9': + hasDigit = true + } + } + return hasUpper && hasLower && hasDigit +} diff --git a/internal/redact/redact_test.go b/internal/redact/redact_test.go index 33513d1..d20c829 100644 --- a/internal/redact/redact_test.go +++ b/internal/redact/redact_test.go @@ -12,9 +12,9 @@ func TestAbsolutePaths(t *testing.T) { input string contains string }{ - {"/Users/john/Documents/secret.txt", redactedValue}, - {"C:\\Users\\john\\secret.txt", redactedValue}, - {"~/Documents/secret.txt", redactedValue}, + {"/Users/john/Documents/secret.txt", RedactedValue}, + {"C:\\Users\\john\\secret.txt", RedactedValue}, + {"~/Documents/secret.txt", RedactedValue}, } for _, tt := range tests { @@ -42,10 +42,10 @@ func TestServerResult_Stdio(t *testing.T) { ServerResult(result) stdio := result.Server.(*models.StdioServer) - if stdio.Env["API_KEY"] != redactedValue { + if stdio.Env["API_KEY"] != RedactedValue { t.Errorf("expected env redacted, got %s", stdio.Env["API_KEY"]) } - if stdio.Args[1] != redactedValue { + if stdio.Args[1] != RedactedValue { t.Errorf("expected path arg redacted, got %s", stdio.Args[1]) } } @@ -61,7 +61,7 @@ func TestServerResult_Remote(t *testing.T) { ServerResult(result) remote := result.Server.(*models.RemoteServer) - if remote.Headers["Authorization"] != redactedValue { + if remote.Headers["Authorization"] != RedactedValue { t.Errorf("expected header redacted, got %s", remote.Headers["Authorization"]) } if strings.Contains(remote.URL, "secret123") { @@ -83,8 +83,70 @@ func TestIsPath(t *testing.T) { } for _, tt := range tests { - if got := isPath(tt.arg); got != tt.expected { - t.Errorf("isPath(%q) = %v, want %v", tt.arg, got, tt.expected) + if got := IsPath(tt.arg); got != tt.expected { + t.Errorf("IsPath(%q) = %v, want %v", tt.arg, got, tt.expected) } } } + +func TestLooksLikeSecret(t *testing.T) { + positives := []string{ + "sk-abc123", // OpenAI + "sk-ant-api03-abc", // Anthropic + "ghp_abcdef1234567890", // GitHub PAT + "gho_token", // GitHub OAuth + "github_pat_abc", // GitHub fine-grained + "Bearer my-token", // Bearer token + "AKIAIOSFODNN7EXAMPLE", // AWS access key + "xoxb-slack-bot-token", // Slack bot + "xoxp-slack-user-token", // Slack user + "xapp-slack-app-token", // Slack app + "glpat-xxxxxxxxxxxx", // GitLab PAT + "npm_xxxxxxxx", // npm token + "pypi-AgEIcHlwaS5vcmc", // PyPI token + "whsec_abcdef123456", // Stripe webhook + "sk_live_abc123", // Stripe live key + "sk_test_abc123", // Stripe test key + "rk_live_abc123", // Stripe restricted + "AGE-SECRET-KEY-1abc", // age key + } + + for _, s := range positives { + if !LooksLikeSecret(s) { + t.Errorf("LooksLikeSecret(%q) = false, want true", s) + } + } + + negatives := []string{ + "--port", + "8080", + "localhost", + "my-server", + "true", + "", + "short", + } + + for _, s := range negatives { + if LooksLikeSecret(s) { + t.Errorf("LooksLikeSecret(%q) = true, want false", s) + } + } +} + +func TestLooksLikeSecret_HighEntropy(t *testing.T) { + // Long mixed-case alphanumeric string should be detected + if !LooksLikeSecret("aB3cD4eF5gH6iJ7kL8mN9oP") { + t.Error("expected high-entropy string to be detected as secret") + } + + // Short string should not trigger entropy heuristic + if LooksLikeSecret("aB3c") { + t.Error("short mixed-case string should not be detected") + } + + // String with spaces should not trigger + if LooksLikeSecret("this Is A Regular Sentence 123") { + t.Error("string with spaces should not trigger entropy heuristic") + } +} diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index b931349..09af66a 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -11,34 +11,14 @@ import ( "net/http" "os" "os/user" - "strings" "time" - "unicode" + "github.com/go-authgate/agent-scanner/internal/httperrors" "github.com/go-authgate/agent-scanner/internal/models" "github.com/go-authgate/agent-scanner/internal/redact" "github.com/go-authgate/agent-scanner/internal/version" ) -// clientError is a non-retryable HTTP error (4xx). -type clientError struct { - StatusCode int - Body string -} - -func (e *clientError) Error() string { - return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) -} - -// nonRetryableError wraps errors that should not be retried -// (e.g., request construction failures). -type nonRetryableError struct { - err error -} - -func (e *nonRetryableError) Error() string { return e.err.Error() } -func (e *nonRetryableError) Unwrap() error { return e.err } - // Uploader pushes scan results to control servers. type Uploader interface { Upload(ctx context.Context, results []models.ScanPathResult, server models.ControlServer) error @@ -98,11 +78,11 @@ func (u *uploader) Upload( return nil } // Do not retry client errors (4xx) or non-retryable errors (e.g., bad URL) - var nre *nonRetryableError + var nre *httperrors.NonRetryableError if errors.As(err, &nre) { return fmt.Errorf("upload failed: %w", err) } - var ce *clientError + var ce *httperrors.ClientError if errors.As(err, &ce) { return fmt.Errorf( "upload failed due to non-retryable client error after %d attempt(s): %w", @@ -127,7 +107,7 @@ func (u *uploader) Upload( func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, body []byte) error { req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, bytes.NewReader(body)) if err != nil { - return &nonRetryableError{err: err} + return &httperrors.NonRetryableError{Err: err} } req.Header.Set("Content-Type", "application/json") for k, v := range server.Headers { @@ -142,9 +122,9 @@ func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, bo if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - bodySnippet := sanitizeBodySnippet(string(respBody), 512) + bodySnippet := httperrors.SanitizeBodySnippet(string(respBody), 512) if resp.StatusCode < 500 { - return &clientError{StatusCode: resp.StatusCode, Body: bodySnippet} + return &httperrors.ClientError{StatusCode: resp.StatusCode, Body: bodySnippet} } return fmt.Errorf("status %d: %s", resp.StatusCode, bodySnippet) } @@ -173,18 +153,3 @@ func getUsername() string { } return u.Username } - -// sanitizeBodySnippet truncates s to approximately maxLen bytes (the -// returned string may be slightly longer due to a " [truncated]" suffix) -// and replaces all Unicode control characters with spaces for safe single-line logging. -func sanitizeBodySnippet(s string, maxLen int) string { - if len(s) > maxLen { - s = s[:maxLen] + " [truncated]" - } - return strings.Map(func(r rune) rune { - if unicode.IsControl(r) { - return ' ' - } - return r - }, s) -} diff --git a/internal/upload/uploader_test.go b/internal/upload/uploader_test.go index 539484d..28f8b43 100644 --- a/internal/upload/uploader_test.go +++ b/internal/upload/uploader_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/go-authgate/agent-scanner/internal/httperrors" "github.com/go-authgate/agent-scanner/internal/models" ) @@ -118,10 +119,10 @@ func TestUpload_4xxNoRetry(t *testing.T) { t.Errorf("expected exactly 1 request (no retry on 4xx), got %d", count) } - // Verify it's a clientError - var ce *clientError + // Verify it's a ClientError + var ce *httperrors.ClientError if !errors.As(err, &ce) { - t.Errorf("expected clientError in chain, got %T: %v", err, err) + t.Errorf("expected ClientError in chain, got %T: %v", err, err) } }