diff --git a/callback.go b/callback.go index a03c348..b965548 100644 --- a/callback.go +++ b/callback.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/subtle" "fmt" "html" "net" @@ -71,9 +72,10 @@ func startCallbackServer( return } - // Validate state (CSRF protection). + // Validate state (CSRF protection) using constant-time comparison. state := q.Get("state") - if state != expectedState { + if len(state) != len(expectedState) || + subtle.ConstantTimeCompare([]byte(state), []byte(expectedState)) != 1 { writeCallbackPage(w, false, "state_mismatch", "State parameter does not match. Possible CSRF attack.") sendResult(callbackResult{ diff --git a/main.go b/main.go index 3585854..718348a 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,6 @@ var ( callbackPort int scope string tokenFile string - tokenStoreMode string tokenStore credstore.Store[credstore.Token] configOnce sync.Once retryClient *retry.Client @@ -170,16 +169,16 @@ func doInitConfig() { var err error retryClient, err = retry.NewBackgroundClient(retry.WithHTTPClient(baseHTTPClient)) if err != nil { - panic(fmt.Sprintf("failed to create retry client: %v", err)) + fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err) + os.Exit(1) } const defaultKeyringService = "authgate-oauth-cli" - tokenStoreMode = getConfig(*flagTokenStore, "TOKEN_STORE", "auto") + tokenStoreMode := getConfig(*flagTokenStore, "TOKEN_STORE", "auto") var warnings []string - var err2 error - tokenStore, warnings, err2 = initTokenStore(tokenStoreMode, tokenFile, defaultKeyringService) - if err2 != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err2) + tokenStore, warnings, err = initTokenStore(tokenStoreMode, tokenFile, defaultKeyringService) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } configWarnings = append(configWarnings, warnings...) @@ -298,6 +297,16 @@ func parseOAuthError(statusCode int, body []byte, action string) error { return fmt.Errorf("%s failed with status %d: %s", action, statusCode, string(body)) } +// isRefreshTokenError checks whether the response body indicates an expired +// or invalid refresh token (invalid_grant / invalid_token). +func isRefreshTokenError(body []byte) bool { + var errResp ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil { + return errResp.Error == "invalid_grant" || errResp.Error == "invalid_token" + } + return false +} + // validateTokenResponse performs basic sanity checks on a token response. func validateTokenResponse(accessToken, tokenType string, expiresIn int) error { if accessToken == "" { @@ -438,12 +447,8 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*tui.TokenSto if resp.StatusCode != http.StatusOK { // Check for expired/invalid refresh token before general error handling. - var errResp ErrorResponse - if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Error != "" { - if errResp.Error == "invalid_grant" || errResp.Error == "invalid_token" { - return nil, tui.ErrRefreshTokenExpired - } - return nil, fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription) + if isRefreshTokenError(body) { + return nil, tui.ErrRefreshTokenExpired } return nil, parseOAuthError(resp.StatusCode, body, "refresh") } @@ -522,10 +527,9 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage) if err != nil { return fmt.Errorf("API request failed: %w", err) } - defer resp.Body.Close() if resp.StatusCode == http.StatusUnauthorized { - // Drain and close body immediately so the HTTP transport can reuse the connection. + // Drain and close body so the HTTP transport can reuse the connection. _, _ = io.Copy(io.Discard, resp.Body) resp.Body.Close() @@ -554,8 +558,8 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage) if err != nil { return fmt.Errorf("retry failed: %w", err) } - defer resp.Body.Close() } + defer resp.Body.Close() body, err := readResponseBody(resp.Body) if err != nil { diff --git a/main_test.go b/main_test.go index 5479c3d..76215a4 100644 --- a/main_test.go +++ b/main_test.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "path/filepath" + "strings" "testing" "time" @@ -179,7 +180,7 @@ func TestBuildAuthURL_ContainsRequiredParams(t *testing.T) { "code_challenge=test-challenge", "code_challenge_method=S256", } { - if !containsSubstring(u, want) { + if !strings.Contains(u, want) { t.Errorf("auth URL missing %q\nURL: %s", want, u) } } @@ -273,7 +274,7 @@ func TestInitTokenStore_Invalid(t *testing.T) { if store != nil { t.Errorf("expected nil store on error, got %T", store) } - if !containsSubstring(err.Error(), "invalid token-store value") { + if !strings.Contains(err.Error(), "invalid token-store value") { t.Errorf("unexpected error message: %v", err) } } @@ -309,17 +310,3 @@ func TestReadResponseBody(t *testing.T) { } }) } - -// containsSubstring is a helper to avoid importing strings in tests. -func containsSubstring(s, sub string) bool { - return len(s) >= len(sub) && findSubstring(s, sub) -} - -func findSubstring(s, sub string) bool { - for i := 0; i <= len(s)-len(sub); i++ { - if s[i:i+len(sub)] == sub { - return true - } - } - return false -}