From ad082a68b9cea5afce6a64bde80ed9e609a32d89 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Tue, 24 Mar 2026 23:16:56 +0800 Subject: [PATCH 01/12] fix(upload): stop retrying 4xx errors and add tests for 6 packages - Introduce clientError type to distinguish non-retryable 4xx from retryable 5xx HTTP errors in upload and analysis packages - Populate ScanMetadata.Version in upload payload - Add test suite for internal/pipeline with 10 tests (96.5% coverage) - Add test suite for internal/output with 8 tests (88.3% coverage) - Add test suite for internal/analysis with 7 tests (88.7% coverage) - Add test suite for internal/upload with 7 tests (86.8% coverage) - Add test suite for internal/tlsutil with 5 tests (87.5% coverage) - Add test suite for internal/cli with 4 tests (46.9% coverage) Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 21 +- internal/analysis/analyzer_test.go | 335 ++++++++++++++++ internal/cli/cli_test.go | 100 +++++ internal/output/output_test.go | 291 ++++++++++++++ internal/pipeline/pipeline_test.go | 624 +++++++++++++++++++++++++++++ internal/tlsutil/tlsutil_test.go | 93 +++++ internal/upload/uploader.go | 25 +- internal/upload/uploader_test.go | 267 ++++++++++++ 8 files changed, 1754 insertions(+), 2 deletions(-) create mode 100644 internal/analysis/analyzer_test.go create mode 100644 internal/cli/cli_test.go create mode 100644 internal/output/output_test.go create mode 100644 internal/pipeline/pipeline_test.go create mode 100644 internal/tlsutil/tlsutil_test.go create mode 100644 internal/upload/uploader_test.go diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 0abc7d1..b32139b 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -14,6 +15,16 @@ import ( "github.com/go-authgate/agent-scanner/internal/tlsutil" ) +// clientError is a non-retryable HTTP error (4xx). +type clientError struct { + StatusCode int + Body string +} + +func (e *clientError) Error() string { + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) +} + // Analyzer performs security analysis on scan results. type Analyzer interface { Analyze(ctx context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) @@ -113,7 +124,7 @@ func (a *remoteAnalyzer) analyzePathResult( return fmt.Errorf("marshal request: %w", err) } - // Retry with exponential backoff + // Retry with exponential backoff (only retry on 5xx / network errors) var resp analysisResponse maxRetries := 3 for attempt := range maxRetries { @@ -121,6 +132,11 @@ func (a *remoteAnalyzer) analyzePathResult( if err == nil { break } + // Do not retry client errors (4xx) + var ce *clientError + if errors.As(err, &ce) { + return fmt.Errorf("analysis API: %w", err) + } if attempt < maxRetries-1 { backoff := time.Duration(1<= 400 { respBody, _ := io.ReadAll(httpResp.Body) + if httpResp.StatusCode < 500 { + return &clientError{StatusCode: httpResp.StatusCode, Body: string(respBody)} + } return fmt.Errorf("status %d: %s", httpResp.StatusCode, string(respBody)) } diff --git a/internal/analysis/analyzer_test.go b/internal/analysis/analyzer_test.go new file mode 100644 index 0000000..e758b8b --- /dev/null +++ b/internal/analysis/analyzer_test.go @@ -0,0 +1,335 @@ +package analysis + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +func TestAnalyze_EmptyURL(t *testing.T) { + requestMade := false + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade = true + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + a := NewAnalyzer("", false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "test"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t1", Description: "desc"}}, + }, + }, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if requestMade { + t.Error("expected no HTTP request when analysis URL is empty") + } + if len(out) != 1 { + t.Errorf("expected results returned unchanged, got %d", len(out)) + } +} + +func TestAnalyze_Success(t *testing.T) { + respData := analysisResponse{ + Issues: []models.Issue{ + {Code: "E001", Message: "remote issue"}, + }, + Labels: [][]models.ScalarToolLabels{ + {{IsPublicSink: 0.9, Destructive: 0.1}}, + }, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %s", ct) + } + + // Verify request body + var req analysisRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Errorf("failed to decode request body: %v", err) + } + if len(req.Servers) != 1 { + t.Errorf("expected 1 server in request, got %d", len(req.Servers)) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(respData) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv1", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "tool1", Description: "a tool"}, + }, + }, + }, + }, + Issues: []models.Issue{ + {Code: "W001", Message: "existing issue"}, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(out) != 1 { + t.Fatalf("expected 1 result, got %d", len(out)) + } + + // Check that remote issues were merged in + if len(out[0].Issues) != 2 { + t.Fatalf("expected 2 issues (1 existing + 1 remote), got %d", len(out[0].Issues)) + } + if out[0].Issues[0].Code != "W001" { + t.Errorf("expected first issue code W001, got %s", out[0].Issues[0].Code) + } + if out[0].Issues[1].Code != "E001" { + t.Errorf("expected second issue code E001, got %s", out[0].Issues[1].Code) + } + + // Check labels were merged + if len(out[0].Labels) != 1 { + t.Fatalf("expected 1 label set, got %d", len(out[0].Labels)) + } + if out[0].Labels[0][0].IsPublicSink != 0.9 { + t.Errorf("expected IsPublicSink 0.9, got %f", out[0].Labels[0][0].IsPublicSink) + } +} + +func TestAnalyze_NilSignatureSkipped(t *testing.T) { + requestMade := false + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade = true + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(analysisResponse{}) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv-no-sig", + Server: &models.StdioServer{Command: "cmd"}, + Signature: nil, // nil signature should be skipped + }, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if requestMade { + t.Error("expected no HTTP request when all signatures are nil") + } + if len(out) != 1 { + t.Errorf("expected results returned, got %d", len(out)) + } +} + +func TestAnalyze_4xxNoRetry(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad request")) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t", Description: "d"}}, + }, + }, + }, + }, + } + + _, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("Analyze should not return error (it logs warning instead), got: %v", err) + } + + count := requestCount.Load() + if count != 1 { + t.Errorf("expected exactly 1 request (no retry on 4xx), got %d", count) + } +} + +func TestAnalyze_5xxRetries(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal server error")) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t", Description: "d"}}, + }, + }, + }, + }, + } + + // Use a context with a short deadline so we don't wait for full backoff. + // First request is immediate, then 1s backoff, then 2s backoff. + // With 1500ms we should get at least 2 attempts. + ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond) + defer cancel() + + _, err := a.Analyze(ctx, results) + // Analyze doesn't return error even on failure (it logs a warning) + if err != nil { + t.Fatalf("Analyze should not return error, got: %v", err) + } + + count := requestCount.Load() + if count < 2 { + t.Errorf("expected at least 2 requests (retries on 5xx), got %d", count) + } +} + +func TestAnalyze_FailureDoesNotFailOverall(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("error")) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p1", + Servers: []models.ServerScanResult{ + { + Name: "srv", + Server: &models.StdioServer{Command: "cmd"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t", Description: "d"}}, + }, + }, + }, + Issues: []models.Issue{ + {Code: "W001", Message: "existing"}, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("expected no error from Analyze even when analysis fails, got: %v", err) + } + + // Results should still be returned + if len(out) != 1 { + t.Fatalf("expected 1 result, got %d", len(out)) + } + // Existing issues should be preserved + if len(out[0].Issues) != 1 { + t.Errorf("expected existing issues preserved, got %d", len(out[0].Issues)) + } +} + +func TestAnalyze_AllNilSignatures(t *testing.T) { + requestMade := false + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade = true + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(analysisResponse{}) + })) + defer ts.Close() + + a := NewAnalyzer(ts.URL, false) + results := []models.ScanPathResult{ + { + Client: "test", + Path: "/p", + Servers: []models.ServerScanResult{ + { + Name: "srv1", + Server: &models.StdioServer{Command: "cmd1"}, + Signature: nil, + }, + { + Name: "srv2", + Server: &models.StdioServer{Command: "cmd2"}, + Signature: nil, + }, + }, + }, + } + + out, err := a.Analyze(context.Background(), results) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if requestMade { + t.Error("expected no HTTP request when all signatures are nil") + } + if len(out) != 1 { + t.Errorf("expected results returned unchanged, got %d", len(out)) + } +} diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go new file mode 100644 index 0000000..a95023d --- /dev/null +++ b/internal/cli/cli_test.go @@ -0,0 +1,100 @@ +package cli + +import ( + "testing" +) + +func TestParseControlServers_Empty(t *testing.T) { + // Save and restore original scanFlags. + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{} + + servers := parseControlServers() + if servers != nil { + t.Errorf("expected nil, got %v", servers) + } +} + +func TestParseControlServers_OneServer_NoIdentifier(t *testing.T) { + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{ + ControlServers: []string{"https://control.example.com"}, + } + + servers := parseControlServers() + if len(servers) != 1 { + t.Fatalf("expected 1 server, got %d", len(servers)) + } + if servers[0].URL != "https://control.example.com" { + t.Errorf("unexpected URL: %s", servers[0].URL) + } + if servers[0].Identifier != "" { + t.Errorf("expected empty identifier, got %q", servers[0].Identifier) + } +} + +func TestParseControlServers_MatchingServersAndIdentifiers(t *testing.T) { + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{ + ControlServers: []string{"https://a.example.com", "https://b.example.com"}, + ControlIdentifier: []string{"id-a", "id-b"}, + } + + servers := parseControlServers() + if len(servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(servers)) + } + + if servers[0].URL != "https://a.example.com" { + t.Errorf("server[0] URL = %q, want %q", servers[0].URL, "https://a.example.com") + } + if servers[0].Identifier != "id-a" { + t.Errorf("server[0] Identifier = %q, want %q", servers[0].Identifier, "id-a") + } + if servers[1].URL != "https://b.example.com" { + t.Errorf("server[1] URL = %q, want %q", servers[1].URL, "https://b.example.com") + } + if servers[1].Identifier != "id-b" { + t.Errorf("server[1] Identifier = %q, want %q", servers[1].Identifier, "id-b") + } +} + +func TestParseControlServers_MoreServersThanIdentifiers(t *testing.T) { + orig := scanFlags + defer func() { scanFlags = orig }() + + scanFlags = ScanFlags{ + ControlServers: []string{ + "https://a.example.com", + "https://b.example.com", + "https://c.example.com", + }, + ControlIdentifier: []string{"id-a"}, + } + + servers := parseControlServers() + if len(servers) != 3 { + t.Fatalf("expected 3 servers, got %d", len(servers)) + } + + if servers[0].URL != "https://a.example.com" { + t.Errorf("server[0] URL = %q, want %q", servers[0].URL, "https://a.example.com") + } + if servers[0].Identifier != "id-a" { + t.Errorf("server[0] Identifier = %q, want %q", servers[0].Identifier, "id-a") + } + + // servers[1] and servers[2] have index >= len(ControlIdentifier), so identifier should be empty. + if servers[1].Identifier != "" { + t.Errorf("server[1] Identifier = %q, want empty", servers[1].Identifier) + } + if servers[2].Identifier != "" { + t.Errorf("server[2] Identifier = %q, want empty", servers[2].Identifier) + } +} diff --git a/internal/output/output_test.go b/internal/output/output_test.go new file mode 100644 index 0000000..5678daf --- /dev/null +++ b/internal/output/output_test.go @@ -0,0 +1,291 @@ +package output + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +// --------------------------------------------------------------------------- +// TextFormatter tests +// --------------------------------------------------------------------------- + +func TestTextFormatter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + f := NewTextFormatter(&buf) + + if err := f.FormatResults(nil, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + + got := buf.String() + if !strings.Contains(got, "No MCP configurations found.") { + t.Errorf("expected 'No MCP configurations found.' in output, got:\n%s", got) + } +} + +func TestTextFormatter_ServersAndIssues(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "cursor", + Path: "/home/user/.cursor/mcp.json", + Servers: []models.ServerScanResult{ + { + Name: "my-server", + Server: &models.StdioServer{Command: "node"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "read_file", Description: "Reads a file"}, + {Name: "write_file", Description: "Writes a file"}, + }, + }, + }, + }, + Issues: []models.Issue{ + {Code: "E001", Message: "Prompt injection detected"}, + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + // Status icon for issues (E001 is high severity -> redCross) + if !strings.Contains(out, redCross) { + t.Error("expected red cross icon in output for high-severity issue") + } + // Server name + if !strings.Contains(out, "my-server") { + t.Error("expected server name 'my-server' in output") + } + // Entity count + if !strings.Contains(out, "2 entities") { + t.Error("expected '2 entities' in output") + } + // Issue code + if !strings.Contains(out, "E001") { + t.Error("expected issue code 'E001' in output") + } +} + +func TestTextFormatter_Summary(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "claude", + Path: "/path/a", + Servers: []models.ServerScanResult{ + { + Name: "server-a", + Server: &models.StdioServer{Command: "a"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t1"}}, + }, + }, + { + Name: "server-b", + Server: &models.StdioServer{Command: "b"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "t2"}, {Name: "t3"}}, + Prompts: []models.Prompt{{Name: "p1"}}, + }, + }, + }, + Issues: []models.Issue{ + {Code: "E002", Message: "cross-server ref"}, // high + {Code: "W001", Message: "suspicious trigger"}, // medium (warning) + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + // Summary should show 2 servers, 4 entities (1 + 2 tools + 1 prompt) + if !strings.Contains(out, "2 server(s)") { + t.Errorf("expected '2 server(s)' in summary, got:\n%s", out) + } + if !strings.Contains(out, "4 entities") { + t.Errorf("expected '4 entities' in summary, got:\n%s", out) + } + if !strings.Contains(out, "1 issue(s) found") { + t.Errorf("expected '1 issue(s) found' in summary, got:\n%s", out) + } + if !strings.Contains(out, "1 warning(s)") { + t.Errorf("expected '1 warning(s)' in summary, got:\n%s", out) + } +} + +func TestTextFormatter_NilSignature(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "vscode", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "dead-server", + Server: &models.StdioServer{Command: "dead"}, + Signature: nil, + }, + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + if !strings.Contains(out, "(no response)") { + t.Errorf("expected '(no response)' for nil signature, got:\n%s", out) + } +} + +func TestTextFormatter_ServerError_PrintErrors(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "windsurf", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "broken-server", + Server: &models.StdioServer{Command: "broken"}, + Error: &models.ScanError{ + Message: "connection refused", + Category: models.ErrCatServerStartup, + }, + }, + }, + }, + } + + // Without PrintErrors + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{PrintErrors: false}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + if strings.Contains(out, "connection refused") { + t.Error("expected error message to be hidden when PrintErrors=false") + } + + // With PrintErrors + buf.Reset() + f = NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{PrintErrors: true}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out = buf.String() + if !strings.Contains(out, "connection refused") { + t.Errorf("expected error message 'connection refused' when PrintErrors=true, got:\n%s", out) + } +} + +func TestTextFormatter_TruncatesLongDescriptions(t *testing.T) { + longDesc := strings.Repeat("A", 300) + results := []models.ScanPathResult{ + { + Client: "cursor", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "long-desc-server", + Server: &models.StdioServer{Command: "node"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "verbose-tool", Description: longDesc}, + }, + }, + }, + }, + }, + } + + var buf bytes.Buffer + f := NewTextFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{PrintFullDescs: false}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + out := buf.String() + + // Should be truncated to 200 chars + "..." + if !strings.Contains(out, "...") { + t.Error("expected truncation ellipsis '...' in output") + } + // The full 300-char string should NOT appear + if strings.Contains(out, longDesc) { + t.Error("expected description to be truncated, but found full 300-char string") + } + // First 200 chars should appear + truncated := longDesc[:200] + if !strings.Contains(out, truncated) { + t.Error("expected first 200 chars of description to appear") + } +} + +// --------------------------------------------------------------------------- +// JSONFormatter tests +// --------------------------------------------------------------------------- + +func TestJSONFormatter_ValidJSON(t *testing.T) { + results := []models.ScanPathResult{ + { + Client: "claude", + Path: "/path/to/config", + Servers: []models.ServerScanResult{ + { + Name: "test-server", + Server: &models.StdioServer{Command: "node"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{ + {Name: "hello", Description: "says hello"}, + }, + }, + }, + }, + Issues: []models.Issue{ + {Code: "W001", Message: "suspicious words"}, + }, + }, + } + + var buf bytes.Buffer + f := NewJSONFormatter(&buf) + if err := f.FormatResults(results, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + + var parsed []map[string]any + if err := json.Unmarshal(buf.Bytes(), &parsed); err != nil { + t.Fatalf("output is not valid JSON: %v\nraw output:\n%s", err, buf.String()) + } + if len(parsed) != 1 { + t.Errorf("expected 1 result in JSON array, got %d", len(parsed)) + } +} + +func TestJSONFormatter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + f := NewJSONFormatter(&buf) + if err := f.FormatResults([]models.ScanPathResult{}, FormatOptions{}); err != nil { + t.Fatalf("FormatResults returned error: %v", err) + } + + got := strings.TrimSpace(buf.String()) + if got != "[]" { + t.Errorf("expected '[]' for empty results, got: %s", got) + } +} diff --git a/internal/pipeline/pipeline_test.go b/internal/pipeline/pipeline_test.go new file mode 100644 index 0000000..2880e0c --- /dev/null +++ b/internal/pipeline/pipeline_test.go @@ -0,0 +1,624 @@ +package pipeline + +import ( + "context" + "errors" + "testing" + + "github.com/go-authgate/agent-scanner/internal/models" + "github.com/go-authgate/agent-scanner/internal/rules" +) + +// --- Mock types --- + +type mockDiscoverer struct { + discoverClientsFn func(ctx context.Context, allUsers bool) []models.CandidateClient + resolveClientFn func(ctx context.Context, candidate models.CandidateClient) ([]*models.ClientToInspect, error) + clientFromPathFn func(ctx context.Context, path string, scanSkills bool) ([]*models.ClientToInspect, error) +} + +func (m *mockDiscoverer) DiscoverClients( + ctx context.Context, + allUsers bool, +) []models.CandidateClient { + if m.discoverClientsFn != nil { + return m.discoverClientsFn(ctx, allUsers) + } + return nil +} + +func (m *mockDiscoverer) ResolveClient( + ctx context.Context, + candidate models.CandidateClient, +) ([]*models.ClientToInspect, error) { + if m.resolveClientFn != nil { + return m.resolveClientFn(ctx, candidate) + } + return []*models.ClientToInspect{}, nil +} + +func (m *mockDiscoverer) ClientFromPath( + ctx context.Context, + path string, + scanSkills bool, +) ([]*models.ClientToInspect, error) { + if m.clientFromPathFn != nil { + return m.clientFromPathFn(ctx, path, scanSkills) + } + return []*models.ClientToInspect{}, nil +} + +type mockInspector struct { + inspectClientFn func(ctx context.Context, client *models.ClientToInspect, scanSkills bool) (*models.InspectedClient, error) +} + +func (m *mockInspector) InspectServer( + _ context.Context, + _ string, + _ models.ServerConfig, +) (*models.InspectedExtension, error) { + return &models.InspectedExtension{}, nil +} + +func (m *mockInspector) InspectSkill( + _ context.Context, + _ string, + _ *models.SkillServer, +) (*models.InspectedExtension, error) { + return &models.InspectedExtension{}, nil +} + +func (m *mockInspector) InspectClient( + ctx context.Context, + client *models.ClientToInspect, + scanSkills bool, +) (*models.InspectedClient, error) { + if m.inspectClientFn != nil { + return m.inspectClientFn(ctx, client, scanSkills) + } + return &models.InspectedClient{ + Name: client.Name, + Extensions: map[string][]models.InspectedExtension{}, + }, nil +} + +type mockRuleEngine struct { + runFn func(ctx *rules.RuleContext) []models.Issue +} + +func (m *mockRuleEngine) Register(_ rules.Rule) {} + +func (m *mockRuleEngine) Run(ctx *rules.RuleContext) []models.Issue { + if m.runFn != nil { + return m.runFn(ctx) + } + return nil +} + +type mockAnalyzer struct { + analyzeFn func(ctx context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) + called bool +} + +func (m *mockAnalyzer) Analyze( + ctx context.Context, + results []models.ScanPathResult, +) ([]models.ScanPathResult, error) { + m.called = true + if m.analyzeFn != nil { + return m.analyzeFn(ctx, results) + } + return results, nil +} + +type mockUploader struct { + uploadFn func(ctx context.Context, results []models.ScanPathResult, server models.ControlServer) error + called bool + servers []models.ControlServer +} + +func (m *mockUploader) Upload( + ctx context.Context, + results []models.ScanPathResult, + server models.ControlServer, +) error { + m.called = true + m.servers = append(m.servers, server) + if m.uploadFn != nil { + return m.uploadFn(ctx, results, server) + } + return nil +} + +// --- Helpers --- + +// newTestClient returns a simple ClientToInspect for testing. +func newTestClient(name string) *models.ClientToInspect { + return &models.ClientToInspect{ + Name: name, + MCPConfigs: map[string]models.MCPConfigOrError{}, + } +} + +// newInspectedClient returns an InspectedClient with one extension under a config path. +func newInspectedClient(name, configPath, extName string) *models.InspectedClient { + return &models.InspectedClient{ + Name: name, + Extensions: map[string][]models.InspectedExtension{ + configPath: { + { + Name: extName, + Config: &models.StdioServer{Command: "echo"}, + Signature: &models.ServerSignature{ + Tools: []models.Tool{{Name: "test-tool", Description: "a test tool"}}, + }, + }, + }, + }, + } +} + +// --- Tests --- + +func TestRun_InspectOnly(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "test-client"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/config.json", "server-a"), nil + }, + } + analyzer := &mockAnalyzer{} + uploader := &mockUploader{} + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Analyzer: analyzer, + Uploader: uploader, + InspectOnly: true, + ControlServers: []ControlServerConfig{ + {URL: "https://control.example.com"}, + }, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) == 0 { + t.Fatal("expected at least one result from inspect stage") + } + if analyzer.called { + t.Error("analyzer should NOT be called in InspectOnly mode") + } + if uploader.called { + t.Error("uploader should NOT be called in InspectOnly mode") + } +} + +func TestRun_FullPipeline(t *testing.T) { + discoverCalled := false + resolveCalled := false + + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + discoverCalled = true + return []models.CandidateClient{{Name: "claude"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + resolveCalled = true + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/config.json", "srv"), nil + }, + } + + ruleEngine := &mockRuleEngine{ + runFn: func(_ *rules.RuleContext) []models.Issue { + return []models.Issue{{Code: "W001", Message: "suspicious"}} + }, + } + + analyzer := &mockAnalyzer{ + analyzeFn: func(_ context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) { + // Simulate adding labels + for i := range results { + results[i].Labels = [][]models.ScalarToolLabels{{{IsPublicSink: 0.9}}} + } + return results, nil + }, + } + + uploader := &mockUploader{} + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + RuleEngine: ruleEngine, + Analyzer: analyzer, + Uploader: uploader, + ControlServers: []ControlServerConfig{ + {URL: "https://ctrl.example.com", Identifier: "id-1"}, + }, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !discoverCalled { + t.Error("DiscoverClients should be called when no Paths are set") + } + if !resolveCalled { + t.Error("ResolveClient should be called when no Paths are set") + } + if !analyzer.called { + t.Error("analyzer should be called in full pipeline") + } + if !uploader.called { + t.Error("uploader should be called in full pipeline") + } + if len(results) == 0 { + t.Fatal("expected at least one result") + } + // Check that rule engine issues were appended. + foundIssue := false + for _, r := range results { + for _, iss := range r.Issues { + if iss.Code == "W001" { + foundIssue = true + } + } + } + if !foundIssue { + t.Error("expected rule engine issue W001 in results") + } + // Check that uploader received the correct server. + if len(uploader.servers) != 1 { + t.Fatalf("expected uploader to be called once, got %d", len(uploader.servers)) + } + if uploader.servers[0].URL != "https://ctrl.example.com" { + t.Errorf("unexpected control server URL: %s", uploader.servers[0].URL) + } + if uploader.servers[0].Identifier != "id-1" { + t.Errorf("unexpected control server identifier: %s", uploader.servers[0].Identifier) + } +} + +func TestRun_WithPaths_UsesClientFromPath(t *testing.T) { + clientFromPathCalled := false + discoverClientsCalled := false + + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + discoverClientsCalled = true + return nil + }, + clientFromPathFn: func(_ context.Context, path string, _ bool) ([]*models.ClientToInspect, error) { + clientFromPathCalled = true + return []*models.ClientToInspect{newTestClient("path-client-" + path)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/some/path", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Paths: []string{"/path/to/config.json", "/another/config.json"}, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !clientFromPathCalled { + t.Error("ClientFromPath should be called when Paths are provided") + } + if discoverClientsCalled { + t.Error("DiscoverClients should NOT be called when Paths are provided") + } + if len(results) != 2 { + t.Errorf("expected 2 results (one per path), got %d", len(results)) + } +} + +func TestRun_NoPaths_UsesDiscoverAndResolve(t *testing.T) { + discoverCalled := false + resolveCalled := false + clientFromPathCalled := false + + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + discoverCalled = true + return []models.CandidateClient{ + {Name: "client-a"}, + {Name: "client-b"}, + } + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + resolveCalled = true + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + clientFromPathFn: func(_ context.Context, _ string, _ bool) ([]*models.ClientToInspect, error) { + clientFromPathCalled = true + return nil, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !discoverCalled { + t.Error("DiscoverClients should be called") + } + if !resolveCalled { + t.Error("ResolveClient should be called") + } + if clientFromPathCalled { + t.Error("ClientFromPath should NOT be called when no Paths are set") + } + if len(results) != 2 { + t.Errorf("expected 2 results (one per discovered client), got %d", len(results)) + } +} + +func TestRun_NilAnalyzer_SkipsAnalysis(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "c"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + ruleEngine := &mockRuleEngine{ + runFn: func(_ *rules.RuleContext) []models.Issue { + return []models.Issue{{Code: "W002", Message: "too many entities"}} + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + RuleEngine: ruleEngine, + Analyzer: nil, // nil analyzer + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) == 0 { + t.Fatal("expected results") + } + // Rule engine should still have run. + foundIssue := false + for _, r := range results { + for _, iss := range r.Issues { + if iss.Code == "W002" { + foundIssue = true + } + } + } + if !foundIssue { + t.Error("expected rule engine issue W002 even with nil analyzer") + } +} + +func TestRun_NilUploader_SkipsPush(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "c"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + // nil Uploader, with ControlServers configured + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Uploader: nil, + ControlServers: []ControlServerConfig{ + {URL: "https://ctrl.example.com"}, + }, + }) + + // Should not panic. + _, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRun_EmptyControlServers_SkipsPush(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{{Name: "c"}} + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + uploader := &mockUploader{} + + // Uploader provided, but no ControlServers + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Uploader: uploader, + ControlServers: nil, + }) + + _, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if uploader.called { + t.Error("uploader should NOT be called with empty ControlServers") + } +} + +func TestRun_ClientDiscoveryFailure_Continues(t *testing.T) { + resolveCallCount := 0 + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{ + {Name: "failing-client"}, + {Name: "ok-client"}, + } + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + resolveCallCount++ + if c.Name == "failing-client" { + return nil, errors.New("resolution failed") + } + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolveCallCount != 2 { + t.Errorf("expected ResolveClient to be called 2 times, got %d", resolveCallCount) + } + + // Only the successful client should produce results. + if len(results) != 1 { + t.Errorf("expected 1 result from the successful client, got %d", len(results)) + } + if len(results) > 0 && results[0].Client != "ok-client" { + t.Errorf("expected result from ok-client, got %q", results[0].Client) + } +} + +func TestRun_ClientFromPathFailure_Continues(t *testing.T) { + disc := &mockDiscoverer{ + clientFromPathFn: func(_ context.Context, path string, _ bool) ([]*models.ClientToInspect, error) { + if path == "/bad/path" { + return nil, errors.New("path not found") + } + return []*models.ClientToInspect{newTestClient("good-client")}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + Paths: []string{"/bad/path", "/good/path"}, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Only the good path should produce results. + if len(results) != 1 { + t.Errorf("expected 1 result from the good path, got %d", len(results)) + } +} + +func TestRun_InspectClientFailure_Continues(t *testing.T) { + disc := &mockDiscoverer{ + discoverClientsFn: func(_ context.Context, _ bool) []models.CandidateClient { + return []models.CandidateClient{ + {Name: "fail-inspect"}, + {Name: "ok-inspect"}, + } + }, + resolveClientFn: func(_ context.Context, c models.CandidateClient) ([]*models.ClientToInspect, error) { + return []*models.ClientToInspect{newTestClient(c.Name)}, nil + }, + } + insp := &mockInspector{ + inspectClientFn: func(_ context.Context, client *models.ClientToInspect, _ bool) (*models.InspectedClient, error) { + if client.Name == "fail-inspect" { + return nil, errors.New("inspection failed") + } + return newInspectedClient(client.Name, "/cfg", "ext"), nil + }, + } + + p := New(Config{ + Discoverer: disc, + Inspector: insp, + InspectOnly: true, + }) + + results, err := p.Run(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(results) != 1 { + t.Errorf("expected 1 result (from ok-inspect only), got %d", len(results)) + } +} diff --git a/internal/tlsutil/tlsutil_test.go b/internal/tlsutil/tlsutil_test.go new file mode 100644 index 0000000..e1becfb --- /dev/null +++ b/internal/tlsutil/tlsutil_test.go @@ -0,0 +1,93 @@ +package tlsutil + +import ( + "crypto/tls" + "net/http" + "testing" +) + +func TestCloneTransport_ReturnsNonNil(t *testing.T) { + tr := CloneTransport() + if tr == nil { + t.Fatal("CloneTransport() returned nil") + } +} + +func TestCloneTransport_ReturnsSeparateInstance(t *testing.T) { + tr := CloneTransport() + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Skip("http.DefaultTransport is not *http.Transport; cannot compare pointers") + } + if tr == base { + t.Fatal("CloneTransport() returned the same pointer as http.DefaultTransport") + } +} + +func TestCloneTransport_HasExpectedDefaults(t *testing.T) { + tr := CloneTransport() + + if tr.MaxIdleConns != 100 { + t.Errorf("MaxIdleConns = %d, want 100", tr.MaxIdleConns) + } + if tr.IdleConnTimeout != 90e9 { + t.Errorf("IdleConnTimeout = %v, want 90s", tr.IdleConnTimeout) + } + if tr.TLSHandshakeTimeout != 10e9 { + t.Errorf("TLSHandshakeTimeout = %v, want 10s", tr.TLSHandshakeTimeout) + } + if !tr.ForceAttemptHTTP2 { + t.Error("ForceAttemptHTTP2 = false, want true") + } +} + +func TestApplyInsecureSkipVerify_NilTLSConfig(t *testing.T) { + tr := &http.Transport{} + if tr.TLSClientConfig != nil { + t.Fatal("precondition failed: TLSClientConfig should be nil") + } + + ApplyInsecureSkipVerify(tr) + + if tr.TLSClientConfig == nil { + t.Fatal("TLSClientConfig is still nil after ApplyInsecureSkipVerify") + } + if !tr.TLSClientConfig.InsecureSkipVerify { + t.Error("InsecureSkipVerify = false, want true") + } +} + +func TestApplyInsecureSkipVerify_ExistingTLSConfig(t *testing.T) { + original := &tls.Config{ + ServerName: "example.com", + MinVersion: tls.VersionTLS12, + } + tr := &http.Transport{ + TLSClientConfig: original, + } + + ApplyInsecureSkipVerify(tr) + + // Must not mutate the original config. + if original.InsecureSkipVerify { + t.Error("original tls.Config was mutated; InsecureSkipVerify should still be false") + } + + // The transport must have a new config. + if tr.TLSClientConfig == original { + t.Error("TLSClientConfig is the same pointer as the original; should be a clone") + } + + // New config must have InsecureSkipVerify set. + if !tr.TLSClientConfig.InsecureSkipVerify { + t.Error("InsecureSkipVerify = false on new config, want true") + } + + // Cloned config should preserve existing fields. + if tr.TLSClientConfig.ServerName != "example.com" { + t.Errorf("ServerName = %q, want %q", tr.TLSClientConfig.ServerName, "example.com") + } + if tr.TLSClientConfig.MinVersion != tls.VersionTLS12 { + t.Errorf("MinVersion = %d, want %d", tr.TLSClientConfig.MinVersion, tls.VersionTLS12) + } +} diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 9c65be5..bb0da35 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -14,8 +15,19 @@ import ( "github.com/go-authgate/agent-scanner/internal/models" "github.com/go-authgate/agent-scanner/internal/redact" + "github.com/go-authgate/agent-scanner/internal/version" ) +// clientError is a non-retryable HTTP error (4xx). +type clientError struct { + StatusCode int + Body string +} + +func (e *clientError) Error() string { + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) +} + // Uploader pushes scan results to control servers. type Uploader interface { Upload(ctx context.Context, results []models.ScanPathResult, server models.ControlServer) error @@ -56,6 +68,9 @@ func (u *uploader) Upload( Hostname: getHostname(), Username: getUsername(), }, + ScanMetadata: &models.ScanMetadata{ + Version: version.Version, + }, } body, err := json.Marshal(payload) @@ -63,7 +78,7 @@ func (u *uploader) Upload( return fmt.Errorf("marshal upload payload: %w", err) } - // Retry with exponential backoff + // Retry with exponential backoff (only retry on 5xx / network errors) maxRetries := 3 for attempt := range maxRetries { err = u.doUpload(ctx, server, body) @@ -71,6 +86,11 @@ func (u *uploader) Upload( slog.Info("upload successful", "url", server.URL) return nil } + // Do not retry client errors (4xx) + var ce *clientError + if errors.As(err, &ce) { + return err + } if attempt < maxRetries-1 { backoff := time.Duration(1<= 400 { respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode < 500 { + return &clientError{StatusCode: resp.StatusCode, Body: string(respBody)} + } return fmt.Errorf("status %d: %s", resp.StatusCode, string(respBody)) } diff --git a/internal/upload/uploader_test.go b/internal/upload/uploader_test.go new file mode 100644 index 0000000..cd33b3b --- /dev/null +++ b/internal/upload/uploader_test.go @@ -0,0 +1,267 @@ +package upload + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +func TestUpload_Success(t *testing.T) { + var receivedBody models.ScanPathResultsCreate + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %s", ct) + } + if err := json.NewDecoder(r.Body).Decode(&receivedBody); err != nil { + t.Errorf("failed to decode request body: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + { + Client: "test-client", + Path: "/test/path", + Issues: []models.Issue{ + {Code: "E001", Message: "test issue"}, + }, + }, + } + server := models.ControlServer{ + URL: ts.URL, + Identifier: "test-id", + } + + err := u.Upload(context.Background(), results, server) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(receivedBody.ScanPathResults) != 1 { + t.Fatalf("expected 1 scan path result, got %d", len(receivedBody.ScanPathResults)) + } + if receivedBody.ScanPathResults[0].Client != "test-client" { + t.Errorf("expected client 'test-client', got %q", receivedBody.ScanPathResults[0].Client) + } + if receivedBody.ScanUserInfo.Hostname == "" { + t.Error("expected non-empty hostname") + } + if receivedBody.ScanUserInfo.Username == "" { + t.Error("expected non-empty username") + } +} + +func TestUpload_EmptyResults(t *testing.T) { + requestMade := false + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade = true + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + server := models.ControlServer{URL: ts.URL} + + err := u.Upload(context.Background(), nil, server) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if requestMade { + t.Error("expected no HTTP request for empty results") + } + + err = u.Upload(context.Background(), []models.ScanPathResult{}, server) + if err != nil { + t.Fatalf("expected nil error for empty slice, got %v", err) + } + if requestMade { + t.Error("expected no HTTP request for empty slice") + } +} + +func TestUpload_4xxNoRetry(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad request")) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + err := u.Upload(context.Background(), results, server) + if err == nil { + t.Fatal("expected error for 400 response") + } + + count := requestCount.Load() + if count != 1 { + t.Errorf("expected exactly 1 request (no retry on 4xx), got %d", count) + } + + // Verify it's a clientError + var ce *clientError + if !containsClientError(err) { + t.Errorf("expected clientError in chain, got %T: %v", err, err) + } + _ = ce +} + +func TestUpload_5xxRetries(t *testing.T) { + var requestCount atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("server error")) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + // Use a context with a short deadline to avoid waiting for full backoff. + // The first request is immediate, then backoff is 1s, 2s. + // With a 1500ms deadline, we should get at least 2 attempts. + ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond) + defer cancel() + + err := u.Upload(ctx, results, server) + if err == nil { + t.Fatal("expected error for 500 response") + } + + count := requestCount.Load() + if count < 2 { + t.Errorf("expected at least 2 requests (retries on 5xx), got %d", count) + } +} + +func TestUpload_CustomHeaders(t *testing.T) { + var receivedHeaders http.Header + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{ + URL: ts.URL, + Headers: map[string]string{ + "Authorization": "Bearer test-token", + "X-Custom": "custom-value", + }, + } + + err := u.Upload(context.Background(), results, server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedHeaders.Get("Authorization") != "Bearer test-token" { + t.Errorf( + "expected Authorization header 'Bearer test-token', got %q", + receivedHeaders.Get("Authorization"), + ) + } + if receivedHeaders.Get("X-Custom") != "custom-value" { + t.Errorf("expected X-Custom header 'custom-value', got %q", receivedHeaders.Get("X-Custom")) + } +} + +func TestUpload_ScanMetadataVersionPopulated(t *testing.T) { + var receivedBody models.ScanPathResultsCreate + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&receivedBody); err != nil { + t.Errorf("failed to decode request body: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + err := u.Upload(context.Background(), results, server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedBody.ScanMetadata == nil { + t.Fatal("expected ScanMetadata to be non-nil") + } + if receivedBody.ScanMetadata.Version == "" { + t.Error("expected ScanMetadata.Version to be populated") + } +} + +func TestUpload_ContextCancellation(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate a slow server by sleeping longer than the context deadline + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + u := NewUploader() + results := []models.ScanPathResult{ + {Client: "test", Path: "/p"}, + } + server := models.ControlServer{URL: ts.URL} + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := u.Upload(ctx, results, server) + if err == nil { + t.Fatal("expected error due to context cancellation") + } +} + +// containsClientError checks if the error chain contains a *clientError. +func containsClientError(err error) bool { + var ce *clientError + for e := err; e != nil; { + if _, ok := e.(*clientError); ok { + return true + } + // Check using errors.As which unwraps + if unwrapper, ok := e.(interface{ Unwrap() error }); ok { + e = unwrapper.Unwrap() + } else { + break + } + } + _ = ce + return false +} From fc36218585f2b0c90757965ac8f5fe9c5aec06ea Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 19:28:29 +0800 Subject: [PATCH 02/12] fix: address copilot review feedback on tests and error handling - Use time.Duration constants instead of raw nanosecond literals in tlsutil tests - Wrap clientError with context (URL, attempt count) in upload error path - Use errors.As directly and remove redundant containsClientError helper - Increase retry test timeouts to avoid flakiness under CI/race detector Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer_test.go | 6 +++--- internal/tlsutil/tlsutil_test.go | 5 +++-- internal/upload/uploader.go | 8 +++++++- internal/upload/uploader_test.go | 30 ++++++------------------------ 4 files changed, 19 insertions(+), 30 deletions(-) diff --git a/internal/analysis/analyzer_test.go b/internal/analysis/analyzer_test.go index e758b8b..90c228c 100644 --- a/internal/analysis/analyzer_test.go +++ b/internal/analysis/analyzer_test.go @@ -233,10 +233,10 @@ func TestAnalyze_5xxRetries(t *testing.T) { }, } - // Use a context with a short deadline so we don't wait for full backoff. + // Use a context with a deadline so we don't wait for full backoff. // First request is immediate, then 1s backoff, then 2s backoff. - // With 1500ms we should get at least 2 attempts. - ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond) + // With a 4s timeout we should reliably get at least 2 attempts without flakiness. + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) defer cancel() _, err := a.Analyze(ctx, results) diff --git a/internal/tlsutil/tlsutil_test.go b/internal/tlsutil/tlsutil_test.go index e1becfb..758c4b9 100644 --- a/internal/tlsutil/tlsutil_test.go +++ b/internal/tlsutil/tlsutil_test.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net/http" "testing" + "time" ) func TestCloneTransport_ReturnsNonNil(t *testing.T) { @@ -30,10 +31,10 @@ func TestCloneTransport_HasExpectedDefaults(t *testing.T) { if tr.MaxIdleConns != 100 { t.Errorf("MaxIdleConns = %d, want 100", tr.MaxIdleConns) } - if tr.IdleConnTimeout != 90e9 { + if tr.IdleConnTimeout != 90*time.Second { t.Errorf("IdleConnTimeout = %v, want 90s", tr.IdleConnTimeout) } - if tr.TLSHandshakeTimeout != 10e9 { + if tr.TLSHandshakeTimeout != 10*time.Second { t.Errorf("TLSHandshakeTimeout = %v, want 10s", tr.TLSHandshakeTimeout) } if !tr.ForceAttemptHTTP2 { diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index bb0da35..0c16b99 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -89,7 +89,13 @@ func (u *uploader) Upload( // Do not retry client errors (4xx) var ce *clientError if errors.As(err, &ce) { - return err + return fmt.Errorf( + "upload failed due to non-retryable client error after %d attempt(s) (url=%s, status=%d): %w", + attempt+1, + server.URL, + ce.StatusCode, + err, + ) } if attempt < maxRetries-1 { backoff := time.Duration(1< Date: Wed, 25 Mar 2026 19:36:14 +0800 Subject: [PATCH 03/12] fix(upload,analysis): stop retrying non-retryable errors in retry loops - Add nonRetryableError type to both upload and analysis packages - Wrap request construction errors as non-retryable (invalid URL won't fix on retry) - Wrap JSON decode errors as non-retryable in analysis doRequest - Only 5xx responses and transient network errors are now retried Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 21 +++++++++++++++++++-- internal/upload/uploader.go | 17 +++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index b32139b..438faf8 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -25,6 +25,15 @@ func (e *clientError) Error() string { return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) } +// nonRetryableError wraps errors that should not be retried +// (e.g., request construction failures, JSON decode errors). +type nonRetryableError struct { + err error +} + +func (e *nonRetryableError) Error() string { return e.err.Error() } +func (e *nonRetryableError) Unwrap() error { return e.err } + // Analyzer performs security analysis on scan results. type Analyzer interface { Analyze(ctx context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error) @@ -132,6 +141,11 @@ func (a *remoteAnalyzer) analyzePathResult( if err == nil { break } + // Do not retry non-retryable errors (bad URL, JSON decode, etc.) + var nre *nonRetryableError + if errors.As(err, &nre) { + return fmt.Errorf("analysis API: %w", err) + } // Do not retry client errors (4xx) var ce *clientError if errors.As(err, &ce) { @@ -166,7 +180,7 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy bytes.NewReader(body), ) if err != nil { - return err + return &nonRetryableError{err: err} } req.Header.Set("Content-Type", "application/json") @@ -184,5 +198,8 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy return fmt.Errorf("status %d: %s", httpResp.StatusCode, string(respBody)) } - return json.NewDecoder(httpResp.Body).Decode(resp) + if err := json.NewDecoder(httpResp.Body).Decode(resp); err != nil { + return &nonRetryableError{err: fmt.Errorf("decode response: %w", err)} + } + return nil } diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 0c16b99..32b9c76 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -28,6 +28,15 @@ func (e *clientError) Error() string { return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) } +// nonRetryableError wraps errors that should not be retried +// (e.g., request construction failures). +type nonRetryableError struct { + err error +} + +func (e *nonRetryableError) Error() string { return e.err.Error() } +func (e *nonRetryableError) Unwrap() error { return e.err } + // Uploader pushes scan results to control servers. type Uploader interface { Upload(ctx context.Context, results []models.ScanPathResult, server models.ControlServer) error @@ -86,7 +95,11 @@ func (u *uploader) Upload( slog.Info("upload successful", "url", server.URL) return nil } - // Do not retry client errors (4xx) + // Do not retry client errors (4xx) or non-retryable errors (e.g., bad URL) + var nre *nonRetryableError + if errors.As(err, &nre) { + return fmt.Errorf("upload failed: %w", err) + } var ce *clientError if errors.As(err, &ce) { return fmt.Errorf( @@ -114,7 +127,7 @@ func (u *uploader) Upload( func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, body []byte) error { req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, bytes.NewReader(body)) if err != nil { - return err + return &nonRetryableError{err: err} } req.Header.Set("Content-Type", "application/json") for k, v := range server.Headers { From 2dc191b6476252eecdba67ca4d49b6d672c630a4 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 19:53:31 +0800 Subject: [PATCH 04/12] fix: remove URL from error strings, limit error body reads, and fix brittle test - Remove server URL from upload error message to avoid leaking sensitive data - Use structured logging (slog.Error) for URL context instead - Cap error response body reads to 4KB via io.LimitReader in both uploader and analyzer - Compare transport defaults against http.DefaultTransport instead of hard-coded values Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 2 +- internal/tlsutil/tlsutil_test.go | 28 ++++++++++++++++++---------- internal/upload/uploader.go | 12 +++++++++--- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 438faf8..8e6b32d 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -191,7 +191,7 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy defer httpResp.Body.Close() if httpResp.StatusCode >= 400 { - respBody, _ := io.ReadAll(httpResp.Body) + respBody, _ := io.ReadAll(io.LimitReader(httpResp.Body, 4096)) if httpResp.StatusCode < 500 { return &clientError{StatusCode: httpResp.StatusCode, Body: string(respBody)} } diff --git a/internal/tlsutil/tlsutil_test.go b/internal/tlsutil/tlsutil_test.go index 758c4b9..95468be 100644 --- a/internal/tlsutil/tlsutil_test.go +++ b/internal/tlsutil/tlsutil_test.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "net/http" "testing" - "time" ) func TestCloneTransport_ReturnsNonNil(t *testing.T) { @@ -25,20 +24,29 @@ func TestCloneTransport_ReturnsSeparateInstance(t *testing.T) { } } -func TestCloneTransport_HasExpectedDefaults(t *testing.T) { +func TestCloneTransport_MatchesDefaultTransport(t *testing.T) { tr := CloneTransport() - if tr.MaxIdleConns != 100 { - t.Errorf("MaxIdleConns = %d, want 100", tr.MaxIdleConns) + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Skip("http.DefaultTransport is not *http.Transport") + } + + if tr.MaxIdleConns != base.MaxIdleConns { + t.Errorf("MaxIdleConns = %d, want %d", tr.MaxIdleConns, base.MaxIdleConns) } - if tr.IdleConnTimeout != 90*time.Second { - t.Errorf("IdleConnTimeout = %v, want 90s", tr.IdleConnTimeout) + if tr.IdleConnTimeout != base.IdleConnTimeout { + t.Errorf("IdleConnTimeout = %v, want %v", tr.IdleConnTimeout, base.IdleConnTimeout) } - if tr.TLSHandshakeTimeout != 10*time.Second { - t.Errorf("TLSHandshakeTimeout = %v, want 10s", tr.TLSHandshakeTimeout) + if tr.TLSHandshakeTimeout != base.TLSHandshakeTimeout { + t.Errorf( + "TLSHandshakeTimeout = %v, want %v", + tr.TLSHandshakeTimeout, + base.TLSHandshakeTimeout, + ) } - if !tr.ForceAttemptHTTP2 { - t.Error("ForceAttemptHTTP2 = false, want true") + if tr.ForceAttemptHTTP2 != base.ForceAttemptHTTP2 { + t.Errorf("ForceAttemptHTTP2 = %v, want %v", tr.ForceAttemptHTTP2, base.ForceAttemptHTTP2) } } diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 32b9c76..b4dc347 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -102,10 +102,16 @@ func (u *uploader) Upload( } var ce *clientError if errors.As(err, &ce) { + slog.Error( + "upload failed due to non-retryable client error", + "attempts", attempt+1, + "url", server.URL, + "status", ce.StatusCode, + "err", err, + ) return fmt.Errorf( - "upload failed due to non-retryable client error after %d attempt(s) (url=%s, status=%d): %w", + "upload failed due to non-retryable client error after %d attempt(s) (status=%d): %w", attempt+1, - server.URL, ce.StatusCode, err, ) @@ -141,7 +147,7 @@ func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, bo defer resp.Body.Close() if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(resp.Body) + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) if resp.StatusCode < 500 { return &clientError{StatusCode: resp.StatusCode, Body: string(respBody)} } From 150ed458a8f4988bbcc27c5eb63aa8d85e2adb45 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 20:01:38 +0800 Subject: [PATCH 05/12] fix(test): use atomic.Bool for cross-goroutine requestMade flags - Replace plain bool with atomic.Bool in upload and analysis test handlers - Prevents potential data races between httptest handler and test goroutines Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer_test.go | 18 +++++++++--------- internal/upload/uploader_test.go | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/internal/analysis/analyzer_test.go b/internal/analysis/analyzer_test.go index 90c228c..e96604a 100644 --- a/internal/analysis/analyzer_test.go +++ b/internal/analysis/analyzer_test.go @@ -13,9 +13,9 @@ import ( ) func TestAnalyze_EmptyURL(t *testing.T) { - requestMade := false + var requestMade atomic.Bool ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestMade = true + requestMade.Store(true) w.WriteHeader(http.StatusOK) })) defer ts.Close() @@ -41,7 +41,7 @@ func TestAnalyze_EmptyURL(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if requestMade { + if requestMade.Load() { t.Error("expected no HTTP request when analysis URL is empty") } if len(out) != 1 { @@ -133,9 +133,9 @@ func TestAnalyze_Success(t *testing.T) { } func TestAnalyze_NilSignatureSkipped(t *testing.T) { - requestMade := false + var requestMade atomic.Bool ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestMade = true + requestMade.Store(true) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(analysisResponse{}) })) @@ -160,7 +160,7 @@ func TestAnalyze_NilSignatureSkipped(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if requestMade { + if requestMade.Load() { t.Error("expected no HTTP request when all signatures are nil") } if len(out) != 1 { @@ -294,9 +294,9 @@ func TestAnalyze_FailureDoesNotFailOverall(t *testing.T) { } func TestAnalyze_AllNilSignatures(t *testing.T) { - requestMade := false + var requestMade atomic.Bool ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestMade = true + requestMade.Store(true) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(analysisResponse{}) })) @@ -326,7 +326,7 @@ func TestAnalyze_AllNilSignatures(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if requestMade { + if requestMade.Load() { t.Error("expected no HTTP request when all signatures are nil") } if len(out) != 1 { diff --git a/internal/upload/uploader_test.go b/internal/upload/uploader_test.go index 221130a..94b9077 100644 --- a/internal/upload/uploader_test.go +++ b/internal/upload/uploader_test.go @@ -65,9 +65,9 @@ func TestUpload_Success(t *testing.T) { } func TestUpload_EmptyResults(t *testing.T) { - requestMade := false + var requestMade atomic.Bool ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestMade = true + requestMade.Store(true) w.WriteHeader(http.StatusOK) })) defer ts.Close() @@ -79,7 +79,7 @@ func TestUpload_EmptyResults(t *testing.T) { if err != nil { t.Fatalf("expected nil error, got %v", err) } - if requestMade { + if requestMade.Load() { t.Error("expected no HTTP request for empty results") } @@ -87,7 +87,7 @@ func TestUpload_EmptyResults(t *testing.T) { if err != nil { t.Fatalf("expected nil error for empty slice, got %v", err) } - if requestMade { + if requestMade.Load() { t.Error("expected no HTTP request for empty slice") } } From ad552f49e5ce64eb4a24a42cd9276b741415cfde Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 20:12:39 +0800 Subject: [PATCH 06/12] fix(upload,analysis): remove response body from error strings to prevent data leaks - Remove Body field from clientError struct in both packages - Log response body at debug level instead of embedding in error strings - Remove duplicate slog.Error call for 4xx errors in uploader Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 11 +++++++---- internal/upload/uploader.go | 18 +++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 8e6b32d..3d0b291 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -18,11 +18,10 @@ import ( // clientError is a non-retryable HTTP error (4xx). type clientError struct { StatusCode int - Body string } func (e *clientError) Error() string { - return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) + return fmt.Sprintf("status %d", e.StatusCode) } // nonRetryableError wraps errors that should not be retried @@ -192,10 +191,14 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy if httpResp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(httpResp.Body, 4096)) + slog.Debug("analysis API returned non-2xx status", + "status", httpResp.StatusCode, + "body_bytes", len(respBody), + ) if httpResp.StatusCode < 500 { - return &clientError{StatusCode: httpResp.StatusCode, Body: string(respBody)} + return &clientError{StatusCode: httpResp.StatusCode} } - return fmt.Errorf("status %d: %s", httpResp.StatusCode, string(respBody)) + return fmt.Errorf("status %d", httpResp.StatusCode) } if err := json.NewDecoder(httpResp.Body).Decode(resp); err != nil { diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index b4dc347..fea9eee 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -21,11 +21,10 @@ import ( // clientError is a non-retryable HTTP error (4xx). type clientError struct { StatusCode int - Body string } func (e *clientError) Error() string { - return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) + return fmt.Sprintf("status %d", e.StatusCode) } // nonRetryableError wraps errors that should not be retried @@ -102,13 +101,6 @@ func (u *uploader) Upload( } var ce *clientError if errors.As(err, &ce) { - slog.Error( - "upload failed due to non-retryable client error", - "attempts", attempt+1, - "url", server.URL, - "status", ce.StatusCode, - "err", err, - ) return fmt.Errorf( "upload failed due to non-retryable client error after %d attempt(s) (status=%d): %w", attempt+1, @@ -148,10 +140,14 @@ func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, bo if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + slog.Debug("upload received non-2xx response", + "status", resp.StatusCode, + "body_bytes", len(respBody), + ) if resp.StatusCode < 500 { - return &clientError{StatusCode: resp.StatusCode, Body: string(respBody)} + return &clientError{StatusCode: resp.StatusCode} } - return fmt.Errorf("status %d: %s", resp.StatusCode, string(respBody)) + return fmt.Errorf("status %d", resp.StatusCode) } return nil From 015e5f45119c47f4ec04680cc9ac17c4129a7c09 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 20:46:34 +0800 Subject: [PATCH 07/12] fix: include truncated response body in errors and reduce test sleep - Include truncated (512B) response body snippet in error messages for debuggability - Restore Body field on clientError with truncated content - Reduce context cancellation test sleep from 2s to 500ms Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 15 ++++++++------- internal/upload/uploader.go | 15 ++++++++------- internal/upload/uploader_test.go | 4 ++-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 3d0b291..52be7b0 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -18,10 +18,11 @@ import ( // clientError is a non-retryable HTTP error (4xx). type clientError struct { StatusCode int + Body string } func (e *clientError) Error() string { - return fmt.Sprintf("status %d", e.StatusCode) + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) } // nonRetryableError wraps errors that should not be retried @@ -191,14 +192,14 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy if httpResp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(httpResp.Body, 4096)) - slog.Debug("analysis API returned non-2xx status", - "status", httpResp.StatusCode, - "body_bytes", len(respBody), - ) + bodySnippet := string(respBody) + if len(bodySnippet) > 512 { + bodySnippet = bodySnippet[:512] + " [truncated]" + } if httpResp.StatusCode < 500 { - return &clientError{StatusCode: httpResp.StatusCode} + return &clientError{StatusCode: httpResp.StatusCode, Body: bodySnippet} } - return fmt.Errorf("status %d", httpResp.StatusCode) + return fmt.Errorf("status %d: %s", httpResp.StatusCode, bodySnippet) } if err := json.NewDecoder(httpResp.Body).Decode(resp); err != nil { diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index fea9eee..8bb4cf6 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -21,10 +21,11 @@ import ( // clientError is a non-retryable HTTP error (4xx). type clientError struct { StatusCode int + Body string } func (e *clientError) Error() string { - return fmt.Sprintf("status %d", e.StatusCode) + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body) } // nonRetryableError wraps errors that should not be retried @@ -140,14 +141,14 @@ func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, bo if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - slog.Debug("upload received non-2xx response", - "status", resp.StatusCode, - "body_bytes", len(respBody), - ) + bodySnippet := string(respBody) + if len(bodySnippet) > 512 { + bodySnippet = bodySnippet[:512] + " [truncated]" + } if resp.StatusCode < 500 { - return &clientError{StatusCode: resp.StatusCode} + return &clientError{StatusCode: resp.StatusCode, Body: bodySnippet} } - return fmt.Errorf("status %d", resp.StatusCode) + return fmt.Errorf("status %d: %s", resp.StatusCode, bodySnippet) } return nil diff --git a/internal/upload/uploader_test.go b/internal/upload/uploader_test.go index 94b9077..539484d 100644 --- a/internal/upload/uploader_test.go +++ b/internal/upload/uploader_test.go @@ -227,8 +227,8 @@ func TestUpload_ScanMetadataVersionPopulated(t *testing.T) { func TestUpload_ContextCancellation(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Simulate a slow server by sleeping longer than the context deadline - time.Sleep(2 * time.Second) + // Simulate a slow server; sleep just long enough to exceed the client context deadline + time.Sleep(500 * time.Millisecond) w.WriteHeader(http.StatusOK) })) defer ts.Close() From 7e1f61f8aaccfaf9338695f53a48599889b9a112 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 20:54:07 +0800 Subject: [PATCH 08/12] fix(upload,analysis): sanitize response body snippets for safe logging - Add sanitizeBodySnippet helper to replace newlines and control chars - Prevents multi-line log output from raw HTTP response bodies Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 16 ++++++++++++---- internal/upload/uploader.go | 16 ++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 52be7b0..d56b49c 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -9,6 +9,7 @@ import ( "io" "log/slog" "net/http" + "strings" "time" "github.com/go-authgate/agent-scanner/internal/models" @@ -192,10 +193,7 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy if httpResp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(httpResp.Body, 4096)) - bodySnippet := string(respBody) - if len(bodySnippet) > 512 { - bodySnippet = bodySnippet[:512] + " [truncated]" - } + bodySnippet := sanitizeBodySnippet(string(respBody), 512) if httpResp.StatusCode < 500 { return &clientError{StatusCode: httpResp.StatusCode, Body: bodySnippet} } @@ -207,3 +205,13 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy } return nil } + +// sanitizeBodySnippet truncates s to maxLen bytes and replaces +// newlines/control characters with spaces for safe single-line logging. +func sanitizeBodySnippet(s string, maxLen int) string { + if len(s) > maxLen { + s = s[:maxLen] + " [truncated]" + } + replacer := strings.NewReplacer("\r\n", " ", "\r", " ", "\n", " ", "\t", " ") + return replacer.Replace(s) +} diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 8bb4cf6..27069d8 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "os/user" + "strings" "time" "github.com/go-authgate/agent-scanner/internal/models" @@ -141,10 +142,7 @@ func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, bo if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - bodySnippet := string(respBody) - if len(bodySnippet) > 512 { - bodySnippet = bodySnippet[:512] + " [truncated]" - } + bodySnippet := sanitizeBodySnippet(string(respBody), 512) if resp.StatusCode < 500 { return &clientError{StatusCode: resp.StatusCode, Body: bodySnippet} } @@ -172,3 +170,13 @@ func getUsername() string { } return u.Username } + +// sanitizeBodySnippet truncates s to maxLen bytes and replaces +// newlines/control characters with spaces for safe single-line logging. +func sanitizeBodySnippet(s string, maxLen int) string { + if len(s) > maxLen { + s = s[:maxLen] + " [truncated]" + } + replacer := strings.NewReplacer("\r\n", " ", "\r", " ", "\n", " ", "\t", " ") + return replacer.Replace(s) +} From db64287b619a54ff98201977b3e4842cbfe16abe Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 21:03:03 +0800 Subject: [PATCH 09/12] docs: fix sanitizeBodySnippet comment to reflect actual truncation behavior Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 5 +++-- internal/upload/uploader.go | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index d56b49c..55c101e 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -206,8 +206,9 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy return nil } -// sanitizeBodySnippet truncates s to maxLen bytes and replaces -// newlines/control characters with spaces for safe single-line logging. +// sanitizeBodySnippet truncates s to approximately maxLen bytes (the +// returned string may be slightly longer due to a " [truncated]" suffix) +// and replaces newlines/control characters with spaces for safe single-line logging. func sanitizeBodySnippet(s string, maxLen int) string { if len(s) > maxLen { s = s[:maxLen] + " [truncated]" diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 27069d8..97900cc 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -171,8 +171,9 @@ func getUsername() string { return u.Username } -// sanitizeBodySnippet truncates s to maxLen bytes and replaces -// newlines/control characters with spaces for safe single-line logging. +// sanitizeBodySnippet truncates s to approximately maxLen bytes (the +// returned string may be slightly longer due to a " [truncated]" suffix) +// and replaces newlines/control characters with spaces for safe single-line logging. func sanitizeBodySnippet(s string, maxLen int) string { if len(s) > maxLen { s = s[:maxLen] + " [truncated]" From 0aa11327ca62a18221343febfcd7ae077b850095 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 21:22:19 +0800 Subject: [PATCH 10/12] fix(upload): drain response body on success for HTTP keep-alive reuse Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/upload/uploader.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 97900cc..19cbf29 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -149,6 +149,9 @@ func (u *uploader) doUpload(ctx context.Context, server models.ControlServer, bo return fmt.Errorf("status %d: %s", resp.StatusCode, bodySnippet) } + // Drain response body on success to allow HTTP connection reuse. + _, _ = io.Copy(io.Discard, resp.Body) + return nil } From 4696da4aa6995737e1cd2646db8be2ad1297cee8 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 23:11:57 +0800 Subject: [PATCH 11/12] fix(upload,analysis): use unicode.IsControl to sanitize all control characters - Replace strings.NewReplacer (only CR/LF/Tab) with strings.Map + unicode.IsControl - Now correctly strips all Unicode control characters from body snippets Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/analysis/analyzer.go | 11 ++++++++--- internal/upload/uploader.go | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/internal/analysis/analyzer.go b/internal/analysis/analyzer.go index 55c101e..6c798fd 100644 --- a/internal/analysis/analyzer.go +++ b/internal/analysis/analyzer.go @@ -11,6 +11,7 @@ import ( "net/http" "strings" "time" + "unicode" "github.com/go-authgate/agent-scanner/internal/models" "github.com/go-authgate/agent-scanner/internal/tlsutil" @@ -208,11 +209,15 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy // sanitizeBodySnippet truncates s to approximately maxLen bytes (the // returned string may be slightly longer due to a " [truncated]" suffix) -// and replaces newlines/control characters with spaces for safe single-line logging. +// and replaces all Unicode control characters with spaces for safe single-line logging. func sanitizeBodySnippet(s string, maxLen int) string { if len(s) > maxLen { s = s[:maxLen] + " [truncated]" } - replacer := strings.NewReplacer("\r\n", " ", "\r", " ", "\n", " ", "\t", " ") - return replacer.Replace(s) + return strings.Map(func(r rune) rune { + if unicode.IsControl(r) { + return ' ' + } + return r + }, s) } diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 19cbf29..0a7590c 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -13,6 +13,7 @@ import ( "os/user" "strings" "time" + "unicode" "github.com/go-authgate/agent-scanner/internal/models" "github.com/go-authgate/agent-scanner/internal/redact" @@ -176,11 +177,15 @@ func getUsername() string { // sanitizeBodySnippet truncates s to approximately maxLen bytes (the // returned string may be slightly longer due to a " [truncated]" suffix) -// and replaces newlines/control characters with spaces for safe single-line logging. +// and replaces all Unicode control characters with spaces for safe single-line logging. func sanitizeBodySnippet(s string, maxLen int) string { if len(s) > maxLen { s = s[:maxLen] + " [truncated]" } - replacer := strings.NewReplacer("\r\n", " ", "\r", " ", "\n", " ", "\t", " ") - return replacer.Replace(s) + return strings.Map(func(r rune) rune { + if unicode.IsControl(r) { + return ' ' + } + return r + }, s) } From 737540932536bcc424998938f6136e0e7328f297 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 25 Mar 2026 23:35:26 +0800 Subject: [PATCH 12/12] fix(upload): remove redundant status code from client error message MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The error format string included (status=%d) explicitly while also wrapping a clientError whose Error() already produces "status : …", resulting in duplicated status information in the output. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/upload/uploader.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/upload/uploader.go b/internal/upload/uploader.go index 0a7590c..b931349 100644 --- a/internal/upload/uploader.go +++ b/internal/upload/uploader.go @@ -105,9 +105,8 @@ func (u *uploader) Upload( var ce *clientError if errors.As(err, &ce) { return fmt.Errorf( - "upload failed due to non-retryable client error after %d attempt(s) (status=%d): %w", + "upload failed due to non-retryable client error after %d attempt(s): %w", attempt+1, - ce.StatusCode, err, ) }