Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ jobs:
with:
go-version: ${{ matrix.go }}

- name: Run Generate
run: |
make generate

- name: Setup golangci-lint
uses: golangci/golangci-lint-action@v9
with:
Expand Down
66 changes: 46 additions & 20 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"

Expand All @@ -28,18 +29,18 @@ import (
)

var (
serverURL string
clientID string
clientSecret string
redirectURI string
callbackPort int
scope string
tokenFile string
tokenStoreMode string
tokenStore credstore.Store[credstore.Token]
configInitialized bool
retryClient *retry.Client
configWarnings []string
serverURL string
clientID string
clientSecret string
redirectURI string
callbackPort int
scope string
tokenFile string
tokenStoreMode string
tokenStore credstore.Store[credstore.Token]
configOnce sync.Once
retryClient *retry.Client
configWarnings []string

flagServerURL *string
flagClientID *string
Expand All @@ -55,6 +56,7 @@ const (
tokenExchangeTimeout = 10 * time.Second
tokenVerificationTimeout = 10 * time.Second
refreshTokenTimeout = 10 * time.Second
maxResponseSize = 1 << 20 // 1 MiB
)

func init() {
Expand Down Expand Up @@ -96,16 +98,21 @@ func init() {

// initConfig parses flags and initializes all configuration.
func initConfig() {
if configInitialized {
return
}
configInitialized = true
configOnce.Do(doInitConfig)
}

func doInitConfig() {
flag.Parse()

serverURL = getConfig(*flagServerURL, "SERVER_URL", "http://localhost:8080")
clientID = getConfig(*flagClientID, "CLIENT_ID", "")
clientSecret = getConfig(*flagClientSecret, "CLIENT_SECRET", "")
if *flagClientSecret != "" {
configWarnings = append(configWarnings,
"Client secret passed via command-line flag. "+
"This may be visible in process listings. "+
"Consider using CLIENT_SECRET env var or .env file instead.")
}
Comment on lines +104 to +115
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new warning is added when -client-secret is provided, but there are no tests validating that configWarnings includes this message when the flag is set. Consider factoring the warning decision into a small helper (or resetting flag.CommandLine in tests) so this behavior is covered and doesn’t regress.

Suggested change
func doInitConfig() {
flag.Parse()
serverURL = getConfig(*flagServerURL, "SERVER_URL", "http://localhost:8080")
clientID = getConfig(*flagClientID, "CLIENT_ID", "")
clientSecret = getConfig(*flagClientSecret, "CLIENT_SECRET", "")
if *flagClientSecret != "" {
configWarnings = append(configWarnings,
"Client secret passed via command-line flag. "+
"This may be visible in process listings. "+
"Consider using CLIENT_SECRET env var or .env file instead.")
}
// computeConfigWarningsFromFlags returns configuration warnings derived solely
// from command-line flag values. This is separated from doInitConfig so that
// the warning behavior can be unit tested without relying on global flag state.
func computeConfigWarningsFromFlags(clientSecretFlag string) []string {
var warnings []string
if clientSecretFlag != "" {
warnings = append(warnings,
"Client secret passed via command-line flag. "+
"This may be visible in process listings. "+
"Consider using CLIENT_SECRET env var or .env file instead.")
}
return warnings
}
func doInitConfig() {
flag.Parse()
serverURL = getConfig(*flagServerURL, "SERVER_URL", "http://localhost:8080")
clientID = getConfig(*flagClientID, "CLIENT_ID", "")
clientSecret = getConfig(*flagClientSecret, "CLIENT_SECRET", "")
configWarnings = append(configWarnings, computeConfigWarningsFromFlags(*flagClientSecret)...)

Copilot uses AI. Check for mistakes.
scope = getConfig(*flagScope, "SCOPE", "read write")
tokenFile = getConfig(*flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json")

Expand Down Expand Up @@ -262,6 +269,25 @@ type tokenResponse struct {
Scope string `json:"scope"`
}

// errResponseTooLarge is returned when a server response exceeds maxResponseSize.
var errResponseTooLarge = fmt.Errorf(
"response body exceeds maximum allowed size of %d bytes",
maxResponseSize,
)

// readResponseBody reads up to maxResponseSize bytes from r and returns an
// explicit error when the response is too large (rather than silently truncating).
func readResponseBody(r io.Reader) ([]byte, error) {
body, err := io.ReadAll(io.LimitReader(r, maxResponseSize+1))
if err != nil {
return nil, err
}
if int64(len(body)) > maxResponseSize {
return nil, errResponseTooLarge
}
return body, nil
}

// parseOAuthError attempts to extract a structured OAuth error from a non-200
// response body. Falls back to including the raw body in the error message.
func parseOAuthError(statusCode int, body []byte, action string) error {
Expand Down Expand Up @@ -341,7 +367,7 @@ func exchangeCode(ctx context.Context, code, codeVerifier string) (*tui.TokenSto
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
Expand Down Expand Up @@ -405,7 +431,7 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*tui.TokenSto
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
Expand Down Expand Up @@ -472,7 +498,7 @@ func verifyToken(ctx context.Context, accessToken string) (string, error) {
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
Expand Down Expand Up @@ -531,7 +557,7 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
defer resp.Body.Close()
}

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
Expand Down
34 changes: 34 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"bytes"
"errors"
"path/filepath"
"testing"
"time"
Expand Down Expand Up @@ -276,6 +278,38 @@ func TestInitTokenStore_Invalid(t *testing.T) {
}
}

func TestReadResponseBody(t *testing.T) {
t.Run("within limit", func(t *testing.T) {
data := bytes.Repeat([]byte("a"), 100)
body, err := readResponseBody(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(body) != 100 {
t.Errorf("expected 100 bytes, got %d", len(body))
}
})

t.Run("exactly at limit", func(t *testing.T) {
data := bytes.Repeat([]byte("a"), maxResponseSize)
body, err := readResponseBody(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(body) != maxResponseSize {
t.Errorf("expected %d bytes, got %d", maxResponseSize, len(body))
}
})

t.Run("exceeds limit", func(t *testing.T) {
data := bytes.Repeat([]byte("a"), maxResponseSize+1)
_, err := readResponseBody(bytes.NewReader(data))
if !errors.Is(err, errResponseTooLarge) {
t.Errorf("expected errResponseTooLarge, got: %v", err)
}
})
}

// containsSubstring is a helper to avoid importing strings in tests.
func containsSubstring(s, sub string) bool {
return len(s) >= len(sub) && findSubstring(s, sub)
Expand Down
Loading