From 63b8e1252e8a20c1db85bd51a723a40db251cea3 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Thu, 2 Apr 2026 01:29:50 +0800 Subject: [PATCH 1/5] feat(cli): add configurable timeouts, OIDC discovery, and token revocation - Make 7 hardcoded timeouts configurable via CLI flags and env vars with graceful fallback to defaults on invalid input - Integrate OIDC Discovery to auto-resolve OAuth endpoints from .well-known/openid-configuration with hardcoded path fallback - Add server-side token revocation (RFC 7009) to `token delete` with --local-only flag to skip remote revocation - Add getDurationConfig and getInt64Config config resolution helpers - Add ResolvedEndpoints struct and resolveEndpoints function - Add discovery_test.go with success and fallback scenario tests - Add revocation tests covering success, graceful degradation, and local-only mode Co-Authored-By: Claude Opus 4.6 (1M context) --- auth.go | 16 ++-- browser_flow.go | 12 +-- callback.go | 12 +-- callback_test.go | 2 +- config.go | 168 +++++++++++++++++++++++++++++++-- config_test.go | 65 +++++++++++++ device_flow.go | 14 +-- discovery_test.go | 175 +++++++++++++++++++++++++++++++++++ main.go | 1 + main_test.go | 18 +++- token_cmd.go | 108 ++++++++++++++++++++-- token_cmd_test.go | 230 +++++++++++++++++++++++++++++++++++++++++++++- tokens.go | 6 +- userinfo.go | 8 +- 14 files changed, 777 insertions(+), 58 deletions(-) create mode 100644 discovery_test.go diff --git a/auth.go b/auth.go index 3c1017c..646f998 100644 --- a/auth.go +++ b/auth.go @@ -63,7 +63,7 @@ func refreshAccessToken( cfg *AppConfig, refreshToken string, ) (*credstore.Token, error) { - ctx, cancel := context.WithTimeout(ctx, refreshTokenTimeout) + ctx, cancel := context.WithTimeout(ctx, cfg.RefreshTokenTimeout) defer cancel() data := url.Values{} @@ -74,7 +74,7 @@ func refreshAccessToken( data.Set("client_secret", cfg.ClientSecret) } - tokenResp, err := doTokenExchange(ctx, cfg, cfg.ServerURL+"/oauth/token", data, + tokenResp, err := doTokenExchange(ctx, cfg, cfg.Endpoints.TokenURL, data, func(errResp ErrorResponse, _ []byte) error { if errResp.Error == "invalid_grant" || errResp.Error == "invalid_token" { return ErrRefreshTokenExpired @@ -101,10 +101,10 @@ func refreshAccessToken( // verifyToken verifies an access token with the OAuth server. func verifyToken(ctx context.Context, cfg *AppConfig, accessToken string) (string, error) { - ctx, cancel := context.WithTimeout(ctx, tokenVerificationTimeout) + ctx, cancel := context.WithTimeout(ctx, cfg.TokenVerificationTimeout) defer cancel() - resp, err := cfg.RetryClient.Get(ctx, cfg.ServerURL+"/oauth/tokeninfo", + resp, err := cfg.RetryClient.Get(ctx, cfg.Endpoints.TokenInfoURL, retry.WithHeader("Authorization", "Bearer "+accessToken), ) if err != nil { @@ -112,7 +112,7 @@ func verifyToken(ctx context.Context, cfg *AppConfig, accessToken string) (strin } defer resp.Body.Close() - body, err := readResponseBody(resp) + body, err := readResponseBody(resp, cfg.MaxResponseBodySize) if err != nil { return "", err } @@ -131,7 +131,7 @@ func makeAPICallWithAutoRefresh( storage *credstore.Token, ui tui.Manager, ) error { - resp, err := cfg.RetryClient.Get(ctx, cfg.ServerURL+"/oauth/tokeninfo", + resp, err := cfg.RetryClient.Get(ctx, cfg.Endpoints.TokenInfoURL, retry.WithHeader("Authorization", "Bearer "+storage.AccessToken), ) if err != nil { @@ -156,7 +156,7 @@ func makeAPICallWithAutoRefresh( ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenRefreshedRetrying}) - resp, err = cfg.RetryClient.Get(ctx, cfg.ServerURL+"/oauth/tokeninfo", + resp, err = cfg.RetryClient.Get(ctx, cfg.Endpoints.TokenInfoURL, retry.WithHeader("Authorization", "Bearer "+storage.AccessToken), ) if err != nil { @@ -165,7 +165,7 @@ func makeAPICallWithAutoRefresh( defer resp.Body.Close() } - body, err := readResponseBody(resp) + body, err := readResponseBody(resp, cfg.MaxResponseBodySize) if err != nil { return err } diff --git a/browser_flow.go b/browser_flow.go index 7707ca8..8e29583 100644 --- a/browser_flow.go +++ b/browser_flow.go @@ -21,7 +21,7 @@ func buildAuthURL(cfg *AppConfig, state string, pkce *PKCEParams) string { params.Set("state", state) params.Set("code_challenge", pkce.Challenge) params.Set("code_challenge_method", pkce.Method) - return cfg.ServerURL + "/oauth/authorize?" + params.Encode() + return cfg.Endpoints.AuthorizeURL + "?" + params.Encode() } // exchangeCode exchanges an authorization code for access + refresh tokens. @@ -30,7 +30,7 @@ func exchangeCode( cfg *AppConfig, code, codeVerifier string, ) (*credstore.Token, error) { - ctx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout) + ctx, cancel := context.WithTimeout(ctx, cfg.TokenExchangeTimeout) defer cancel() data := url.Values{} @@ -44,7 +44,7 @@ func exchangeCode( data.Set("client_secret", cfg.ClientSecret) } - tokenResp, err := doTokenExchange(ctx, cfg, cfg.ServerURL+"/oauth/token", data, nil) + tokenResp, err := doTokenExchange(ctx, cfg, cfg.Endpoints.TokenURL, data, nil) if err != nil { return nil, err } @@ -144,7 +144,7 @@ func performBrowserFlowWithUpdates( return case <-ticker.C: elapsed := time.Since(startTime) - progress := float64(elapsed) / float64(callbackTimeout) + progress := float64(elapsed) / float64(cfg.CallbackTimeout) if progress > 1.0 { progress = 1.0 } @@ -153,7 +153,7 @@ func performBrowserFlowWithUpdates( Progress: progress, Data: map[string]any{ "elapsed": elapsed, - "timeout": callbackTimeout, + "timeout": cfg.CallbackTimeout, }, } select { @@ -167,7 +167,7 @@ func performBrowserFlowWithUpdates( } }() - storage, err := startCallbackServer(ctx, cfg.CallbackPort, state, + storage, err := startCallbackServer(ctx, cfg.CallbackPort, state, cfg.CallbackTimeout, func(callbackCtx context.Context, code string) (*credstore.Token, error) { updates <- tui.FlowUpdate{ Type: tui.StepStart, diff --git a/callback.go b/callback.go index ee65348..42d6c45 100644 --- a/callback.go +++ b/callback.go @@ -51,12 +51,7 @@ func sanitizeTokenExchangeError(_ error) string { return "Token exchange failed. Please try again." } -const ( - // callbackTimeout is how long we wait for the browser to deliver the code. - callbackTimeout = 2 * time.Minute -) - -// ErrCallbackTimeout is returned when no browser callback is received within callbackTimeout. +// ErrCallbackTimeout is returned when no browser callback is received within the callback timeout. // Callers can use errors.Is to distinguish a timeout from other authorization errors // and decide whether to fall back to Device Code Flow. var ErrCallbackTimeout = errors.New("browser authorization timed out") @@ -76,6 +71,7 @@ type callbackResult struct { // // The server shuts itself down after the first request. func startCallbackServer(ctx context.Context, port int, expectedState string, + cbTimeout time.Duration, exchangeFn func(context.Context, string) (*credstore.Token, error), ) (*credstore.Token, error) { resultCh := make(chan callbackResult, 1) @@ -158,7 +154,7 @@ func startCallbackServer(ctx context.Context, port int, expectedState string, _ = srv.Shutdown(shutdownCtx) }() - timer := time.NewTimer(callbackTimeout) + timer := time.NewTimer(cbTimeout) defer timer.Stop() select { @@ -172,7 +168,7 @@ func startCallbackServer(ctx context.Context, port int, expectedState string, return result.Storage, nil case <-timer.C: - return nil, fmt.Errorf("%w after %s", ErrCallbackTimeout, callbackTimeout) + return nil, fmt.Errorf("%w after %s", ErrCallbackTimeout, cbTimeout) case <-ctx.Done(): return nil, ctx.Err() diff --git a/callback_test.go b/callback_test.go index 4cf4195..f7c05e4 100644 --- a/callback_test.go +++ b/callback_test.go @@ -158,7 +158,7 @@ func startCallbackServerAsync( t.Helper() ch := make(chan callbackServerResult, 1) go func() { - storage, err := startCallbackServer(ctx, port, state, exchangeFn) + storage, err := startCallbackServer(ctx, port, state, defaultCallbackTimeout, exchangeFn) ch <- callbackServerResult{storage, err} }() // Give the server a moment to bind. diff --git a/config.go b/config.go index d017ed4..7442302 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "errors" "fmt" @@ -13,6 +14,7 @@ import ( "time" "github.com/go-authgate/sdk-go/credstore" + "github.com/go-authgate/sdk-go/discovery" retry "github.com/appleboy/go-httpretry" "github.com/google/uuid" @@ -34,17 +36,38 @@ var ( flagTokenFile string flagTokenStore string flagDevice bool + + flagTokenExchangeTimeout string + flagTokenVerificationTimeout string + flagRefreshTokenTimeout string + flagDeviceCodeRequestTimeout string + flagCallbackTimeout string + flagUserInfoTimeout string + flagMaxResponseBodySize string ) const ( - tokenExchangeTimeout = 10 * time.Second - tokenVerificationTimeout = 10 * time.Second - refreshTokenTimeout = 10 * time.Second - deviceCodeRequestTimeout = 10 * time.Second - maxResponseBodySize = 1 * 1024 * 1024 // 1 MB — guards against oversized server responses - defaultKeyringService = "authgate-cli" + defaultTokenExchangeTimeout = 10 * time.Second + defaultTokenVerificationTimeout = 10 * time.Second + defaultRefreshTokenTimeout = 10 * time.Second + defaultDeviceCodeRequestTimeout = 10 * time.Second + defaultCallbackTimeout = 2 * time.Minute + defaultUserInfoTimeout = 10 * time.Second + defaultMaxResponseBodySize = 1 * 1024 * 1024 // 1 MB — guards against oversized server responses + defaultKeyringService = "authgate-cli" ) +// ResolvedEndpoints holds the absolute URLs for all OAuth endpoints. +// Populated from OIDC Discovery or from hardcoded fallback paths. +type ResolvedEndpoints struct { + AuthorizeURL string + TokenURL string + DeviceAuthorizationURL string + TokenInfoURL string + UserinfoURL string + RevocationURL string +} + // AppConfig holds all resolved configuration for the CLI application. type AppConfig struct { ServerURL string @@ -57,6 +80,20 @@ type AppConfig struct { TokenStoreMode string // "auto", "file", or "keyring" RetryClient *retry.Client Store credstore.Store[credstore.Token] + + // Endpoints holds the resolved OAuth endpoint URLs. + // Populated by resolveEndpoints after loadConfig. + Endpoints ResolvedEndpoints + + // Timeout configuration (resolved from flag → env → default). + // Only populated by loadConfig; zero in loadStoreConfig paths. + TokenExchangeTimeout time.Duration + TokenVerificationTimeout time.Duration + RefreshTokenTimeout time.Duration + DeviceCodeRequestTimeout time.Duration + CallbackTimeout time.Duration + UserInfoTimeout time.Duration + MaxResponseBodySize int64 } // IsPublicClient returns true when no client secret is configured — @@ -85,6 +122,20 @@ func registerFlags(cmd *cobra.Command) { StringVar(&flagTokenStore, "token-store", "", "Token storage backend: auto, file, keyring (default: auto or TOKEN_STORE env)") cmd.PersistentFlags(). BoolVar(&flagDevice, "device", false, "Force Device Code Flow (skip browser detection)") + cmd.PersistentFlags(). + StringVar(&flagTokenExchangeTimeout, "token-exchange-timeout", "", "Timeout for token exchange requests (e.g. 10s, 1m)") + cmd.PersistentFlags(). + StringVar(&flagTokenVerificationTimeout, "token-verification-timeout", "", "Timeout for token verification requests (e.g. 10s, 1m)") + cmd.PersistentFlags(). + StringVar(&flagRefreshTokenTimeout, "refresh-token-timeout", "", "Timeout for token refresh requests (e.g. 10s, 1m)") + cmd.PersistentFlags(). + StringVar(&flagDeviceCodeRequestTimeout, "device-code-request-timeout", "", "Timeout for device code requests (e.g. 10s, 1m)") + cmd.PersistentFlags(). + StringVar(&flagCallbackTimeout, "callback-timeout", "", "Timeout waiting for browser callback (e.g. 2m, 5m)") + cmd.PersistentFlags(). + StringVar(&flagUserInfoTimeout, "userinfo-timeout", "", "Timeout for UserInfo requests (e.g. 10s, 1m)") + cmd.PersistentFlags(). + StringVar(&flagMaxResponseBodySize, "max-response-body-size", "", "Maximum response body size in bytes (e.g. 1048576)") } // loadStoreConfig initialises only the token store and client ID — the minimum @@ -180,6 +231,25 @@ func loadConfig() *AppConfig { panic(fmt.Sprintf("failed to create retry client: %v", err)) } + // Resolve timeout configuration. + cfg.TokenExchangeTimeout = getDurationConfig( + flagTokenExchangeTimeout, "TOKEN_EXCHANGE_TIMEOUT", defaultTokenExchangeTimeout) + cfg.TokenVerificationTimeout = getDurationConfig( + flagTokenVerificationTimeout, "TOKEN_VERIFICATION_TIMEOUT", defaultTokenVerificationTimeout) + cfg.RefreshTokenTimeout = getDurationConfig( + flagRefreshTokenTimeout, "REFRESH_TOKEN_TIMEOUT", defaultRefreshTokenTimeout) + cfg.DeviceCodeRequestTimeout = getDurationConfig( + flagDeviceCodeRequestTimeout, + "DEVICE_CODE_REQUEST_TIMEOUT", + defaultDeviceCodeRequestTimeout, + ) + cfg.CallbackTimeout = getDurationConfig( + flagCallbackTimeout, "CALLBACK_TIMEOUT", defaultCallbackTimeout) + cfg.UserInfoTimeout = getDurationConfig( + flagUserInfoTimeout, "USERINFO_TIMEOUT", defaultUserInfoTimeout) + cfg.MaxResponseBodySize = getInt64Config( + flagMaxResponseBodySize, "MAX_RESPONSE_BODY_SIZE", defaultMaxResponseBodySize) + if cfg.TokenStoreMode == "auto" { if ss, ok := cfg.Store.(*credstore.SecureStore[credstore.Token]); ok && !ss.UseKeyring() { fmt.Fprintln( @@ -239,6 +309,92 @@ func validateServerURL(rawURL string) error { return nil } +// defaultEndpoints returns hardcoded endpoint paths appended to serverURL. +// Used as fallback when OIDC Discovery is unavailable. +func defaultEndpoints(serverURL string) ResolvedEndpoints { + return ResolvedEndpoints{ + AuthorizeURL: serverURL + "/oauth/authorize", + TokenURL: serverURL + "/oauth/token", + DeviceAuthorizationURL: serverURL + "/oauth/device/code", + TokenInfoURL: serverURL + "/oauth/tokeninfo", + UserinfoURL: serverURL + "/oauth/userinfo", + RevocationURL: serverURL + "/oauth/revoke", + } +} + +// resolveEndpoints attempts OIDC Discovery and falls back to hardcoded paths. +func resolveEndpoints(ctx context.Context, cfg *AppConfig) { + disco, err := discovery.NewClient( + cfg.ServerURL, + discovery.WithHTTPClient(cfg.RetryClient), + ) + if err != nil { + fmt.Fprintf(os.Stderr, + "WARNING: OIDC Discovery init failed: %v (using default endpoints)\n", err) + cfg.Endpoints = defaultEndpoints(cfg.ServerURL) + return + } + + fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + meta, err := disco.Fetch(fetchCtx) + if err != nil { + fmt.Fprintf(os.Stderr, + "WARNING: OIDC Discovery fetch failed: %v (using default endpoints)\n", err) + cfg.Endpoints = defaultEndpoints(cfg.ServerURL) + return + } + + ep := meta.Endpoints() + cfg.Endpoints = ResolvedEndpoints{ + AuthorizeURL: ep.AuthorizeURL, + TokenURL: ep.TokenURL, + DeviceAuthorizationURL: ep.DeviceAuthorizationURL, + TokenInfoURL: ep.TokenInfoURL, + UserinfoURL: ep.UserinfoURL, + RevocationURL: ep.RevocationURL, + } +} + +// getDurationConfig resolves a time.Duration from flag → env → default. +// The value is parsed with time.ParseDuration (e.g. "10s", "2m", "1m30s"). +// On parse error or non-positive value, it falls back to the default and prints a warning. +func getDurationConfig(flagValue, envKey string, defaultValue time.Duration) time.Duration { + raw := getConfig(flagValue, envKey, "") + if raw == "" { + return defaultValue + } + d, err := time.ParseDuration(raw) + if err != nil { + fmt.Fprintf(os.Stderr, "WARNING: invalid duration %q for %s, using default %s\n", + raw, envKey, defaultValue) + return defaultValue + } + if d <= 0 { + fmt.Fprintf(os.Stderr, "WARNING: %s must be positive, got %s, using default %s\n", + envKey, d, defaultValue) + return defaultValue + } + return d +} + +// getInt64Config resolves an int64 from flag → env → default. +// On parse error or non-positive value, it falls back to the default and prints a warning. +func getInt64Config(flagValue, envKey string, defaultValue int64) int64 { + raw := getConfig(flagValue, envKey, "") + if raw == "" { + return defaultValue + } + v, err := strconv.ParseInt(raw, 10, 64) + if err != nil || v <= 0 { + fmt.Fprintf(os.Stderr, "WARNING: invalid value %q for %s, using default %d\n", + raw, envKey, defaultValue) + return defaultValue + } + return v +} + // getVersion returns the build version, preferring the ldflags-injected value // and falling back to debug.ReadBuildInfo(). func getVersion() string { diff --git a/config_test.go b/config_test.go index 87bf7b2..abe8b77 100644 --- a/config_test.go +++ b/config_test.go @@ -3,6 +3,7 @@ package main import ( "strings" "testing" + "time" "github.com/go-authgate/sdk-go/credstore" ) @@ -79,3 +80,67 @@ func TestNewTokenStore(t *testing.T) { }) } } + +func TestGetDurationConfig(t *testing.T) { + tests := []struct { + name string + flag string + envVal string + def time.Duration + want time.Duration + }{ + {"default when empty", "", "", 10 * time.Second, 10 * time.Second}, + {"flag value", "30s", "", 10 * time.Second, 30 * time.Second}, + {"env value", "", "1m", 10 * time.Second, 1 * time.Minute}, + {"flag takes precedence over env", "5s", "1m", 10 * time.Second, 5 * time.Second}, + {"invalid falls back to default", "notaduration", "", 10 * time.Second, 10 * time.Second}, + {"negative falls back to default", "-5s", "", 10 * time.Second, 10 * time.Second}, + {"zero falls back to default", "0s", "", 10 * time.Second, 10 * time.Second}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.envVal != "" { + t.Setenv("TEST_DUR_CFG", tc.envVal) + } else { + t.Setenv("TEST_DUR_CFG", "") + } + got := getDurationConfig(tc.flag, "TEST_DUR_CFG", tc.def) + if got != tc.want { + t.Errorf("getDurationConfig() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestGetInt64Config(t *testing.T) { + tests := []struct { + name string + flag string + envVal string + def int64 + want int64 + }{ + {"default when empty", "", "", 1024, 1024}, + {"flag value", "2048", "", 1024, 2048}, + {"env value", "", "4096", 1024, 4096}, + {"flag takes precedence", "512", "4096", 1024, 512}, + {"invalid falls back to default", "abc", "", 1024, 1024}, + {"negative falls back to default", "-100", "", 1024, 1024}, + {"zero falls back to default", "0", "", 1024, 1024}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.envVal != "" { + t.Setenv("TEST_INT_CFG", tc.envVal) + } else { + t.Setenv("TEST_INT_CFG", "") + } + got := getInt64Config(tc.flag, "TEST_INT_CFG", tc.def) + if got != tc.want { + t.Errorf("getInt64Config() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/device_flow.go b/device_flow.go index 669a20d..b861091 100644 --- a/device_flow.go +++ b/device_flow.go @@ -33,14 +33,14 @@ const ( // requestDeviceCode requests a device code from the OAuth server. func requestDeviceCode(ctx context.Context, cfg *AppConfig) (*oauth2.DeviceAuthResponse, error) { - reqCtx, cancel := context.WithTimeout(ctx, deviceCodeRequestTimeout) + reqCtx, cancel := context.WithTimeout(ctx, cfg.DeviceCodeRequestTimeout) defer cancel() data := url.Values{} data.Set("client_id", cfg.ClientID) data.Set("scope", cfg.Scope) - resp, err := cfg.RetryClient.Post(reqCtx, cfg.ServerURL+"/oauth/device/code", + resp, err := cfg.RetryClient.Post(reqCtx, cfg.Endpoints.DeviceAuthorizationURL, retry.WithBody("application/x-www-form-urlencoded", strings.NewReader(data.Encode())), ) if err != nil { @@ -48,7 +48,7 @@ func requestDeviceCode(ctx context.Context, cfg *AppConfig) (*oauth2.DeviceAuthR } defer resp.Body.Close() - body, err := readResponseBody(resp) + body, err := readResponseBody(resp, cfg.MaxResponseBodySize) if err != nil { return nil, err } @@ -146,7 +146,7 @@ func exchangeDeviceCode( cfg *AppConfig, tokenURL, cID, deviceCode string, ) (*oauth2.Token, error) { - reqCtx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout) + reqCtx, cancel := context.WithTimeout(ctx, cfg.TokenExchangeTimeout) defer cancel() data := url.Values{} @@ -162,7 +162,7 @@ func exchangeDeviceCode( } defer resp.Body.Close() - body, err := readResponseBody(resp) + body, err := readResponseBody(resp, cfg.MaxResponseBodySize) if err != nil { return nil, err } @@ -205,8 +205,8 @@ func performDeviceFlowWithUpdates( config := &oauth2.Config{ ClientID: cfg.ClientID, Endpoint: oauth2.Endpoint{ - DeviceAuthURL: cfg.ServerURL + "/oauth/device/code", - TokenURL: cfg.ServerURL + "/oauth/token", + DeviceAuthURL: cfg.Endpoints.DeviceAuthorizationURL, + TokenURL: cfg.Endpoints.TokenURL, }, Scopes: strings.Fields(cfg.Scope), } diff --git a/discovery_test.go b/discovery_test.go new file mode 100644 index 0000000..8f79963 --- /dev/null +++ b/discovery_test.go @@ -0,0 +1,175 @@ +package main + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + retry "github.com/appleboy/go-httpretry" +) + +func TestDefaultEndpoints(t *testing.T) { + ep := defaultEndpoints("http://example.com") + + want := map[string]string{ + "AuthorizeURL": "http://example.com/oauth/authorize", + "TokenURL": "http://example.com/oauth/token", + "DeviceAuthorizationURL": "http://example.com/oauth/device/code", + "TokenInfoURL": "http://example.com/oauth/tokeninfo", + "UserinfoURL": "http://example.com/oauth/userinfo", + "RevocationURL": "http://example.com/oauth/revoke", + } + + got := map[string]string{ + "AuthorizeURL": ep.AuthorizeURL, + "TokenURL": ep.TokenURL, + "DeviceAuthorizationURL": ep.DeviceAuthorizationURL, + "TokenInfoURL": ep.TokenInfoURL, + "UserinfoURL": ep.UserinfoURL, + "RevocationURL": ep.RevocationURL, + } + + for field, wantVal := range want { + if got[field] != wantVal { + t.Errorf("%s = %q, want %q", field, got[field], wantVal) + } + } +} + +func TestResolveEndpoints_Success(t *testing.T) { + // Use a pointer so the handler closure can reference the final URL. + var srvURL string + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + meta := map[string]any{ + "issuer": srvURL, + "authorization_endpoint": srvURL + "/auth", + "token_endpoint": srvURL + "/token", + "device_authorization_endpoint": srvURL + "/device", + "userinfo_endpoint": srvURL + "/userinfo", + "revocation_endpoint": srvURL + "/revoke", + "response_types_supported": []string{"code"}, + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": []string{"RS256"}, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(meta); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }), + ) + defer srv.Close() + srvURL = srv.URL + + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } + cfg := &AppConfig{ + ServerURL: srv.URL, + RetryClient: rc, + } + + resolveEndpoints(context.Background(), cfg) + + if cfg.Endpoints.AuthorizeURL != srv.URL+"/auth" { + t.Errorf("AuthorizeURL = %q, want %q", cfg.Endpoints.AuthorizeURL, srv.URL+"/auth") + } + if cfg.Endpoints.TokenURL != srv.URL+"/token" { + t.Errorf("TokenURL = %q, want %q", cfg.Endpoints.TokenURL, srv.URL+"/token") + } + if cfg.Endpoints.DeviceAuthorizationURL != srv.URL+"/device" { + t.Errorf( + "DeviceAuthorizationURL = %q, want %q", + cfg.Endpoints.DeviceAuthorizationURL, + srv.URL+"/device", + ) + } + if cfg.Endpoints.UserinfoURL != srv.URL+"/userinfo" { + t.Errorf("UserinfoURL = %q, want %q", cfg.Endpoints.UserinfoURL, srv.URL+"/userinfo") + } + if cfg.Endpoints.RevocationURL != srv.URL+"/revoke" { + t.Errorf("RevocationURL = %q, want %q", cfg.Endpoints.RevocationURL, srv.URL+"/revoke") + } +} + +func TestResolveEndpoints_FallbackOn404(t *testing.T) { + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + }), + ) + defer srv.Close() + + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } + cfg := &AppConfig{ + ServerURL: srv.URL, + RetryClient: rc, + } + + resolveEndpoints(context.Background(), cfg) + + // Should fall back to defaults + want := defaultEndpoints(srv.URL) + if cfg.Endpoints != want { + t.Errorf("expected default endpoints on 404 fallback\ngot: %+v\nwant: %+v", + cfg.Endpoints, want) + } +} + +func TestResolveEndpoints_FallbackOnTimeout(t *testing.T) { + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + time.Sleep(5 * time.Second) + }), + ) + defer srv.Close() + + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } + cfg := &AppConfig{ + ServerURL: srv.URL, + RetryClient: rc, + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + resolveEndpoints(ctx, cfg) + + want := defaultEndpoints(srv.URL) + if cfg.Endpoints != want { + t.Errorf("expected default endpoints on timeout fallback\ngot: %+v\nwant: %+v", + cfg.Endpoints, want) + } +} + +func TestResolveEndpoints_FallbackOnNetworkError(t *testing.T) { + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } + cfg := &AppConfig{ + ServerURL: "http://127.0.0.1:1", // port 1 is unreachable + RetryClient: rc, + } + + resolveEndpoints(context.Background(), cfg) + + want := defaultEndpoints(cfg.ServerURL) + if cfg.Endpoints != want { + t.Errorf("expected default endpoints on network error\ngot: %+v\nwant: %+v", + cfg.Endpoints, want) + } +} diff --git a/main.go b/main.go index 29dc220..83110f4 100644 --- a/main.go +++ b/main.go @@ -45,6 +45,7 @@ func buildRootCmd() *cobra.Command { uiManager := tui.SelectManager() ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() + resolveEndpoints(ctx, cfg) if code := run(ctx, uiManager, cfg); code != 0 { return exitCodeError(code) } diff --git a/main_test.go b/main_test.go index 123305b..80e9e8e 100644 --- a/main_test.go +++ b/main_test.go @@ -24,12 +24,23 @@ func testConfig(t *testing.T) *AppConfig { if err != nil { t.Fatalf("failed to create retry client: %v", err) } + serverURL := "http://localhost:8080" return &AppConfig{ - ServerURL: "http://localhost:8080", + ServerURL: serverURL, ClientID: "test-client", Scope: "read write", RetryClient: rc, - Store: credstore.NewTokenFileStore(filepath.Join(t.TempDir(), "tokens.json")), + Store: credstore.NewTokenFileStore( + filepath.Join(t.TempDir(), "tokens.json"), + ), + Endpoints: defaultEndpoints(serverURL), + TokenExchangeTimeout: defaultTokenExchangeTimeout, + TokenVerificationTimeout: defaultTokenVerificationTimeout, + RefreshTokenTimeout: defaultRefreshTokenTimeout, + DeviceCodeRequestTimeout: defaultDeviceCodeRequestTimeout, + CallbackTimeout: defaultCallbackTimeout, + UserInfoTimeout: defaultUserInfoTimeout, + MaxResponseBodySize: defaultMaxResponseBodySize, } } @@ -223,6 +234,7 @@ func TestBuildAuthURL_ContainsRequiredParams(t *testing.T) { ClientID: "my-client-id", RedirectURI: "http://localhost:8888/callback", Scope: "read write", + Endpoints: defaultEndpoints("http://localhost:8080"), } pkce := &PKCEParams{ @@ -300,6 +312,7 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) { cfg := testConfig(t) cfg.ServerURL = srv.URL + cfg.Endpoints = defaultEndpoints(srv.URL) cfg.ClientID = "test-client-rotation" storage, err := refreshAccessToken(context.Background(), cfg, tt.oldRefreshToken) @@ -347,6 +360,7 @@ func TestRequestDeviceCode_WithRetry(t *testing.T) { cfg := testConfig(t) cfg.ServerURL = testServer.URL + cfg.Endpoints = defaultEndpoints(testServer.URL) resp, err := requestDeviceCode(context.Background(), cfg) if err != nil { diff --git a/token_cmd.go b/token_cmd.go index 450f61a..ba3102b 100644 --- a/token_cmd.go +++ b/token_cmd.go @@ -1,12 +1,17 @@ package main import ( + "context" "encoding/json" "errors" "fmt" "io" + "net/http" + "net/url" + "strings" "time" + retry "github.com/appleboy/go-httpretry" "github.com/go-authgate/sdk-go/credstore" "github.com/spf13/cobra" ) @@ -57,15 +62,30 @@ func buildTokenGetCmd() *cobra.Command { } func buildTokenDeleteCmd() *cobra.Command { + var localOnly bool cmd := &cobra.Command{ Use: "delete", Short: "Delete the stored token", - Args: cobra.NoArgs, + Long: `Delete the stored token. + +By default, the token is first revoked on the OAuth server before being +deleted locally. If the server is unreachable, the local token is still +deleted (graceful degradation). + +Use --local-only to skip server revocation and only delete the local token.`, + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - cfg := loadStoreConfig() + var cfg *AppConfig + if localOnly { + cfg = loadStoreConfig() + } else { + cfg = loadConfig() + resolveEndpoints(cmd.Context(), cfg) + } if code := runTokenDelete( - cfg.Store, - cfg.ClientID, + cmd.Context(), + cfg, + localOnly, cmd.OutOrStdout(), cmd.ErrOrStderr(), ); code != 0 { @@ -74,6 +94,8 @@ func buildTokenDeleteCmd() *cobra.Command { return nil }, } + cmd.Flags(). + BoolVar(&localOnly, "local-only", false, "Skip server-side token revocation; only delete the local token") return cmd } @@ -98,26 +120,94 @@ func loadTokenOrFail( return tok, 0 } +const revokeTokenTimeout = 10 * time.Second + // runTokenDelete is the testable core of `token delete`. func runTokenDelete( - store credstore.Store[credstore.Token], - id string, + ctx context.Context, + cfg *AppConfig, + localOnly bool, stdout io.Writer, stderr io.Writer, ) int { // Check existence first — Delete is idempotent and silently succeeds // even when the key is absent. - if _, code := loadTokenOrFail(store, id, stderr); code != 0 { + tok, code := loadTokenOrFail(cfg.Store, cfg.ClientID, stderr) + if code != 0 { return code } - if err := store.Delete(id); err != nil { + + if !localOnly { + if err := revokeTokenOnServer(ctx, cfg, tok, stderr); err != nil { + fmt.Fprintf(stderr, "Warning: server-side revocation failed: %v\n", err) + fmt.Fprintln(stderr, "Proceeding with local token deletion.") + } else { + fmt.Fprintln(stdout, "Token revoked on server.") + } + } + + if err := cfg.Store.Delete(cfg.ClientID); err != nil { fmt.Fprintf(stderr, "Error: failed to delete token: %v\n", err) return 1 } - fmt.Fprintf(stdout, "Token for client-id %q deleted.\n", id) + fmt.Fprintf(stdout, "Token for client-id %q deleted.\n", cfg.ClientID) return 0 } +// revokeTokenOnServer attempts to revoke tokens on the OAuth server (RFC 7009). +// It revokes the refresh token first (long-lived), then the access token. +func revokeTokenOnServer( + ctx context.Context, + cfg *AppConfig, + tok credstore.Token, + stderr io.Writer, +) error { + revokeURL := cfg.Endpoints.RevocationURL + + // Revoke the refresh token first (more important — it's long-lived). + if tok.RefreshToken != "" { + if err := doRevoke(ctx, cfg.RetryClient, revokeURL, tok.RefreshToken); err != nil { + fmt.Fprintf(stderr, "Warning: failed to revoke refresh token: %v\n", err) + } + } + + if tok.AccessToken != "" { + if err := doRevoke(ctx, cfg.RetryClient, revokeURL, tok.AccessToken); err != nil { + return fmt.Errorf("access token revocation: %w", err) + } + } + + return nil +} + +// doRevoke posts a single token to the revocation endpoint (RFC 7009). +func doRevoke( + ctx context.Context, + client *retry.Client, + revokeURL string, + token string, +) error { + ctx, cancel := context.WithTimeout(ctx, revokeTokenTimeout) + defer cancel() + + data := url.Values{"token": {token}} + resp, err := client.Post(ctx, revokeURL, + retry.WithBody( + "application/x-www-form-urlencoded", + strings.NewReader(data.Encode()), + ), + ) + if err != nil { + return fmt.Errorf("revoke request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("revoke returned status %d", resp.StatusCode) + } + return nil +} + // runTokenGet is the testable core of `token get`. func runTokenGet( store credstore.Store[credstore.Token], diff --git a/token_cmd_test.go b/token_cmd_test.go index 17b1393..f9de6b3 100644 --- a/token_cmd_test.go +++ b/token_cmd_test.go @@ -2,13 +2,18 @@ package main import ( "bytes" + "context" "encoding/json" "errors" + "net/http" + "net/http/httptest" "path/filepath" "strings" + "sync" "testing" "time" + retry "github.com/appleboy/go-httpretry" "github.com/go-authgate/sdk-go/credstore" ) @@ -72,10 +77,15 @@ func TestRunTokenDelete(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - store := credstore.NewTokenFileStore(filepath.Join(t.TempDir(), "tokens.json")) + store := credstore.NewTokenFileStore( + filepath.Join(t.TempDir(), "tokens.json"), + ) tc.setup(store) + cfg := &AppConfig{ClientID: "test-id", Store: store} var stdout, stderr bytes.Buffer - code := runTokenDelete(store, "test-id", &stdout, &stderr) + code := runTokenDelete( + context.Background(), cfg, true, &stdout, &stderr, + ) if code != tc.wantCode { t.Errorf("exit code: got %d, want %d", code, tc.wantCode) } @@ -92,6 +102,218 @@ func TestRunTokenDelete(t *testing.T) { } } +func TestRunTokenDelete_ServerRevocation(t *testing.T) { + t.Run("successful revocation and local delete", func(t *testing.T) { + var revokedTokens []string + var mu sync.Mutex + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + mu.Lock() + revokedTokens = append(revokedTokens, r.FormValue("token")) + mu.Unlock() + w.WriteHeader(http.StatusOK) + }), + ) + defer srv.Close() + + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } + store := credstore.NewTokenFileStore( + filepath.Join(t.TempDir(), "tokens.json"), + ) + if err := store.Save("test-id", credstore.Token{ + AccessToken: "access-123", + RefreshToken: "refresh-456", + ExpiresAt: time.Now().Add(time.Hour), + ClientID: "test-id", + }); err != nil { + t.Fatal(err) + } + + cfg := &AppConfig{ + ClientID: "test-id", + ServerURL: srv.URL, + Endpoints: ResolvedEndpoints{RevocationURL: srv.URL + "/oauth/revoke"}, + RetryClient: rc, + Store: store, + } + + var stdout, stderr bytes.Buffer + code := runTokenDelete( + context.Background(), cfg, false, &stdout, &stderr, + ) + if code != 0 { + t.Fatalf("exit code: got %d, want 0; stderr: %s", code, stderr.String()) + } + if !strings.Contains(stdout.String(), "revoked on server") { + t.Errorf("expected 'revoked on server' in stdout, got: %q", stdout.String()) + } + if !strings.Contains(stdout.String(), "deleted") { + t.Errorf("expected 'deleted' in stdout, got: %q", stdout.String()) + } + + mu.Lock() + defer mu.Unlock() + if len(revokedTokens) != 2 { + t.Fatalf("expected 2 revoke calls, got %d", len(revokedTokens)) + } + if revokedTokens[0] != "refresh-456" { + t.Errorf("expected refresh token revoked first, got %q", revokedTokens[0]) + } + if revokedTokens[1] != "access-123" { + t.Errorf("expected access token revoked second, got %q", revokedTokens[1]) + } + }) + + t.Run("server error graceful degradation", func(t *testing.T) { + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }), + ) + defer srv.Close() + + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } + store := credstore.NewTokenFileStore( + filepath.Join(t.TempDir(), "tokens.json"), + ) + if err := store.Save("test-id", credstore.Token{ + AccessToken: "access-123", + ExpiresAt: time.Now().Add(time.Hour), + ClientID: "test-id", + }); err != nil { + t.Fatal(err) + } + + cfg := &AppConfig{ + ClientID: "test-id", + ServerURL: srv.URL, + Endpoints: ResolvedEndpoints{RevocationURL: srv.URL + "/oauth/revoke"}, + RetryClient: rc, + Store: store, + } + + var stdout, stderr bytes.Buffer + code := runTokenDelete( + context.Background(), cfg, false, &stdout, &stderr, + ) + if code != 0 { + t.Fatalf("exit code: got %d, want 0", code) + } + if !strings.Contains(stderr.String(), "Warning") { + t.Errorf("expected warning in stderr, got: %q", stderr.String()) + } + if !strings.Contains(stdout.String(), "deleted") { + t.Errorf("token should still be deleted locally, got: %q", stdout.String()) + } + }) + + t.Run("local-only skips server call", func(t *testing.T) { + serverCalled := false + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + serverCalled = true + w.WriteHeader(http.StatusOK) + }), + ) + defer srv.Close() + + store := credstore.NewTokenFileStore( + filepath.Join(t.TempDir(), "tokens.json"), + ) + if err := store.Save("test-id", credstore.Token{ + AccessToken: "access-123", + ExpiresAt: time.Now().Add(time.Hour), + ClientID: "test-id", + }); err != nil { + t.Fatal(err) + } + + cfg := &AppConfig{ + ClientID: "test-id", + Store: store, + } + + var stdout, stderr bytes.Buffer + code := runTokenDelete( + context.Background(), cfg, true, &stdout, &stderr, + ) + if code != 0 { + t.Fatalf("exit code: got %d, want 0", code) + } + if serverCalled { + t.Error("server should not have been called with --local-only") + } + }) + + t.Run("only access token no refresh token", func(t *testing.T) { + var revokedTokens []string + var mu sync.Mutex + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + mu.Lock() + revokedTokens = append(revokedTokens, r.FormValue("token")) + mu.Unlock() + w.WriteHeader(http.StatusOK) + }), + ) + defer srv.Close() + + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } + store := credstore.NewTokenFileStore( + filepath.Join(t.TempDir(), "tokens.json"), + ) + if err := store.Save("test-id", credstore.Token{ + AccessToken: "access-only", + ExpiresAt: time.Now().Add(time.Hour), + ClientID: "test-id", + }); err != nil { + t.Fatal(err) + } + + cfg := &AppConfig{ + ClientID: "test-id", + ServerURL: srv.URL, + Endpoints: ResolvedEndpoints{RevocationURL: srv.URL + "/oauth/revoke"}, + RetryClient: rc, + Store: store, + } + + var stdout, stderr bytes.Buffer + code := runTokenDelete( + context.Background(), cfg, false, &stdout, &stderr, + ) + if code != 0 { + t.Fatalf("exit code: got %d, want 0", code) + } + + mu.Lock() + defer mu.Unlock() + if len(revokedTokens) != 1 { + t.Fatalf("expected 1 revoke call (access only), got %d", len(revokedTokens)) + } + if revokedTokens[0] != "access-only" { + t.Errorf("expected access token, got %q", revokedTokens[0]) + } + }) +} + func TestRunTokenGet(t *testing.T) { tests := []struct { name string @@ -161,7 +383,9 @@ func TestRunTokenGet(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - store := credstore.NewTokenFileStore(filepath.Join(t.TempDir(), "tokens.json")) + store := credstore.NewTokenFileStore( + filepath.Join(t.TempDir(), "tokens.json"), + ) tc.setup(store) var stdout, stderr bytes.Buffer code := runTokenGet(store, "test-id", tc.jsonOut, &stdout, &stderr) diff --git a/tokens.go b/tokens.go index 2686cf0..47ce761 100644 --- a/tokens.go +++ b/tokens.go @@ -42,8 +42,8 @@ func parseOAuthError(body []byte) (ErrorResponse, bool) { } // readResponseBody reads the response body with a size limit to guard against oversized responses. -func readResponseBody(resp *http.Response) ([]byte, error) { - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) +func readResponseBody(resp *http.Response, maxSize int64) ([]byte, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } @@ -88,7 +88,7 @@ func doTokenExchange( } defer resp.Body.Close() - body, err := readResponseBody(resp) + body, err := readResponseBody(resp, cfg.MaxResponseBodySize) if err != nil { return nil, err } diff --git a/userinfo.go b/userinfo.go index 7adce98..0413b94 100644 --- a/userinfo.go +++ b/userinfo.go @@ -34,15 +34,13 @@ type UserInfo struct { EmailVerified *bool `json:"email_verified,omitempty"` } -const userInfoTimeout = 10 * time.Second - // fetchUserInfo calls GET /oauth/userinfo with a Bearer token and returns the // parsed claims per OIDC Core §5.3. func fetchUserInfo(ctx context.Context, cfg *AppConfig, accessToken string) (*UserInfo, error) { - ctx, cancel := context.WithTimeout(ctx, userInfoTimeout) + ctx, cancel := context.WithTimeout(ctx, cfg.UserInfoTimeout) defer cancel() - resp, err := cfg.RetryClient.Get(ctx, cfg.ServerURL+"/oauth/userinfo", + resp, err := cfg.RetryClient.Get(ctx, cfg.Endpoints.UserinfoURL, retry.WithHeader("Authorization", "Bearer "+accessToken), retry.WithHeader("Accept", "application/json"), ) @@ -51,7 +49,7 @@ func fetchUserInfo(ctx context.Context, cfg *AppConfig, accessToken string) (*Us } defer resp.Body.Close() - body, err := readResponseBody(resp) + body, err := readResponseBody(resp, cfg.MaxResponseBodySize) if err != nil { return nil, err } From 0403731b6bfbd042f639f10bf842124d33a8f62e Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Thu, 2 Apr 2026 11:56:34 +0800 Subject: [PATCH 2/5] refactor(cli): use SDK types, parallelize revocation, and harden timeouts - Replace ResolvedEndpoints with oauth.Endpoints from SDK - Simplify resolveEndpoints to assign meta.Endpoints directly - Parallelize refresh and access token revocation with WaitGroup.Go - Drain response body in doRevoke for proper HTTP connection reuse - Accept 2xx status range in doRevoke instead of only 200 - Cap user-supplied timeout durations at 10 minutes - Add configurable --discovery-timeout and --revocation-timeout flags Co-Authored-By: Claude Opus 4.6 (1M context) --- config.go | 53 +++++++++++++++++++++------------------ config_test.go | 1 + discovery_test.go | 20 +++++++++------ main_test.go | 2 ++ token_cmd.go | 63 +++++++++++++++++++++++++++++++++++++---------- token_cmd_test.go | 44 +++++++++++++++++++-------------- 6 files changed, 119 insertions(+), 64 deletions(-) diff --git a/config.go b/config.go index 7442302..7d8a61e 100644 --- a/config.go +++ b/config.go @@ -15,6 +15,7 @@ import ( "github.com/go-authgate/sdk-go/credstore" "github.com/go-authgate/sdk-go/discovery" + "github.com/go-authgate/sdk-go/oauth" retry "github.com/appleboy/go-httpretry" "github.com/google/uuid" @@ -43,6 +44,8 @@ var ( flagDeviceCodeRequestTimeout string flagCallbackTimeout string flagUserInfoTimeout string + flagDiscoveryTimeout string + flagRevocationTimeout string flagMaxResponseBodySize string ) @@ -53,20 +56,15 @@ const ( defaultDeviceCodeRequestTimeout = 10 * time.Second defaultCallbackTimeout = 2 * time.Minute defaultUserInfoTimeout = 10 * time.Second + defaultDiscoveryTimeout = 10 * time.Second + defaultRevocationTimeout = 10 * time.Second defaultMaxResponseBodySize = 1 * 1024 * 1024 // 1 MB — guards against oversized server responses defaultKeyringService = "authgate-cli" -) -// ResolvedEndpoints holds the absolute URLs for all OAuth endpoints. -// Populated from OIDC Discovery or from hardcoded fallback paths. -type ResolvedEndpoints struct { - AuthorizeURL string - TokenURL string - DeviceAuthorizationURL string - TokenInfoURL string - UserinfoURL string - RevocationURL string -} + // maxDurationConfig caps user-supplied timeout values to prevent the CLI + // from hanging indefinitely on misconfiguration. + maxDurationConfig = 10 * time.Minute +) // AppConfig holds all resolved configuration for the CLI application. type AppConfig struct { @@ -83,7 +81,7 @@ type AppConfig struct { // Endpoints holds the resolved OAuth endpoint URLs. // Populated by resolveEndpoints after loadConfig. - Endpoints ResolvedEndpoints + Endpoints oauth.Endpoints // Timeout configuration (resolved from flag → env → default). // Only populated by loadConfig; zero in loadStoreConfig paths. @@ -93,6 +91,8 @@ type AppConfig struct { DeviceCodeRequestTimeout time.Duration CallbackTimeout time.Duration UserInfoTimeout time.Duration + DiscoveryTimeout time.Duration + RevocationTimeout time.Duration MaxResponseBodySize int64 } @@ -134,6 +134,10 @@ func registerFlags(cmd *cobra.Command) { StringVar(&flagCallbackTimeout, "callback-timeout", "", "Timeout waiting for browser callback (e.g. 2m, 5m)") cmd.PersistentFlags(). StringVar(&flagUserInfoTimeout, "userinfo-timeout", "", "Timeout for UserInfo requests (e.g. 10s, 1m)") + cmd.PersistentFlags(). + StringVar(&flagDiscoveryTimeout, "discovery-timeout", "", "Timeout for OIDC Discovery requests (e.g. 10s, 30s)") + cmd.PersistentFlags(). + StringVar(&flagRevocationTimeout, "revocation-timeout", "", "Timeout for token revocation requests (e.g. 10s, 1m)") cmd.PersistentFlags(). StringVar(&flagMaxResponseBodySize, "max-response-body-size", "", "Maximum response body size in bytes (e.g. 1048576)") } @@ -247,6 +251,10 @@ func loadConfig() *AppConfig { flagCallbackTimeout, "CALLBACK_TIMEOUT", defaultCallbackTimeout) cfg.UserInfoTimeout = getDurationConfig( flagUserInfoTimeout, "USERINFO_TIMEOUT", defaultUserInfoTimeout) + cfg.DiscoveryTimeout = getDurationConfig( + flagDiscoveryTimeout, "DISCOVERY_TIMEOUT", defaultDiscoveryTimeout) + cfg.RevocationTimeout = getDurationConfig( + flagRevocationTimeout, "REVOCATION_TIMEOUT", defaultRevocationTimeout) cfg.MaxResponseBodySize = getInt64Config( flagMaxResponseBodySize, "MAX_RESPONSE_BODY_SIZE", defaultMaxResponseBodySize) @@ -311,8 +319,8 @@ func validateServerURL(rawURL string) error { // defaultEndpoints returns hardcoded endpoint paths appended to serverURL. // Used as fallback when OIDC Discovery is unavailable. -func defaultEndpoints(serverURL string) ResolvedEndpoints { - return ResolvedEndpoints{ +func defaultEndpoints(serverURL string) oauth.Endpoints { + return oauth.Endpoints{ AuthorizeURL: serverURL + "/oauth/authorize", TokenURL: serverURL + "/oauth/token", DeviceAuthorizationURL: serverURL + "/oauth/device/code", @@ -335,7 +343,7 @@ func resolveEndpoints(ctx context.Context, cfg *AppConfig) { return } - fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + fetchCtx, cancel := context.WithTimeout(ctx, cfg.DiscoveryTimeout) defer cancel() meta, err := disco.Fetch(fetchCtx) @@ -346,15 +354,7 @@ func resolveEndpoints(ctx context.Context, cfg *AppConfig) { return } - ep := meta.Endpoints() - cfg.Endpoints = ResolvedEndpoints{ - AuthorizeURL: ep.AuthorizeURL, - TokenURL: ep.TokenURL, - DeviceAuthorizationURL: ep.DeviceAuthorizationURL, - TokenInfoURL: ep.TokenInfoURL, - UserinfoURL: ep.UserinfoURL, - RevocationURL: ep.RevocationURL, - } + cfg.Endpoints = meta.Endpoints() } // getDurationConfig resolves a time.Duration from flag → env → default. @@ -376,6 +376,11 @@ func getDurationConfig(flagValue, envKey string, defaultValue time.Duration) tim envKey, d, defaultValue) return defaultValue } + if d > maxDurationConfig { + fmt.Fprintf(os.Stderr, "WARNING: %s exceeds maximum %s, capping at %s\n", + envKey, maxDurationConfig, maxDurationConfig) + return maxDurationConfig + } return d } diff --git a/config_test.go b/config_test.go index abe8b77..6a62ddb 100644 --- a/config_test.go +++ b/config_test.go @@ -96,6 +96,7 @@ func TestGetDurationConfig(t *testing.T) { {"invalid falls back to default", "notaduration", "", 10 * time.Second, 10 * time.Second}, {"negative falls back to default", "-5s", "", 10 * time.Second, 10 * time.Second}, {"zero falls back to default", "0s", "", 10 * time.Second, 10 * time.Second}, + {"exceeds max capped", "20m", "", 10 * time.Second, maxDurationConfig}, } for _, tc := range tests { diff --git a/discovery_test.go b/discovery_test.go index 8f79963..2d1bbf9 100644 --- a/discovery_test.go +++ b/discovery_test.go @@ -73,8 +73,9 @@ func TestResolveEndpoints_Success(t *testing.T) { t.Fatal(err) } cfg := &AppConfig{ - ServerURL: srv.URL, - RetryClient: rc, + ServerURL: srv.URL, + RetryClient: rc, + DiscoveryTimeout: defaultDiscoveryTimeout, } resolveEndpoints(context.Background(), cfg) @@ -113,8 +114,9 @@ func TestResolveEndpoints_FallbackOn404(t *testing.T) { t.Fatal(err) } cfg := &AppConfig{ - ServerURL: srv.URL, - RetryClient: rc, + ServerURL: srv.URL, + RetryClient: rc, + DiscoveryTimeout: defaultDiscoveryTimeout, } resolveEndpoints(context.Background(), cfg) @@ -140,8 +142,9 @@ func TestResolveEndpoints_FallbackOnTimeout(t *testing.T) { t.Fatal(err) } cfg := &AppConfig{ - ServerURL: srv.URL, - RetryClient: rc, + ServerURL: srv.URL, + RetryClient: rc, + DiscoveryTimeout: defaultDiscoveryTimeout, } ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) @@ -161,8 +164,9 @@ func TestResolveEndpoints_FallbackOnNetworkError(t *testing.T) { t.Fatal(err) } cfg := &AppConfig{ - ServerURL: "http://127.0.0.1:1", // port 1 is unreachable - RetryClient: rc, + ServerURL: "http://127.0.0.1:1", // port 1 is unreachable + RetryClient: rc, + DiscoveryTimeout: defaultDiscoveryTimeout, } resolveEndpoints(context.Background(), cfg) diff --git a/main_test.go b/main_test.go index 80e9e8e..eb7c3d4 100644 --- a/main_test.go +++ b/main_test.go @@ -40,6 +40,8 @@ func testConfig(t *testing.T) *AppConfig { DeviceCodeRequestTimeout: defaultDeviceCodeRequestTimeout, CallbackTimeout: defaultCallbackTimeout, UserInfoTimeout: defaultUserInfoTimeout, + DiscoveryTimeout: defaultDiscoveryTimeout, + RevocationTimeout: defaultRevocationTimeout, MaxResponseBodySize: defaultMaxResponseBodySize, } } diff --git a/token_cmd.go b/token_cmd.go index ba3102b..37ca8d5 100644 --- a/token_cmd.go +++ b/token_cmd.go @@ -6,9 +6,9 @@ import ( "errors" "fmt" "io" - "net/http" "net/url" "strings" + "sync" "time" retry "github.com/appleboy/go-httpretry" @@ -120,8 +120,6 @@ func loadTokenOrFail( return tok, 0 } -const revokeTokenTimeout = 10 * time.Second - // runTokenDelete is the testable core of `token delete`. func runTokenDelete( ctx context.Context, @@ -155,7 +153,7 @@ func runTokenDelete( } // revokeTokenOnServer attempts to revoke tokens on the OAuth server (RFC 7009). -// It revokes the refresh token first (long-lived), then the access token. +// It revokes the refresh and access tokens concurrently. func revokeTokenOnServer( ctx context.Context, cfg *AppConfig, @@ -163,20 +161,55 @@ func revokeTokenOnServer( stderr io.Writer, ) error { revokeURL := cfg.Endpoints.RevocationURL + timeout := cfg.RevocationTimeout + + var ( + mu sync.Mutex + refreshErr error + accessErr error + wg sync.WaitGroup + ) - // Revoke the refresh token first (more important — it's long-lived). if tok.RefreshToken != "" { - if err := doRevoke(ctx, cfg.RetryClient, revokeURL, tok.RefreshToken); err != nil { - fmt.Fprintf(stderr, "Warning: failed to revoke refresh token: %v\n", err) - } + wg.Go(func() { + if err := doRevoke( + ctx, + cfg.RetryClient, + revokeURL, + tok.RefreshToken, + timeout, + ); err != nil { + mu.Lock() + refreshErr = err + mu.Unlock() + } + }) } if tok.AccessToken != "" { - if err := doRevoke(ctx, cfg.RetryClient, revokeURL, tok.AccessToken); err != nil { - return fmt.Errorf("access token revocation: %w", err) - } + wg.Go(func() { + if err := doRevoke( + ctx, + cfg.RetryClient, + revokeURL, + tok.AccessToken, + timeout, + ); err != nil { + mu.Lock() + accessErr = err + mu.Unlock() + } + }) } + wg.Wait() + + if refreshErr != nil { + fmt.Fprintf(stderr, "Warning: failed to revoke refresh token: %v\n", refreshErr) + } + if accessErr != nil { + return fmt.Errorf("access token revocation: %w", accessErr) + } return nil } @@ -186,8 +219,9 @@ func doRevoke( client *retry.Client, revokeURL string, token string, + timeout time.Duration, ) error { - ctx, cancel := context.WithTimeout(ctx, revokeTokenTimeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() data := url.Values{"token": {token}} @@ -202,7 +236,10 @@ func doRevoke( } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + // Drain body for proper HTTP connection reuse. + _, _ = io.Copy(io.Discard, resp.Body) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("revoke returned status %d", resp.StatusCode) } return nil diff --git a/token_cmd_test.go b/token_cmd_test.go index f9de6b3..ff5c8cd 100644 --- a/token_cmd_test.go +++ b/token_cmd_test.go @@ -15,6 +15,7 @@ import ( retry "github.com/appleboy/go-httpretry" "github.com/go-authgate/sdk-go/credstore" + "github.com/go-authgate/sdk-go/oauth" ) func TestRunTokenDelete(t *testing.T) { @@ -137,11 +138,12 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { } cfg := &AppConfig{ - ClientID: "test-id", - ServerURL: srv.URL, - Endpoints: ResolvedEndpoints{RevocationURL: srv.URL + "/oauth/revoke"}, - RetryClient: rc, - Store: store, + ClientID: "test-id", + ServerURL: srv.URL, + Endpoints: oauth.Endpoints{RevocationURL: srv.URL + "/oauth/revoke"}, + RevocationTimeout: defaultRevocationTimeout, + RetryClient: rc, + Store: store, } var stdout, stderr bytes.Buffer @@ -163,11 +165,13 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { if len(revokedTokens) != 2 { t.Fatalf("expected 2 revoke calls, got %d", len(revokedTokens)) } - if revokedTokens[0] != "refresh-456" { - t.Errorf("expected refresh token revoked first, got %q", revokedTokens[0]) + // Revocations run concurrently, so order is non-deterministic. + got := map[string]bool{revokedTokens[0]: true, revokedTokens[1]: true} + if !got["refresh-456"] { + t.Errorf("expected refresh token to be revoked, got %v", revokedTokens) } - if revokedTokens[1] != "access-123" { - t.Errorf("expected access token revoked second, got %q", revokedTokens[1]) + if !got["access-123"] { + t.Errorf("expected access token to be revoked, got %v", revokedTokens) } }) @@ -195,11 +199,12 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { } cfg := &AppConfig{ - ClientID: "test-id", - ServerURL: srv.URL, - Endpoints: ResolvedEndpoints{RevocationURL: srv.URL + "/oauth/revoke"}, - RetryClient: rc, - Store: store, + ClientID: "test-id", + ServerURL: srv.URL, + Endpoints: oauth.Endpoints{RevocationURL: srv.URL + "/oauth/revoke"}, + RevocationTimeout: defaultRevocationTimeout, + RetryClient: rc, + Store: store, } var stdout, stderr bytes.Buffer @@ -288,11 +293,12 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { } cfg := &AppConfig{ - ClientID: "test-id", - ServerURL: srv.URL, - Endpoints: ResolvedEndpoints{RevocationURL: srv.URL + "/oauth/revoke"}, - RetryClient: rc, - Store: store, + ClientID: "test-id", + ServerURL: srv.URL, + Endpoints: oauth.Endpoints{RevocationURL: srv.URL + "/oauth/revoke"}, + RevocationTimeout: defaultRevocationTimeout, + RetryClient: rc, + Store: store, } var stdout, stderr bytes.Buffer From 190df167960afe82496d9169123e6668a0b9d2a1 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Thu, 2 Apr 2026 12:24:03 +0800 Subject: [PATCH 3/5] fix(cli): add client auth to revocation, fix partial failure messaging - Include client_id and client_secret in RFC 7009 revocation requests - Return error from revokeTokenOnServer on refresh-only failure to prevent misleading "Token revoked on server." message - Wire test server into local-only test config for meaningful assertion - Assert TokenInfoURL in OIDC discovery success test Co-Authored-By: Claude Opus 4.6 (1M context) --- discovery_test.go | 5 +++++ token_cmd.go | 42 +++++++++++++++++++++--------------------- token_cmd_test.go | 11 +++++++++-- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/discovery_test.go b/discovery_test.go index 2d1bbf9..4390ce0 100644 --- a/discovery_test.go +++ b/discovery_test.go @@ -99,6 +99,11 @@ func TestResolveEndpoints_Success(t *testing.T) { if cfg.Endpoints.RevocationURL != srv.URL+"/revoke" { t.Errorf("RevocationURL = %q, want %q", cfg.Endpoints.RevocationURL, srv.URL+"/revoke") } + // TokenInfoURL is derived from issuer by the SDK (not a standard OIDC field). + wantTokenInfo := srv.URL + "/oauth/tokeninfo" + if cfg.Endpoints.TokenInfoURL != wantTokenInfo { + t.Errorf("TokenInfoURL = %q, want %q", cfg.Endpoints.TokenInfoURL, wantTokenInfo) + } } func TestResolveEndpoints_FallbackOn404(t *testing.T) { diff --git a/token_cmd.go b/token_cmd.go index 37ca8d5..3abf30f 100644 --- a/token_cmd.go +++ b/token_cmd.go @@ -172,13 +172,7 @@ func revokeTokenOnServer( if tok.RefreshToken != "" { wg.Go(func() { - if err := doRevoke( - ctx, - cfg.RetryClient, - revokeURL, - tok.RefreshToken, - timeout, - ); err != nil { + if err := doRevoke(ctx, cfg, revokeURL, tok.RefreshToken, timeout); err != nil { mu.Lock() refreshErr = err mu.Unlock() @@ -188,13 +182,7 @@ func revokeTokenOnServer( if tok.AccessToken != "" { wg.Go(func() { - if err := doRevoke( - ctx, - cfg.RetryClient, - revokeURL, - tok.AccessToken, - timeout, - ); err != nil { + if err := doRevoke(ctx, cfg, revokeURL, tok.AccessToken, timeout); err != nil { mu.Lock() accessErr = err mu.Unlock() @@ -204,19 +192,25 @@ func revokeTokenOnServer( wg.Wait() - if refreshErr != nil { + switch { + case accessErr != nil && refreshErr != nil: fmt.Fprintf(stderr, "Warning: failed to revoke refresh token: %v\n", refreshErr) - } - if accessErr != nil { return fmt.Errorf("access token revocation: %w", accessErr) + case accessErr != nil: + return fmt.Errorf("access token revocation: %w", accessErr) + case refreshErr != nil: + return fmt.Errorf("refresh token revocation: %w", refreshErr) + default: + return nil } - return nil } // doRevoke posts a single token to the revocation endpoint (RFC 7009). +// It includes client_id (and client_secret for confidential clients) as +// required by most OAuth servers for client authentication. func doRevoke( ctx context.Context, - client *retry.Client, + cfg *AppConfig, revokeURL string, token string, timeout time.Duration, @@ -224,8 +218,14 @@ func doRevoke( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - data := url.Values{"token": {token}} - resp, err := client.Post(ctx, revokeURL, + data := url.Values{ + "token": {token}, + "client_id": {cfg.ClientID}, + } + if !cfg.IsPublicClient() { + data.Set("client_secret", cfg.ClientSecret) + } + resp, err := cfg.RetryClient.Post(ctx, revokeURL, retry.WithBody( "application/x-www-form-urlencoded", strings.NewReader(data.Encode()), diff --git a/token_cmd_test.go b/token_cmd_test.go index ff5c8cd..1019824 100644 --- a/token_cmd_test.go +++ b/token_cmd_test.go @@ -232,6 +232,10 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { ) defer srv.Close() + rc, err := retry.NewClient() + if err != nil { + t.Fatal(err) + } store := credstore.NewTokenFileStore( filepath.Join(t.TempDir(), "tokens.json"), ) @@ -244,8 +248,11 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { } cfg := &AppConfig{ - ClientID: "test-id", - Store: store, + ClientID: "test-id", + Endpoints: oauth.Endpoints{RevocationURL: srv.URL + "/oauth/revoke"}, + RevocationTimeout: defaultRevocationTimeout, + RetryClient: rc, + Store: store, } var stdout, stderr bytes.Buffer From 0612d4ebae74c30f4bc73b187d5b49ea9192454d Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Thu, 2 Apr 2026 12:31:00 +0800 Subject: [PATCH 4/5] fix(test): use atomic.Bool for race-safe server call assertion Co-Authored-By: Claude Opus 4.6 (1M context) --- token_cmd_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/token_cmd_test.go b/token_cmd_test.go index 1019824..57fced5 100644 --- a/token_cmd_test.go +++ b/token_cmd_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "testing" "time" @@ -223,10 +224,10 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { }) t.Run("local-only skips server call", func(t *testing.T) { - serverCalled := false + var serverCalled atomic.Bool srv := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - serverCalled = true + serverCalled.Store(true) w.WriteHeader(http.StatusOK) }), ) @@ -262,7 +263,7 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) { if code != 0 { t.Fatalf("exit code: got %d, want 0", code) } - if serverCalled { + if serverCalled.Load() { t.Error("server should not have been called with --local-only") } }) From 0152654e65064e50d4b8f2500a36a9fca710f3e4 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Thu, 2 Apr 2026 12:37:46 +0800 Subject: [PATCH 5/5] fix(cli): bound response body drain and cap max response body size - Limit io.Copy drain in doRevoke to 4KB to prevent unbounded reads - Cap MAX_RESPONSE_BODY_SIZE at 100MB to prevent OOM via io.ReadAll Co-Authored-By: Claude Opus 4.6 (1M context) --- config.go | 9 +++++++++ token_cmd.go | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 7d8a61e..6c09575 100644 --- a/config.go +++ b/config.go @@ -64,6 +64,10 @@ const ( // maxDurationConfig caps user-supplied timeout values to prevent the CLI // from hanging indefinitely on misconfiguration. maxDurationConfig = 10 * time.Minute + + // maxResponseBodySizeCap prevents users from setting an excessively large + // response body limit that could cause OOM via io.ReadAll. + maxResponseBodySizeCap int64 = 100 * 1024 * 1024 // 100 MB ) // AppConfig holds all resolved configuration for the CLI application. @@ -257,6 +261,11 @@ func loadConfig() *AppConfig { flagRevocationTimeout, "REVOCATION_TIMEOUT", defaultRevocationTimeout) cfg.MaxResponseBodySize = getInt64Config( flagMaxResponseBodySize, "MAX_RESPONSE_BODY_SIZE", defaultMaxResponseBodySize) + if cfg.MaxResponseBodySize > maxResponseBodySizeCap { + fmt.Fprintf(os.Stderr, + "WARNING: MAX_RESPONSE_BODY_SIZE exceeds %d, capping\n", maxResponseBodySizeCap) + cfg.MaxResponseBodySize = maxResponseBodySizeCap + } if cfg.TokenStoreMode == "auto" { if ss, ok := cfg.Store.(*credstore.SecureStore[credstore.Token]); ok && !ss.UseKeyring() { diff --git a/token_cmd.go b/token_cmd.go index 3abf30f..2d1af12 100644 --- a/token_cmd.go +++ b/token_cmd.go @@ -236,8 +236,8 @@ func doRevoke( } defer resp.Body.Close() - // Drain body for proper HTTP connection reuse. - _, _ = io.Copy(io.Discard, resp.Body) + // Drain a bounded amount of the body for proper HTTP connection reuse. + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("revoke returned status %d", resp.StatusCode)