diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 1908fbd..1c7bbaf 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -32,10 +32,6 @@ jobs: with: go-version: ${{ matrix.go }} - - name: Run Generate - run: | - make generate - - name: Setup golangci-lint uses: golangci/golangci-lint-action@v9 with: diff --git a/main.go b/main.go index f8c3595..3585854 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "syscall" "time" @@ -28,18 +29,18 @@ import ( ) var ( - serverURL string - clientID string - clientSecret string - redirectURI string - callbackPort int - scope string - tokenFile string - tokenStoreMode string - tokenStore credstore.Store[credstore.Token] - configInitialized bool - retryClient *retry.Client - configWarnings []string + serverURL string + clientID string + clientSecret string + redirectURI string + callbackPort int + scope string + tokenFile string + tokenStoreMode string + tokenStore credstore.Store[credstore.Token] + configOnce sync.Once + retryClient *retry.Client + configWarnings []string flagServerURL *string flagClientID *string @@ -55,6 +56,7 @@ const ( tokenExchangeTimeout = 10 * time.Second tokenVerificationTimeout = 10 * time.Second refreshTokenTimeout = 10 * time.Second + maxResponseSize = 1 << 20 // 1 MiB ) func init() { @@ -96,16 +98,21 @@ func init() { // initConfig parses flags and initializes all configuration. func initConfig() { - if configInitialized { - return - } - configInitialized = true + configOnce.Do(doInitConfig) +} +func doInitConfig() { flag.Parse() serverURL = getConfig(*flagServerURL, "SERVER_URL", "http://localhost:8080") clientID = getConfig(*flagClientID, "CLIENT_ID", "") clientSecret = getConfig(*flagClientSecret, "CLIENT_SECRET", "") + if *flagClientSecret != "" { + configWarnings = append(configWarnings, + "Client secret passed via command-line flag. "+ + "This may be visible in process listings. "+ + "Consider using CLIENT_SECRET env var or .env file instead.") + } scope = getConfig(*flagScope, "SCOPE", "read write") tokenFile = getConfig(*flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json") @@ -262,6 +269,25 @@ type tokenResponse struct { Scope string `json:"scope"` } +// errResponseTooLarge is returned when a server response exceeds maxResponseSize. +var errResponseTooLarge = fmt.Errorf( + "response body exceeds maximum allowed size of %d bytes", + maxResponseSize, +) + +// readResponseBody reads up to maxResponseSize bytes from r and returns an +// explicit error when the response is too large (rather than silently truncating). +func readResponseBody(r io.Reader) ([]byte, error) { + body, err := io.ReadAll(io.LimitReader(r, maxResponseSize+1)) + if err != nil { + return nil, err + } + if int64(len(body)) > maxResponseSize { + return nil, errResponseTooLarge + } + return body, nil +} + // parseOAuthError attempts to extract a structured OAuth error from a non-200 // response body. Falls back to including the raw body in the error message. func parseOAuthError(statusCode int, body []byte, action string) error { @@ -341,7 +367,7 @@ func exchangeCode(ctx context.Context, code, codeVerifier string) (*tui.TokenSto } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -405,7 +431,7 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*tui.TokenSto } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -472,7 +498,7 @@ func verifyToken(ctx context.Context, accessToken string) (string, error) { } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return "", fmt.Errorf("failed to read response: %w", err) } @@ -531,7 +557,7 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage) defer resp.Body.Close() } - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) } diff --git a/main_test.go b/main_test.go index dc799af..5479c3d 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,8 @@ package main import ( + "bytes" + "errors" "path/filepath" "testing" "time" @@ -276,6 +278,38 @@ func TestInitTokenStore_Invalid(t *testing.T) { } } +func TestReadResponseBody(t *testing.T) { + t.Run("within limit", func(t *testing.T) { + data := bytes.Repeat([]byte("a"), 100) + body, err := readResponseBody(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(body) != 100 { + t.Errorf("expected 100 bytes, got %d", len(body)) + } + }) + + t.Run("exactly at limit", func(t *testing.T) { + data := bytes.Repeat([]byte("a"), maxResponseSize) + body, err := readResponseBody(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(body) != maxResponseSize { + t.Errorf("expected %d bytes, got %d", maxResponseSize, len(body)) + } + }) + + t.Run("exceeds limit", func(t *testing.T) { + data := bytes.Repeat([]byte("a"), maxResponseSize+1) + _, err := readResponseBody(bytes.NewReader(data)) + if !errors.Is(err, errResponseTooLarge) { + t.Errorf("expected errResponseTooLarge, got: %v", err) + } + }) +} + // containsSubstring is a helper to avoid importing strings in tests. func containsSubstring(s, sub string) bool { return len(s) >= len(sub) && findSubstring(s, sub)