Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 7 additions & 42 deletions internal/analysis/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,13 @@ import (
"io"
"log/slog"
"net/http"
"strings"
"time"
"unicode"

"github.com/go-authgate/agent-scanner/internal/httperrors"
"github.com/go-authgate/agent-scanner/internal/models"
"github.com/go-authgate/agent-scanner/internal/tlsutil"
)

// clientError is a non-retryable HTTP error (4xx).
type clientError struct {
StatusCode int
Body string
}

func (e *clientError) Error() string {
return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body)
}

// nonRetryableError wraps errors that should not be retried
// (e.g., request construction failures, JSON decode errors).
type nonRetryableError struct {
err error
}

func (e *nonRetryableError) Error() string { return e.err.Error() }
func (e *nonRetryableError) Unwrap() error { return e.err }

// Analyzer performs security analysis on scan results.
type Analyzer interface {
Analyze(ctx context.Context, results []models.ScanPathResult) ([]models.ScanPathResult, error)
Expand Down Expand Up @@ -144,12 +124,12 @@ func (a *remoteAnalyzer) analyzePathResult(
break
}
// Do not retry non-retryable errors (bad URL, JSON decode, etc.)
var nre *nonRetryableError
var nre *httperrors.NonRetryableError
if errors.As(err, &nre) {
return fmt.Errorf("analysis API: %w", err)
}
// Do not retry client errors (4xx)
var ce *clientError
var ce *httperrors.ClientError
if errors.As(err, &ce) {
return fmt.Errorf("analysis API: %w", err)
}
Expand Down Expand Up @@ -182,7 +162,7 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy
bytes.NewReader(body),
)
if err != nil {
return &nonRetryableError{err: err}
return &httperrors.NonRetryableError{Err: err}
}
req.Header.Set("Content-Type", "application/json")

Expand All @@ -194,30 +174,15 @@ func (a *remoteAnalyzer) doRequest(ctx context.Context, body []byte, resp *analy

if httpResp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(httpResp.Body, 4096))
bodySnippet := sanitizeBodySnippet(string(respBody), 512)
bodySnippet := httperrors.SanitizeBodySnippet(string(respBody), 512)
if httpResp.StatusCode < 500 {
return &clientError{StatusCode: httpResp.StatusCode, Body: bodySnippet}
return &httperrors.ClientError{StatusCode: httpResp.StatusCode, Body: bodySnippet}
}
return fmt.Errorf("status %d: %s", httpResp.StatusCode, bodySnippet)
}

if err := json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
return &nonRetryableError{err: fmt.Errorf("decode response: %w", err)}
return &httperrors.NonRetryableError{Err: fmt.Errorf("decode response: %w", err)}
}
return nil
}

// sanitizeBodySnippet truncates s to approximately maxLen bytes (the
// returned string may be slightly longer due to a " [truncated]" suffix)
// and replaces all Unicode control characters with spaces for safe single-line logging.
func sanitizeBodySnippet(s string, maxLen int) string {
if len(s) > maxLen {
s = s[:maxLen] + " [truncated]"
}
return strings.Map(func(r rune) rune {
if unicode.IsControl(r) {
return ' '
}
return r
}, s)
}
80 changes: 80 additions & 0 deletions internal/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,83 @@ func TestParseControlServers_MoreServersThanIdentifiers(t *testing.T) {
t.Errorf("server[2] Identifier = %q, want empty", servers[2].Identifier)
}
}

func TestParseControlServers_WithHeaders(t *testing.T) {
orig := scanFlags
defer func() { scanFlags = orig }()

scanFlags = ScanFlags{
ControlServers: []string{"https://a.example.com", "https://b.example.com"},
ControlHeaders: []string{"Authorization: Bearer token123; X-Custom: value"},
}

servers := parseControlServers()
if len(servers) != 2 {
t.Fatalf("expected 2 servers, got %d", len(servers))
}

if servers[0].Headers == nil {
t.Fatal("expected headers for server[0]")
}
if servers[0].Headers["Authorization"] != "Bearer token123" {
t.Errorf(
"Authorization = %q, want %q",
servers[0].Headers["Authorization"],
"Bearer token123",
)
}
if servers[0].Headers["X-Custom"] != "value" {
t.Errorf("X-Custom = %q, want %q", servers[0].Headers["X-Custom"], "value")
}

// server[1] has no matching header entry
if servers[1].Headers != nil {
t.Errorf("expected nil headers for server[1], got %v", servers[1].Headers)
}
}

func TestParseHeaders(t *testing.T) {
tests := []struct {
name string
raw string
want map[string]string
}{
{
name: "single header",
raw: "Authorization: Bearer abc",
want: map[string]string{"Authorization": "Bearer abc"},
},
{
name: "multiple headers",
raw: "Key1: Val1; Key2: Val2",
want: map[string]string{"Key1": "Val1", "Key2": "Val2"},
},
{
name: "empty string",
raw: "",
want: nil,
},
{
name: "whitespace only",
raw: " ; ",
want: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parseHeaders(tt.raw)
if tt.want == nil {
if got != nil {
t.Errorf("expected nil, got %v", got)
}
return
}
for k, v := range tt.want {
if got[k] != v {
t.Errorf("header %q = %q, want %q", k, got[k], v)
}
}
})
}
}
3 changes: 0 additions & 3 deletions internal/cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ type CommonFlags struct {
StorageFile string
AnalysisURL string
VerificationH []string
OAuthTokensPath string
Verbose bool
PrintErrors bool
PrintFullDescs bool
Expand Down Expand Up @@ -48,8 +47,6 @@ func addCommonFlags(cmd *cobra.Command) {
cmd.Flags().StringVar(&commonFlags.AnalysisURL, "analysis-url", "", "Verification server URL")
cmd.Flags().
StringSliceVar(&commonFlags.VerificationH, "verification-H", nil, "Additional headers for verification API")
cmd.Flags().
StringVar(&commonFlags.OAuthTokensPath, "mcp-oauth-tokens-path", "", "OAuth token storage path")
cmd.Flags().BoolVar(&commonFlags.Verbose, "verbose", false, "Enable verbose logging")
cmd.Flags().
BoolVar(&commonFlags.PrintErrors, "print-errors", false, "Print server startup errors/tracebacks")
Expand Down
24 changes: 24 additions & 0 deletions internal/cli/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"time"

"github.com/go-authgate/agent-scanner/internal/analysis"
Expand Down Expand Up @@ -108,7 +109,30 @@ func parseControlServers() []pipeline.ControlServerConfig {
if i < len(scanFlags.ControlIdentifier) {
cs.Identifier = scanFlags.ControlIdentifier[i]
}
if i < len(scanFlags.ControlHeaders) {
cs.Headers = parseHeaders(scanFlags.ControlHeaders[i])
}
servers = append(servers, cs)
}
return servers
}

// parseHeaders parses a semicolon-separated header string into a map.
// Each header is in "Key: Value" format.
func parseHeaders(raw string) map[string]string {
headers := make(map[string]string)
for part := range strings.SplitSeq(raw, ";") {
part = strings.TrimSpace(part)
if key, value, ok := strings.Cut(part, ":"); ok {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key != "" {
headers[key] = value
}
}
}
if len(headers) == 0 {
return nil
}
return headers
}
41 changes: 41 additions & 0 deletions internal/httperrors/httperrors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package httperrors

import (
"fmt"
"strings"
"unicode"
)

// ClientError is a non-retryable HTTP error (4xx).
type ClientError struct {
StatusCode int
Body string
}

func (e *ClientError) Error() string {
return fmt.Sprintf("status %d: %s", e.StatusCode, e.Body)
}

// NonRetryableError wraps errors that should not be retried
// (e.g., request construction failures, JSON decode errors).
type NonRetryableError struct {
Err error
}

func (e *NonRetryableError) Error() string { return e.Err.Error() }
func (e *NonRetryableError) Unwrap() error { return e.Err }

// SanitizeBodySnippet truncates s to approximately maxLen bytes (the
// returned string may be slightly longer due to a " [truncated]" suffix)
// and replaces all Unicode control characters with spaces for safe single-line logging.
func SanitizeBodySnippet(s string, maxLen int) string {
if len(s) > maxLen {
s = s[:maxLen] + " [truncated]"
}
return strings.Map(func(r rune) rune {
if unicode.IsControl(r) {
return ' '
}
return r
}, s)
}
77 changes: 77 additions & 0 deletions internal/httperrors/httperrors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package httperrors

import (
"errors"
"strings"
"testing"
)

func TestClientError_Error(t *testing.T) {
err := &ClientError{StatusCode: 403, Body: "forbidden"}
got := err.Error()
if got != "status 403: forbidden" {
t.Errorf("got %q, want %q", got, "status 403: forbidden")
}
}

func TestNonRetryableError_Unwrap(t *testing.T) {
inner := errors.New("bad request")
err := &NonRetryableError{Err: inner}

if err.Error() != "bad request" {
t.Errorf("Error() = %q, want %q", err.Error(), "bad request")
}
if !errors.Is(err, inner) {
t.Error("expected errors.Is to find inner error")
}
}

func TestSanitizeBodySnippet(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
checks func(t *testing.T, result string)
}{
{
name: "short string unchanged",
input: "hello world",
maxLen: 100,
checks: func(t *testing.T, result string) {
if result != "hello world" {
t.Errorf("got %q", result)
}
},
},
{
name: "truncated",
input: "abcdefghij",
maxLen: 5,
checks: func(t *testing.T, result string) {
if !strings.HasPrefix(result, "abcde") {
t.Errorf("got %q, want prefix 'abcde'", result)
}
if !strings.Contains(result, "[truncated]") {
t.Error("expected [truncated] suffix")
}
},
},
{
name: "control chars replaced",
input: "line1\nline2\ttab\x00null",
maxLen: 100,
checks: func(t *testing.T, result string) {
if strings.ContainsAny(result, "\n\t\x00") {
t.Errorf("control characters not replaced: %q", result)
}
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeBodySnippet(tt.input, tt.maxLen)
tt.checks(t, result)
})
}
}
Loading
Loading