diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 0abc7d1..6c798fd 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -4,16 +4,38 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" + "strings" "time" + "unicode" "github.com/go-authgate/agent-scanner/internal/models" "github.com/go-authgate/agent-scanner/internal/tlsutil" ) +// clientError is a non-retryable HTTP error (4xx). +type clientError struct { + StatusCode int + Body string +} + +func (e *clientError) Error() string { + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) +} + +// nonRetryableError wraps errors that should not be retried +// (e.g., request construction failures, JSON decode errors). +type nonRetryableError struct { + err error +} + +func (e *nonRetryableError) Error() string { return e.err.Error() } +func (e *nonRetryableError) Unwrap() error { return e.err } + // Analyzer performs security analysis on scan results. type Analyzer interface { Analyze(ctx context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) @@ -113,7 +135,7 @@ func (a *remoteAnalyzer) analyzePathResult( return fmt.Errorf("marshal request: %w", err) } - // Retry with exponential backoff + // Retry with exponential backoff (only retry on 5xx / network errors) var resp analysisResponse maxRetries := 3 for attempt := range maxRetries { @@ -121,6 +143,16 @@ func (a *remoteAnalyzer) analyzePathResult( if err == nil { break } + // Do not retry non-retryable errors (bad URL, JSON decode, etc.) + var nre *nonRetryableError + if errors.As(err, &nre) { + return fmt.Errorf("analysis API: %w", err) + } + // Do not retry client errors (4xx) + var ce *clientError + if errors.As(err, &ce) { + return fmt.Errorf("analysis API: %w", err) + } if attempt < maxRetries-1 { backoff := time.Duration(1<= 400 { - respBody, _ := io.ReadAll(httpResp.Body) - return fmt.Errorf("status %d: %s", httpResp.StatusCode, string(respBody)) + respBody, _ := io.ReadAll(io.LimitReader(httpResp.Body, 4096)) + bodySnippet := sanitizeBodySnippet(string(respBody), 512) + if httpResp.StatusCode < 500 { + return &clientError{StatusCode: httpResp.StatusCode, Body: bodySnippet} + } + return fmt.Errorf("status %d: %s", httpResp.StatusCode, bodySnippet) + } + + if err := json.NewDecoder(httpResp.Body).Decode(resp); err != nil { + return &nonRetryableError{err: fmt.Errorf("decode response: %w", err)} } + return nil +} - return json.NewDecoder(httpResp.Body).Decode(resp) +// sanitizeBodySnippet truncates s to approximately maxLen bytes (the +// returned string may be slightly longer due to a " [truncated]" suffix) +// and replaces all Unicode control characters with spaces for safe single-line logging. +func sanitizeBodySnippet(s string, maxLen int) string { + if len(s) > maxLen { + s = s[:maxLen] + " [truncated]" + } + return strings.Map(func(r rune) rune { + if unicode.IsControl(r) { + return ' ' + } + return r + }, s) } diff --git a/internal/analysis/analyzer_test.go b/internal/analysis/analyzer_test.go new file mode 100644 index 0000000..e96604a --- /dev/null +++ b/internal/analysis/analyzer_test.go @@ -0,0 +1,335 @@ +package analysis + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +func TestAnalyze_EmptyURL(t *testing.T) { + var requestMade atomic.Bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade.Store(true) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + a := NewAnalyzer("", false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "test"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t1", Description: "desc"}}, + }, + }, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if requestMade.Load() { + t.Error("expected no HTTP request when analysis URL is empty") + } + if len(out) != 1 { + t.Errorf("expected results returned unchanged, got %d", len(out)) + } +} + +func TestAnalyze_Success(t *testing.T) { + respData := analysisResponse{ + Issues: []models.Issue{ + {Code: "E001", Message: "remote issue"}, + }, + Labels: [][]models.ScalarToolLabels{ + {{IsPublicSink: 0.9, Destructive: 0.1}}, + }, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %s", ct) + } + + // Verify request body + var req analysisRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Errorf("failed to decode request body: %v", err) + } + if len(req.Servers) != 1 { + t.Errorf("expected 1 server in request, got %d", len(req.Servers)) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(respData) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv1", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "tool1", Description: "a tool"}, + }, + }, + }, + }, + Issues: []models.Issue{ + {Code: "W001", Message: "existing issue"}, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(out) != 1 { + t.Fatalf("expected 1 result, got %d", len(out)) + } + + // Check that remote issues were merged in + if len(out[0].Issues) != 2 { + t.Fatalf("expected 2 issues (1 existing + 1 remote), got %d", len(out[0].Issues)) + } + if out[0].Issues[0].Code != "W001" { + t.Errorf("expected first issue code W001, got %s", out[0].Issues[0].Code) + } + if out[0].Issues[1].Code != "E001" { + t.Errorf("expected second issue code E001, got %s", out[0].Issues[1].Code) + } + + // Check labels were merged + if len(out[0].Labels) != 1 { + t.Fatalf("expected 1 label set, got %d", len(out[0].Labels)) + } + if out[0].Labels[0][0].IsPublicSink != 0.9 { + t.Errorf("expected IsPublicSink 0.9, got %f", out[0].Labels[0][0].IsPublicSink) + } +} + +func TestAnalyze_NilSignatureSkipped(t *testing.T) { + var requestMade atomic.Bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade.Store(true) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(analysisResponse{}) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv-no-sig", + Server: &models.StdioServer{Command: "cmd"}, + Signature: nil, // nil signature should be skipped + }, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if requestMade.Load() { + t.Error("expected no HTTP request when all signatures are nil") + } + if len(out) != 1 { + t.Errorf("expected results returned, got %d", len(out)) + } +} + +func TestAnalyze_4xxNoRetry(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad request")) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t", Description: "d"}}, + }, + }, + }, + }, + } + + _, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("Analyze should not return error (it logs warning instead), got: %v", err) + } + + count := requestCount.Load() + if count != 1 { + t.Errorf("expected exactly 1 request (no retry on 4xx), got %d", count) + } +} + +func TestAnalyze_5xxRetries(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal server error")) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t", Description: "d"}}, + }, + }, + }, + }, + } + + // Use a context with a deadline so we don't wait for full backoff. + // First request is immediate, then 1s backoff, then 2s backoff. + // With a 4s timeout we should reliably get at least 2 attempts without flakiness. + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + _, err := a.Analyze(ctx, results) + // Analyze doesn't return error even on failure (it logs a warning) + if err != nil { + t.Fatalf("Analyze should not return error, got: %v", err) + } + + count := requestCount.Load() + if count < 2 { + t.Errorf("expected at least 2 requests (retries on 5xx), got %d", count) + } +} + +func TestAnalyze_FailureDoesNotFailOverall(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("error")) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p1", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t", Description: "d"}}, + }, + }, + }, + Issues: []models.Issue{ + {Code: "W001", Message: "existing"}, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("expected no error from Analyze even when analysis fails, got: %v", err) + } + + // Results should still be returned + if len(out) != 1 { + t.Fatalf("expected 1 result, got %d", len(out)) + } + // Existing issues should be preserved + if len(out[0].Issues) != 1 { + t.Errorf("expected existing issues preserved, got %d", len(out[0].Issues)) + } +} + +func TestAnalyze_AllNilSignatures(t *testing.T) { + var requestMade atomic.Bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade.Store(true) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(analysisResponse{}) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv1", + Server: &models.StdioServer{Command: "cmd1"}, + Signature: nil, + }, + { + Name: "srv2", + Server: &models.StdioServer{Command: "cmd2"}, + Signature: nil, + }, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if requestMade.Load() { + t.Error("expected no HTTP request when all signatures are nil") + } + if len(out) != 1 { + t.Errorf("expected results returned unchanged, got %d", len(out)) + } +} diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go new file mode 100644 index 0000000..a95023d --- /dev/null +++ b/internal/cli/cli_test.go @@ -0,0 +1,100 @@ +package cli + +import ( + "testing" +) + +func TestParseControlServers_Empty(t *testing.T) { + // Save and restore original scanFlags. + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{} + + servers := parseControlServers() + if servers != nil { + t.Errorf("expected nil, got %v", servers) + } +} + +func TestParseControlServers_OneServer_NoIdentifier(t *testing.T) { + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{ + ControlServers: []string{"https://control.example.com"}, + } + + servers := parseControlServers() + if len(servers) != 1 { + t.Fatalf("expected 1 server, got %d", len(servers)) + } + if servers[0].URL != "https://control.example.com" { + t.Errorf("unexpected URL: %s", servers[0].URL) + } + if servers[0].Identifier != "" { + t.Errorf("expected empty identifier, got %q", servers[0].Identifier) + } +} + +func TestParseControlServers_MatchingServersAndIdentifiers(t *testing.T) { + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{ + ControlServers: []string{"https://a.example.com", "https://b.example.com"}, + ControlIdentifier: []string{"id-a", "id-b"}, + } + + servers := parseControlServers() + if len(servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(servers)) + } + + if servers[0].URL != "https://a.example.com" { + t.Errorf("server[0] URL = %q, want %q", servers[0].URL, "https://a.example.com") + } + if servers[0].Identifier != "id-a" { + t.Errorf("server[0] Identifier = %q, want %q", servers[0].Identifier, "id-a") + } + if servers[1].URL != "https://b.example.com" { + t.Errorf("server[1] URL = %q, want %q", servers[1].URL, "https://b.example.com") + } + if servers[1].Identifier != "id-b" { + t.Errorf("server[1] Identifier = %q, want %q", servers[1].Identifier, "id-b") + } +} + +func TestParseControlServers_MoreServersThanIdentifiers(t *testing.T) { + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{ + ControlServers: []string{ + "https://a.example.com", + "https://b.example.com", + "https://c.example.com", + }, + ControlIdentifier: []string{"id-a"}, + } + + servers := parseControlServers() + if len(servers) != 3 { + t.Fatalf("expected 3 servers, got %d", len(servers)) + } + + if servers[0].URL != "https://a.example.com" { + t.Errorf("server[0] URL = %q, want %q", servers[0].URL, "https://a.example.com") + } + if servers[0].Identifier != "id-a" { + t.Errorf("server[0] Identifier = %q, want %q", servers[0].Identifier, "id-a") + } + + // servers[1] and servers[2] have index >= len(ControlIdentifier), so identifier should be empty. + if servers[1].Identifier != "" { + t.Errorf("server[1] Identifier = %q, want empty", servers[1].Identifier) + } + if servers[2].Identifier != "" { + t.Errorf("server[2] Identifier = %q, want empty", servers[2].Identifier) + } +} diff --git a/internal/output/output_test.go b/internal/output/output_test.go new file mode 100644 index 0000000..5678daf --- /dev/null +++ b/internal/output/output_test.go @@ -0,0 +1,291 @@ +package output + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +// --------------------------------------------------------------------------- +// TextFormatter tests +// --------------------------------------------------------------------------- + +func TestTextFormatter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + f := NewTextFormatter(&buf) + + if err := f.FormatResults(nil, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + + got := buf.String() + if !strings.Contains(got, "No MCP configurations found.") { + t.Errorf("expected 'No MCP configurations found.' in output, got:\n%s", got) + } +} + +func TestTextFormatter_ServersAndIssues(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "cursor", + Path: "/home/user/.cursor/mcp.json", + Servers: []models.ServerScanResult{ + { + Name: "my-server", + Server: &models.StdioServer{Command: "node"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "read_file", Description: "Reads a file"}, + {Name: "write_file", Description: "Writes a file"}, + }, + }, + }, + }, + Issues: []models.Issue{ + {Code: "E001", Message: "Prompt injection detected"}, + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + // Status icon for issues (E001 is high severity -> redCross) + if !strings.Contains(out, redCross) { + t.Error("expected red cross icon in output for high-severity issue") + } + // Server name + if !strings.Contains(out, "my-server") { + t.Error("expected server name 'my-server' in output") + } + // Entity count + if !strings.Contains(out, "2 entities") { + t.Error("expected '2 entities' in output") + } + // Issue code + if !strings.Contains(out, "E001") { + t.Error("expected issue code 'E001' in output") + } +} + +func TestTextFormatter_Summary(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "claude", + Path: "/path/a", + Servers: []models.ServerScanResult{ + { + Name: "server-a", + Server: &models.StdioServer{Command: "a"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t1"}}, + }, + }, + { + Name: "server-b", + Server: &models.StdioServer{Command: "b"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t2"}, {Name: "t3"}}, + Prompts: []models.Prompt{{Name: "p1"}}, + }, + }, + }, + Issues: []models.Issue{ + {Code: "E002", Message: "cross-server ref"}, // high + {Code: "W001", Message: "suspicious trigger"}, // medium (warning) + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + // Summary should show 2 servers, 4 entities (1 + 2 tools + 1 prompt) + if !strings.Contains(out, "2 server(s)") { + t.Errorf("expected '2 server(s)' in summary, got:\n%s", out) + } + if !strings.Contains(out, "4 entities") { + t.Errorf("expected '4 entities' in summary, got:\n%s", out) + } + if !strings.Contains(out, "1 issue(s) found") { + t.Errorf("expected '1 issue(s) found' in summary, got:\n%s", out) + } + if !strings.Contains(out, "1 warning(s)") { + t.Errorf("expected '1 warning(s)' in summary, got:\n%s", out) + } +} + +func TestTextFormatter_NilSignature(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "vscode", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "dead-server", + Server: &models.StdioServer{Command: "dead"}, + Signature: nil, + }, + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + if !strings.Contains(out, "(no response)") { + t.Errorf("expected '(no response)' for nil signature, got:\n%s", out) + } +} + +func TestTextFormatter_ServerError_PrintErrors(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "windsurf", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "broken-server", + Server: &models.StdioServer{Command: "broken"}, + Error: &models.ScanError{ + Message: "connection refused", + Category: models.ErrCatServerStartup, + }, + }, + }, + }, + } + + // Without PrintErrors + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{PrintErrors: false}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + if strings.Contains(out, "connection refused") { + t.Error("expected error message to be hidden when PrintErrors=false") + } + + // With PrintErrors + buf.Reset() + f = NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{PrintErrors: true}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out = buf.String() + if !strings.Contains(out, "connection refused") { + t.Errorf("expected error message 'connection refused' when PrintErrors=true, got:\n%s", out) + } +} + +func TestTextFormatter_TruncatesLongDescriptions(t *testing.T) { + longDesc := strings.Repeat("A", 300) + results := []models.ScanPathResult{ + { + Client: "cursor", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "long-desc-server", + Server: &models.StdioServer{Command: "node"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "verbose-tool", Description: longDesc}, + }, + }, + }, + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{PrintFullDescs: false}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + // Should be truncated to 200 chars + "..." + if !strings.Contains(out, "...") { + t.Error("expected truncation ellipsis '...' in output") + } + // The full 300-char string should NOT appear + if strings.Contains(out, longDesc) { + t.Error("expected description to be truncated, but found full 300-char string") + } + // First 200 chars should appear + truncated := longDesc[:200] + if !strings.Contains(out, truncated) { + t.Error("expected first 200 chars of description to appear") + } +} + +// --------------------------------------------------------------------------- +// JSONFormatter tests +// --------------------------------------------------------------------------- + +func TestJSONFormatter_ValidJSON(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "claude", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "test-server", + Server: &models.StdioServer{Command: "node"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "hello", Description: "says hello"}, + }, + }, + }, + }, + Issues: []models.Issue{ + {Code: "W001", Message: "suspicious words"}, + }, + }, + } + + var buf bytes.Buffer + f := NewJSONFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + + var parsed []map[string]any + if err := json.Unmarshal(buf.Bytes(), &parsed); err != nil { + t.Fatalf("output is not valid JSON: %v\nraw output:\n%s", err, buf.String()) + } + if len(parsed) != 1 { + t.Errorf("expected 1 result in JSON array, got %d", len(parsed)) + } +} + +func TestJSONFormatter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + f := NewJSONFormatter(&buf) + if err := f.FormatResults([]models.ScanPathResult{}, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + + got := strings.TrimSpace(buf.String()) + if got != "[]" { + t.Errorf("expected '[]' for empty results, got: %s", got) + } +} diff --git a/internal/pipeline/pipeline_test.go b/internal/pipeline/pipeline_test.go new file mode 100644 index 0000000..2880e0c --- /dev/null +++ b/internal/pipeline/pipeline_test.go @@ -0,0 +1,624 @@ +package pipeline + +import ( + "context" + "errors" + "testing" + + "github.com/go-authgate/agent-scanner/internal/models" + "github.com/go-authgate/agent-scanner/internal/rules" +) + +// --- Mock types --- + +type mockDiscoverer struct { + discoverClientsFn func(ctx context.Context, allUsers bool) []models.CandidateClient + resolveClientFn func(ctx context.Context, candidate models.CandidateClient) ([]*models.ClientToInspect, error) + clientFromPathFn func(ctx context.Context, path string, scanSkills bool) ([]*models.ClientToInspect, error) +} + +func (m *mockDiscoverer) DiscoverClients( + ctx context.Context, + allUsers bool, +) []models.CandidateClient { + if m.discoverClientsFn != nil { + return m.discoverClientsFn(ctx, allUsers) + } + return nil +} + +func (m *mockDiscoverer) ResolveClient( + ctx context.Context, + candidate models.CandidateClient, +) ([]*models.ClientToInspect, error) { + if m.resolveClientFn != nil { + return m.resolveClientFn(ctx, candidate) + } + return []*models.ClientToInspect{}, nil +} + +func (m *mockDiscoverer) ClientFromPath( + ctx context.Context, + path string, + scanSkills bool, +) ([]*models.ClientToInspect, error) { + if m.clientFromPathFn != nil { + return m.clientFromPathFn(ctx, path, scanSkills) + } + return []*models.ClientToInspect{}, nil +} + +type mockInspector struct { + inspectClientFn func(ctx context.Context, client *models.ClientToInspect, scanSkills bool) (*models.InspectedClient, error) +} + +func (m *mockInspector) InspectServer( + _ context.Context, + _ string, + _ models.ServerConfig, +) (*models.InspectedExtension, error) { + return &models.InspectedExtension{}, nil +} + +func (m *mockInspector) InspectSkill( + _ context.Context, + _ string, + _ *models.SkillServer, +) (*models.InspectedExtension, error) { + return &models.InspectedExtension{}, nil +} + +func (m *mockInspector) InspectClient( + ctx context.Context, + client *models.ClientToInspect, + scanSkills bool, +) (*models.InspectedClient, error) { + if m.inspectClientFn != nil { + return m.inspectClientFn(ctx, client, scanSkills) + } + return &models.InspectedClient{ + Name: client.Name, + Extensions: map[string][]models.InspectedExtension{}, + }, nil +} + +type mockRuleEngine struct { + runFn func(ctx *rules.RuleContext) []models.Issue +} + +func (m *mockRuleEngine) Register(_ rules.Rule) {} + +func (m *mockRuleEngine) Run(ctx *rules.RuleContext) []models.Issue { + if m.runFn != nil { + return m.runFn(ctx) + } + return nil +} + +type mockAnalyzer struct { + analyzeFn func(ctx context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) + called bool +} + +func (m *mockAnalyzer) Analyze( + ctx context.Context, + results []models.ScanPathResult, +) ([]models.ScanPathResult, error) { + m.called = true + if m.analyzeFn != nil { + return m.analyzeFn(ctx, results) + } + return results, nil +} + +type mockUploader struct { + uploadFn func(ctx context.Context, results []models.ScanPathResult, server models.ControlServer) error + called bool + servers []models.ControlServer +} + +func (m *mockUploader) Upload( + ctx context.Context, + results []models.ScanPathResult, + server models.ControlServer, +) error { + m.called = true + m.servers = append(m.servers, server) + if m.uploadFn != nil { + return m.uploadFn(ctx, results, server) + } + return nil +} + +// --- Helpers --- + +// newTestClient returns a simple ClientToInspect for testing. +func newTestClient(name string) *models.ClientToInspect { + return &models.ClientToInspect{ + Name: name, + MCPConfigs: map[string]models.MCPConfigOrError{}, + } +} + +// newInspectedClient returns an InspectedClient with one extension under a config path. +func newInspectedClient(name, configPath, extName string) *models.InspectedClient { + return &models.InspectedClient{ + Name: name, + Extensions: map[string][]models.InspectedExtension{ + configPath: { + { + Name: extName, + Config: &models.StdioServer{Command: "echo"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "test-tool", Description: "a test tool"}}, + }, + }, + }, + }, + } +} + +// --- Tests --- + +func TestRun_InspectOnly(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "test-client"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/config.json", "server-a"), nil + }, + } + analyzer := &mockAnalyzer{} + uploader := &mockUploader{} + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Analyzer: analyzer, + Uploader: uploader, + InspectOnly: true, + ControlServers: []ControlServerConfig{ + {URL: "https://control.example.com"}, + }, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) == 0 { + t.Fatal("expected at least one result from inspect stage") + } + if analyzer.called { + t.Error("analyzer should NOT be called in InspectOnly mode") + } + if uploader.called { + t.Error("uploader should NOT be called in InspectOnly mode") + } +} + +func TestRun_FullPipeline(t *testing.T) { + discoverCalled := false + resolveCalled := false + + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + discoverCalled = true + return []models.CandidateClient{{Name: "claude"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + resolveCalled = true + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/config.json", "srv"), nil + }, + } + + ruleEngine := &mockRuleEngine{ + runFn: func(_ *rules.RuleContext) []models.Issue { + return []models.Issue{{Code: "W001", Message: "suspicious"}} + }, + } + + analyzer := &mockAnalyzer{ + analyzeFn: func(_ context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) { + // Simulate adding labels + for i := range results { + results[i].Labels = [][]models.ScalarToolLabels{{{IsPublicSink: 0.9}}} + } + return results, nil + }, + } + + uploader := &mockUploader{} + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + RuleEngine: ruleEngine, + Analyzer: analyzer, + Uploader: uploader, + ControlServers: []ControlServerConfig{ + {URL: "https://ctrl.example.com", Identifier: "id-1"}, + }, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !discoverCalled { + t.Error("DiscoverClients should be called when no Paths are set") + } + if !resolveCalled { + t.Error("ResolveClient should be called when no Paths are set") + } + if !analyzer.called { + t.Error("analyzer should be called in full pipeline") + } + if !uploader.called { + t.Error("uploader should be called in full pipeline") + } + if len(results) == 0 { + t.Fatal("expected at least one result") + } + // Check that rule engine issues were appended. + foundIssue := false + for _, r := range results { + for _, iss := range r.Issues { + if iss.Code == "W001" { + foundIssue = true + } + } + } + if !foundIssue { + t.Error("expected rule engine issue W001 in results") + } + // Check that uploader received the correct server. + if len(uploader.servers) != 1 { + t.Fatalf("expected uploader to be called once, got %d", len(uploader.servers)) + } + if uploader.servers[0].URL != "https://ctrl.example.com" { + t.Errorf("unexpected control server URL: %s", uploader.servers[0].URL) + } + if uploader.servers[0].Identifier != "id-1" { + t.Errorf("unexpected control server identifier: %s", uploader.servers[0].Identifier) + } +} + +func TestRun_WithPaths_UsesClientFromPath(t *testing.T) { + clientFromPathCalled := false + discoverClientsCalled := false + + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + discoverClientsCalled = true + return nil + }, + clientFromPathFn: func(_ context.Context, path string, _ bool) ([]*models.ClientToInspect, error) { + clientFromPathCalled = true + return []*models.ClientToInspect{newTestClient("path-client-" + path)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/some/path", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Paths: []string{"/path/to/config.json", "/another/config.json"}, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !clientFromPathCalled { + t.Error("ClientFromPath should be called when Paths are provided") + } + if discoverClientsCalled { + t.Error("DiscoverClients should NOT be called when Paths are provided") + } + if len(results) != 2 { + t.Errorf("expected 2 results (one per path), got %d", len(results)) + } +} + +func TestRun_NoPaths_UsesDiscoverAndResolve(t *testing.T) { + discoverCalled := false + resolveCalled := false + clientFromPathCalled := false + + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + discoverCalled = true + return []models.CandidateClient{ + {Name: "client-a"}, + {Name: "client-b"}, + } + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + resolveCalled = true + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + clientFromPathFn: func(_ context.Context, _ string, _ bool) ([]*models.ClientToInspect, error) { + clientFromPathCalled = true + return nil, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !discoverCalled { + t.Error("DiscoverClients should be called") + } + if !resolveCalled { + t.Error("ResolveClient should be called") + } + if clientFromPathCalled { + t.Error("ClientFromPath should NOT be called when no Paths are set") + } + if len(results) != 2 { + t.Errorf("expected 2 results (one per discovered client), got %d", len(results)) + } +} + +func TestRun_NilAnalyzer_SkipsAnalysis(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "c"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + ruleEngine := &mockRuleEngine{ + runFn: func(_ *rules.RuleContext) []models.Issue { + return []models.Issue{{Code: "W002", Message: "too many entities"}} + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + RuleEngine: ruleEngine, + Analyzer: nil, // nil analyzer + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) == 0 { + t.Fatal("expected results") + } + // Rule engine should still have run. + foundIssue := false + for _, r := range results { + for _, iss := range r.Issues { + if iss.Code == "W002" { + foundIssue = true + } + } + } + if !foundIssue { + t.Error("expected rule engine issue W002 even with nil analyzer") + } +} + +func TestRun_NilUploader_SkipsPush(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "c"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + // nil Uploader, with ControlServers configured + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Uploader: nil, + ControlServers: []ControlServerConfig{ + {URL: "https://ctrl.example.com"}, + }, + }) + + // Should not panic. + _, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRun_EmptyControlServers_SkipsPush(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "c"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + uploader := &mockUploader{} + + // Uploader provided, but no ControlServers + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Uploader: uploader, + ControlServers: nil, + }) + + _, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if uploader.called { + t.Error("uploader should NOT be called with empty ControlServers") + } +} + +func TestRun_ClientDiscoveryFailure_Continues(t *testing.T) { + resolveCallCount := 0 + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{ + {Name: "failing-client"}, + {Name: "ok-client"}, + } + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + resolveCallCount++ + if c.Name == "failing-client" { + return nil, errors.New("resolution failed") + } + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolveCallCount != 2 { + t.Errorf("expected ResolveClient to be called 2 times, got %d", resolveCallCount) + } + + // Only the successful client should produce results. + if len(results) != 1 { + t.Errorf("expected 1 result from the successful client, got %d", len(results)) + } + if len(results) > 0 && results[0].Client != "ok-client" { + t.Errorf("expected result from ok-client, got %q", results[0].Client) + } +} + +func TestRun_ClientFromPathFailure_Continues(t *testing.T) { + disc := &mockDiscoverer{ + clientFromPathFn: func(_ context.Context, path string, _ bool) ([]*models.ClientToInspect, error) { + if path == "/bad/path" { + return nil, errors.New("path not found") + } + return []*models.ClientToInspect{newTestClient("good-client")}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Paths: []string{"/bad/path", "/good/path"}, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Only the good path should produce results. + if len(results) != 1 { + t.Errorf("expected 1 result from the good path, got %d", len(results)) + } +} + +func TestRun_InspectClientFailure_Continues(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{ + {Name: "fail-inspect"}, + {Name: "ok-inspect"}, + } + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + if client.Name == "fail-inspect" { + return nil, errors.New("inspection failed") + } + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(results) != 1 { + t.Errorf("expected 1 result (from ok-inspect only), got %d", len(results)) + } +} diff --git a/internal/tlsutil/tlsutil_test.go b/internal/tlsutil/tlsutil_test.go new file mode 100644 index 0000000..95468be --- /dev/null +++ b/internal/tlsutil/tlsutil_test.go @@ -0,0 +1,102 @@ +package tlsutil + +import ( + "crypto/tls" + "net/http" + "testing" +) + +func TestCloneTransport_ReturnsNonNil(t *testing.T) { + tr := CloneTransport() + if tr == nil { + t.Fatal("CloneTransport() returned nil") + } +} + +func TestCloneTransport_ReturnsSeparateInstance(t *testing.T) { + tr := CloneTransport() + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Skip("http.DefaultTransport is not *http.Transport; cannot compare pointers") + } + if tr == base { + t.Fatal("CloneTransport() returned the same pointer as http.DefaultTransport") + } +} + +func TestCloneTransport_MatchesDefaultTransport(t *testing.T) { + tr := CloneTransport() + + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Skip("http.DefaultTransport is not *http.Transport") + } + + if tr.MaxIdleConns != base.MaxIdleConns { + t.Errorf("MaxIdleConns = %d, want %d", tr.MaxIdleConns, base.MaxIdleConns) + } + if tr.IdleConnTimeout != base.IdleConnTimeout { + t.Errorf("IdleConnTimeout = %v, want %v", tr.IdleConnTimeout, base.IdleConnTimeout) + } + if tr.TLSHandshakeTimeout != base.TLSHandshakeTimeout { + t.Errorf( + "TLSHandshakeTimeout = %v, want %v", + tr.TLSHandshakeTimeout, + base.TLSHandshakeTimeout, + ) + } + if tr.ForceAttemptHTTP2 != base.ForceAttemptHTTP2 { + t.Errorf("ForceAttemptHTTP2 = %v, want %v", tr.ForceAttemptHTTP2, base.ForceAttemptHTTP2) + } +} + +func TestApplyInsecureSkipVerify_NilTLSConfig(t *testing.T) { + tr := &http.Transport{} + if tr.TLSClientConfig != nil { + t.Fatal("precondition failed: TLSClientConfig should be nil") + } + + ApplyInsecureSkipVerify(tr) + + if tr.TLSClientConfig == nil { + t.Fatal("TLSClientConfig is still nil after ApplyInsecureSkipVerify") + } + if !tr.TLSClientConfig.InsecureSkipVerify { + t.Error("InsecureSkipVerify = false, want true") + } +} + +func TestApplyInsecureSkipVerify_ExistingTLSConfig(t *testing.T) { + original := &tls.Config{ + ServerName: "example.com", + MinVersion: tls.VersionTLS12, + } + tr := &http.Transport{ + TLSClientConfig: original, + } + + ApplyInsecureSkipVerify(tr) + + // Must not mutate the original config. + if original.InsecureSkipVerify { + t.Error("original tls.Config was mutated; InsecureSkipVerify should still be false") + } + + // The transport must have a new config. + if tr.TLSClientConfig == original { + t.Error("TLSClientConfig is the same pointer as the original; should be a clone") + } + + // New config must have InsecureSkipVerify set. + if !tr.TLSClientConfig.InsecureSkipVerify { + t.Error("InsecureSkipVerify = false on new config, want true") + } + + // Cloned config should preserve existing fields. + if tr.TLSClientConfig.ServerName != "example.com" { + t.Errorf("ServerName = %q, want %q", tr.TLSClientConfig.ServerName, "example.com") + } + if tr.TLSClientConfig.MinVersion != tls.VersionTLS12 { + t.Errorf("MinVersion = %d, want %d", tr.TLSClientConfig.MinVersion, tls.VersionTLS12) + } +} diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 9c65be5..b931349 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -4,18 +4,41 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" "os" "os/user" + "strings" "time" + "unicode" "github.com/go-authgate/agent-scanner/internal/models" "github.com/go-authgate/agent-scanner/internal/redact" + "github.com/go-authgate/agent-scanner/internal/version" ) +// clientError is a non-retryable HTTP error (4xx). +type clientError struct { + StatusCode int + Body string +} + +func (e *clientError) Error() string { + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) +} + +// nonRetryableError wraps errors that should not be retried +// (e.g., request construction failures). +type nonRetryableError struct { + err error +} + +func (e *nonRetryableError) Error() string { return e.err.Error() } +func (e *nonRetryableError) Unwrap() error { return e.err } + // Uploader pushes scan results to control servers. type Uploader interface { Upload(ctx context.Context, results []models.ScanPathResult, server models.ControlServer) error @@ -56,6 +79,9 @@ func (u *uploader) Upload( Hostname: getHostname(), Username: getUsername(), }, + ScanMetadata: &models.ScanMetadata{ + Version: version.Version, + }, } body, err := json.Marshal(payload) @@ -63,7 +89,7 @@ func (u *uploader) Upload( return fmt.Errorf("marshal upload payload: %w", err) } - // Retry with exponential backoff + // Retry with exponential backoff (only retry on 5xx / network errors) maxRetries := 3 for attempt := range maxRetries { err = u.doUpload(ctx, server, body) @@ -71,6 +97,19 @@ func (u *uploader) Upload( slog.Info("upload successful", "url", server.URL) return nil } + // Do not retry client errors (4xx) or non-retryable errors (e.g., bad URL) + var nre *nonRetryableError + if errors.As(err, &nre) { + return fmt.Errorf("upload failed: %w", err) + } + var ce *clientError + if errors.As(err, &ce) { + return fmt.Errorf( + "upload failed due to non-retryable client error after %d attempt(s): %w", + attempt+1, + err, + ) + } if attempt < maxRetries-1 { backoff := time.Duration(1<= 400 { - respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("status %d: %s", resp.StatusCode, string(respBody)) + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + bodySnippet := sanitizeBodySnippet(string(respBody), 512) + if resp.StatusCode < 500 { + return &clientError{StatusCode: resp.StatusCode, Body: bodySnippet} + } + return fmt.Errorf("status %d: %s", resp.StatusCode, bodySnippet) } + // Drain response body on success to allow HTTP connection reuse. + _, _ = io.Copy(io.Discard, resp.Body) + return nil } @@ -127,3 +173,18 @@ func getUsername() string { } return u.Username } + +// sanitizeBodySnippet truncates s to approximately maxLen bytes (the +// returned string may be slightly longer due to a " [truncated]" suffix) +// and replaces all Unicode control characters with spaces for safe single-line logging. +func sanitizeBodySnippet(s string, maxLen int) string { + if len(s) > maxLen { + s = s[:maxLen] + " [truncated]" + } + return strings.Map(func(r rune) rune { + if unicode.IsControl(r) { + return ' ' + } + return r + }, s) +} diff --git a/internal/upload/uploader_test.go b/internal/upload/uploader_test.go new file mode 100644 index 0000000..539484d --- /dev/null +++ b/internal/upload/uploader_test.go @@ -0,0 +1,249 @@ +package upload + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +func TestUpload_Success(t *testing.T) { + var receivedBody models.ScanPathResultsCreate + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %s", ct) + } + if err := json.NewDecoder(r.Body).Decode(&receivedBody); err != nil { + t.Errorf("failed to decode request body: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + { + Client: "test-client", + Path: "/test/path", + Issues: []models.Issue{ + {Code: "E001", Message: "test issue"}, + }, + }, + } + server := models.ControlServer{ + URL: ts.URL, + Identifier: "test-id", + } + + err := u.Upload(context.Background(), results, server) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(receivedBody.ScanPathResults) != 1 { + t.Fatalf("expected 1 scan path result, got %d", len(receivedBody.ScanPathResults)) + } + if receivedBody.ScanPathResults[0].Client != "test-client" { + t.Errorf("expected client 'test-client', got %q", receivedBody.ScanPathResults[0].Client) + } + if receivedBody.ScanUserInfo.Hostname == "" { + t.Error("expected non-empty hostname") + } + if receivedBody.ScanUserInfo.Username == "" { + t.Error("expected non-empty username") + } +} + +func TestUpload_EmptyResults(t *testing.T) { + var requestMade atomic.Bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade.Store(true) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + server := models.ControlServer{URL: ts.URL} + + err := u.Upload(context.Background(), nil, server) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if requestMade.Load() { + t.Error("expected no HTTP request for empty results") + } + + err = u.Upload(context.Background(), []models.ScanPathResult{}, server) + if err != nil { + t.Fatalf("expected nil error for empty slice, got %v", err) + } + if requestMade.Load() { + t.Error("expected no HTTP request for empty slice") + } +} + +func TestUpload_4xxNoRetry(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad request")) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + err := u.Upload(context.Background(), results, server) + if err == nil { + t.Fatal("expected error for 400 response") + } + + count := requestCount.Load() + if count != 1 { + t.Errorf("expected exactly 1 request (no retry on 4xx), got %d", count) + } + + // Verify it's a clientError + var ce *clientError + if !errors.As(err, &ce) { + t.Errorf("expected clientError in chain, got %T: %v", err, err) + } +} + +func TestUpload_5xxRetries(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("server error")) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + // Use a context with a bounded deadline to avoid waiting for full backoff + // while still allowing multiple retries. The first request is immediate, + // then backoff is 1s, 2s, so we give ample time for at least two attempts. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := u.Upload(ctx, results, server) + if err == nil { + t.Fatal("expected error for 500 response") + } + + count := requestCount.Load() + if count < 2 { + t.Errorf("expected at least 2 requests (retries on 5xx), got %d", count) + } +} + +func TestUpload_CustomHeaders(t *testing.T) { + var receivedHeaders http.Header + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{ + URL: ts.URL, + Headers: map[string]string{ + "Authorization": "Bearer test-token", + "X-Custom": "custom-value", + }, + } + + err := u.Upload(context.Background(), results, server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedHeaders.Get("Authorization") != "Bearer test-token" { + t.Errorf( + "expected Authorization header 'Bearer test-token', got %q", + receivedHeaders.Get("Authorization"), + ) + } + if receivedHeaders.Get("X-Custom") != "custom-value" { + t.Errorf("expected X-Custom header 'custom-value', got %q", receivedHeaders.Get("X-Custom")) + } +} + +func TestUpload_ScanMetadataVersionPopulated(t *testing.T) { + var receivedBody models.ScanPathResultsCreate + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&receivedBody); err != nil { + t.Errorf("failed to decode request body: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + err := u.Upload(context.Background(), results, server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedBody.ScanMetadata == nil { + t.Fatal("expected ScanMetadata to be non-nil") + } + if receivedBody.ScanMetadata.Version == "" { + t.Error("expected ScanMetadata.Version to be populated") + } +} + +func TestUpload_ContextCancellation(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate a slow server; sleep just long enough to exceed the client context deadline + time.Sleep(500 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := u.Upload(ctx, results, server) + if err == nil { + t.Fatal("expected error due to context cancellation") + } +}