diff --git a/CHANGELOG.md b/CHANGELOG.md index 26d6ed4..538b41c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ This project uses [Semantic Versioning 2.0.0](http://semver.org/), the format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## Unreleased + +### Added + +- `auth login` can authenticate in the browser via an interactive OAuth flow (OAuth 2.0 with PKCE and a loopback redirect). The feature is dark-launched and off by default: opt in per command with `--web`, or persistently by setting `oauth_login: true` in the config file (or `DNSIMPLE_OAUTH_LOGIN=1`). Without it, `auth login` keeps prompting for a pasted API token. (dnsimple/cli#57) + ## 0.9.0 - 2026-05-25 ### Added diff --git a/README.md b/README.md index 6284963..24d2d28 100644 --- a/README.md +++ b/README.md @@ -75,16 +75,19 @@ dnsimple [command] [flags] The CLI supports two authentication modes that can be combined freely. > [!NOTE] -> The CLI currently supports API token authentication only, including both classic and scoped API tokens. OAuth support may be considered in the future, but it is not currently on the roadmap. +> By default `auth login` authenticates with an API token (classic or scoped), which you paste when prompted. An interactive browser login (OAuth) is being rolled out and is off by default for now. Opt in per command with `--web`, or persistently by setting `oauth_login: true` in the config file (or `DNSIMPLE_OAUTH_LOGIN=1`). #### Stateful: stored contexts Authenticate once and the CLI remembers a named *context* (token, account, environment) on disk. Multiple contexts can coexist and you select one as active: ```shell -# Log in to production and store a context +# Log in to production and store a context (prompts for an API token) dnsimple auth login +# Authenticate in the browser instead of pasting a token +dnsimple auth login --web + # Log in to sandbox alongside it dnsimple auth login --sandbox diff --git a/go.mod b/go.mod index 8e5d7a2..bd25984 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/dnsimple/cli go 1.25.4 require ( + github.com/cli/browser v1.3.0 github.com/dnsimple/dnsimple-go/v9 v9.1.0 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 diff --git a/go.sum b/go.sum index 421b00c..c20efae 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/cli/browser v1.3.0 h1:LejqCrpWr+1pRqmEPDGnTZOjsMe7sehifLynZJuqJpo= +github.com/cli/browser v1.3.0/go.mod h1:HH8s+fOAxjhQoBUAsKuPCbqUuxZDhQ2/aD+SzsEfBTk= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/cli/auth.go b/internal/cli/auth.go index bceff19..aa9beb1 100644 --- a/internal/cli/auth.go +++ b/internal/cli/auth.go @@ -64,7 +64,7 @@ func (a *authStatusOutput) TemplateData() any { func newAuthCmd(f *cmdutil.Factory) *cobra.Command { cmd := &cobra.Command{ Use: "auth", - Short: "Authenticate with DNSimple", + Short: "Manage authentication contexts", } cmd.AddCommand(newAuthLoginCmd(f)) @@ -151,24 +151,37 @@ This command does not contact the DNSimple API and works without a valid token.` func newAuthLoginCmd(f *cmdutil.Factory) *cobra.Command { var withToken bool + var web bool var nameFlag string cmd := &cobra.Command{ Use: "login", - Short: "Authenticate with a DNSimple API token", - Long: `Authenticate with DNSimple by providing an API token and store it as a named context. + Short: "Authenticate with DNSimple", + Long: `Authenticate with DNSimple and store the resulting credential as a named context. + +On a terminal, this command prompts you to paste an API token. Pass --web to +authenticate in your browser instead: it opens the DNSimple authorization page +and completes the login automatically once you approve, with no token to copy. +Browser login can also be turned on persistently by setting 'oauth_login: true' +in the config file (or DNSIMPLE_OAUTH_LOGIN=1). The new context becomes the active one. To create a sandbox context, pass --sandbox. To choose a context name, pass --name; otherwise the name is derived from the environment ('production' or 'sandbox'), with the account ID appended on collision. -Get your token from: +Headless / non-interactive use: + + - Pass --with-token to pipe a pre-issued API token on stdin: + echo "$TOKEN" | dnsimple auth login --with-token + - When stdin is not a terminal (CI, redirected input), the command reads + the token from stdin without requiring --with-token. - Production: https://dnsimple.com/user - Sandbox: https://sandbox.dnsimple.com/user +With --web, if the browser cannot be launched (e.g. no display server), the +authorize URL is printed to stderr and the command keeps listening for the +callback. -See https://support.dnsimple.com/articles/api-access-token/ for instructions on -generating an API token.`, +See https://support.dnsimple.com/articles/api-access-token/ if you need to +generate an API token manually.`, RunE: func(cmd *cobra.Command, args []string) error { cfg, err := f.Config() if err != nil { @@ -176,7 +189,9 @@ generating an API token.`, } host := config.HostForSandbox(cfg.Sandbox) - token, err := readLoginToken(cmd, withToken) + useOAuth := web || cfg.OAuthLogin + warnIfWebIgnored(cmd, web, withToken) + token, err := acquireToken(cmd, cfg, withToken, useOAuth) if err != nil { return err } @@ -205,7 +220,7 @@ generating an API token.`, return err } - ctx, action, err := upsertLoginContext(creds, host, token, accountID, user, nameFlag) + ctx, _, err := upsertLoginContext(creds, host, token, accountID, user, nameFlag) if err != nil { return err } @@ -215,13 +230,24 @@ generating an API token.`, return err } - fmt.Fprintf(cmd.ErrOrStderr(), "%s context %q (%s, account %s) and set as active\n", - action, ctx.Name, config.EnvironmentName(host), ctx.AccountID) + stderr := cmd.ErrOrStderr() + if user != "" { + fmt.Fprintf(stderr, "Success! You're now logged in to DNSimple as %s.\n", user) + } else { + fmt.Fprintln(stderr, "Success! You're now logged in to DNSimple.") + } + + location := config.EnvironmentName(host) + if ctx.AccountID != "" { + location = fmt.Sprintf("%s, account %s", location, ctx.AccountID) + } + fmt.Fprintf(stderr, "Context %q (%s) is now active.\n", ctx.Name, location) return nil }, } cmd.Flags().BoolVar(&withToken, "with-token", false, "Read token from stdin") + cmd.Flags().BoolVar(&web, "web", false, "Authenticate in a browser instead of pasting a token") cmd.Flags().StringVar(&nameFlag, "name", "", "Name for the new context (auto-derived if omitted)") return cmd @@ -327,8 +353,7 @@ func resolveLoginAccount(c *dnsimple.Client, whoami *dnsimple.WhoamiResponse, in // - same (host, token) anywhere → refresh that context (re-login). // - otherwise → create with an auto-derived name. // -// The returned action is "Created" or "Refreshed" for use in the success -// message. +// The returned action is "Created" or "Refreshed". func upsertLoginContext(creds *config.Credentials, host, token, accountID, user, explicitName string) (*config.Context, string, error) { if explicitName != "" { existing := creds.Find(explicitName) diff --git a/internal/cli/auth_oauth.go b/internal/cli/auth_oauth.go new file mode 100644 index 0000000..18dd59e --- /dev/null +++ b/internal/cli/auth_oauth.go @@ -0,0 +1,89 @@ +package cli + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/cli/browser" + "github.com/dnsimple/cli/internal/config" + "github.com/dnsimple/cli/internal/oauth" + "github.com/spf13/cobra" +) + +// loginViaOAuth runs the interactive OAuth browser flow and returns an +// access token. Production wiring constructs an oauth.Client from the +// active config and delegates to its Login method. Tests override this +// var directly to skip the listener / browser / token exchange. +var loginViaOAuth = defaultLoginViaOAuth + +// isStdinTTY reports whether the command's stdin is a real terminal. +// Tests override it directly so they can drive the OAuth branch without +// faking a PTY. The underlying check delegates to isInteractiveInput +// (see confirm.go) so the OAuth branch stays in lockstep with how +// destructive-action prompts decide interactivity. +var isStdinTTY = func(cmd *cobra.Command) bool { + return isInteractiveInput(cmd.InOrStdin()) +} + +// defaultLoginViaOAuth is the production implementation of the OAuth flow. +// It is wired through `loginViaOAuth` so the integration test in +// auth_oauth_test.go can swap it for a stub. +func defaultLoginViaOAuth(ctx context.Context, cfg *config.Config, errOut io.Writer) (string, error) { + clientID := config.OAuthClientID(cfg.Sandbox) + if clientID == "" { + return "", oauth.ErrNotProvisioned + } + c := &oauth.Client{ + ClientID: clientID, + AuthorizeBase: config.AuthorizeURL(cfg.Sandbox), + TokenURL: config.OAuthTokenURL(cfg.BaseURL), + BrowserOpener: browser.OpenURL, + Stderr: errOut, + } + return c.Login(ctx) +} + +// acquireToken obtains the access token for a fresh `auth login`. It reads a +// token from stdin for --with-token or non-TTY input; on a TTY it runs the +// OAuth browser flow when useOAuth is set, otherwise it prompts for a pasted +// token. A browser-login failure is returned as-is (no paste fallback); the +// error tells the user to retry or pass --with-token. +func acquireToken(cmd *cobra.Command, cfg *config.Config, withToken, useOAuth bool) (string, error) { + switch { + case withToken: + return readLoginToken(cmd, true) + case !isStdinTTY(cmd) || !useOAuth: + return readLoginToken(cmd, false) + } + + token, err := loginViaOAuth(context.Background(), cfg, cmd.ErrOrStderr()) + switch { + case err == nil: + return token, nil + case errors.Is(err, context.Canceled): + return "", err + case errors.Is(err, oauth.ErrNotProvisioned): + return "", errors.New("interactive browser login is not available in this build\n\nRun `dnsimple auth login --with-token` to authenticate with an API token instead") + default: + return "", fmt.Errorf("browser login failed: %w\n\nRetry `dnsimple auth login`, or run `dnsimple auth login --with-token` to authenticate with an API token instead", err) + } +} + +// warnIfWebIgnored notes that an explicit --web was not honored, mirroring the +// precedence in acquireToken: --with-token wins, and the browser flow needs an +// interactive terminal. It keys off the actual flag value (not just whether it +// was set) so `--web=false` stays silent, and it ignores the persistent +// oauth_login toggle, which is meant to fall back to the prompt without noise. +func warnIfWebIgnored(cmd *cobra.Command, web, withToken bool) { + if !web { + return + } + switch { + case withToken: + fmt.Fprintln(cmd.ErrOrStderr(), "Warning: --web is ignored when --with-token is set.") + case !isStdinTTY(cmd): + fmt.Fprintln(cmd.ErrOrStderr(), "Warning: browser login (--web) needs an interactive terminal; reading the token from stdin instead.") + } +} diff --git a/internal/cli/auth_oauth_test.go b/internal/cli/auth_oauth_test.go new file mode 100644 index 0000000..d00abe3 --- /dev/null +++ b/internal/cli/auth_oauth_test.go @@ -0,0 +1,375 @@ +package cli + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "testing" + + "github.com/dnsimple/cli/internal/cmdutil" + "github.com/dnsimple/cli/internal/config" + "github.com/dnsimple/cli/internal/oauth" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +// stubLoginViaOAuth swaps the OAuth entry point for the duration of the +// test, restoring the original on Cleanup. +func stubLoginViaOAuth(t *testing.T, stub func(ctx context.Context, cfg *config.Config, errOut io.Writer) (string, error)) { + t.Helper() + prev := loginViaOAuth + loginViaOAuth = stub + t.Cleanup(func() { loginViaOAuth = prev }) +} + +// forceTTY makes isStdinTTY return the given value for the duration of the +// test. Tests that exercise the OAuth branch need to flip it to true +// because strings.NewReader is not an *os.File. +func forceTTY(t *testing.T, tty bool) { + t.Helper() + prev := isStdinTTY + isStdinTTY = func(*cobra.Command) bool { return tty } + t.Cleanup(func() { isStdinTTY = prev }) +} + +// --- acquireToken branching --- + +func TestAcquireTokenWithTokenFlagReadsFromStdin(t *testing.T) { + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("tok-from-stdin\n")) + cmd.SetErr(io.Discard) + + got, err := acquireToken(cmd, &config.Config{}, true, false) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "tok-from-stdin", got) +} + +func TestAcquireTokenNonTTYReadsFromStdin(t *testing.T) { + forceTTY(t, false) + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("tok-piped\n")) + cmd.SetErr(io.Discard) + + got, err := acquireToken(cmd, &config.Config{}, false, false) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "tok-piped", got) +} + +func TestAcquireTokenTTYRunsOAuth(t *testing.T) { + forceTTY(t, true) + + var capturedSandbox bool + stubLoginViaOAuth(t, func(_ context.Context, cfg *config.Config, _ io.Writer) (string, error) { + capturedSandbox = cfg.Sandbox + return "tok-from-oauth", nil + }) + + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("")) // OAuth path must not consume stdin + cmd.SetErr(io.Discard) + + got, err := acquireToken(cmd, &config.Config{Sandbox: true}, false, true) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "tok-from-oauth", got) + assert.True(t, capturedSandbox, "OAuth flow should receive cfg.Sandbox=true") +} + +// TestAcquireTokenTTYWithoutOAuthPromptsForToken pins the dark-launch default: +// on a TTY with OAuth disabled, the command reads a pasted token and never +// starts the browser flow. +func TestAcquireTokenTTYWithoutOAuthPromptsForToken(t *testing.T) { + forceTTY(t, true) + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + t.Fatal("OAuth flow must not run when useOAuth is false") + return "", nil + }) + + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("tok-paste\n")) + cmd.SetErr(io.Discard) + + got, err := acquireToken(cmd, &config.Config{}, false, false) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "tok-paste", got) +} + +func TestAcquireTokenErrorsOnErrNotProvisioned(t *testing.T) { + forceTTY(t, true) + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "", oauth.ErrNotProvisioned + }) + + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("would-be-pasted\n")) // must not be consumed + cmd.SetErr(io.Discard) + + _, err := acquireToken(cmd, &config.Config{}, false, true) + if !assert.Error(t, err) { + return + } + assert.Contains(t, err.Error(), "not available in this build") + assert.Contains(t, err.Error(), "--with-token") +} + +func TestAcquireTokenErrorsOnTransientOAuthError(t *testing.T) { + forceTTY(t, true) + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "", fmt.Errorf("network: connection refused") + }) + + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("would-be-pasted\n")) // must not be consumed + cmd.SetErr(io.Discard) + + _, err := acquireToken(cmd, &config.Config{}, false, true) + if !assert.Error(t, err) { + return + } + assert.Contains(t, err.Error(), "browser login failed") + assert.Contains(t, err.Error(), "connection refused") + assert.Contains(t, err.Error(), "--with-token") +} + +// TestAcquireTokenAbortsOnAccessDenied pins that a user who explicitly +// denied consent in the browser is NOT pestered with a paste prompt. +func TestAcquireTokenAbortsOnAccessDenied(t *testing.T) { + forceTTY(t, true) + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "", &oauth.AuthError{Code: "access_denied", Description: "user cancelled"} + }) + + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("would-be-pasted\n")) + cmd.SetErr(io.Discard) + + _, err := acquireToken(cmd, &config.Config{}, false, true) + if !assert.Error(t, err) { + return + } + var ae *oauth.AuthError + assert.ErrorAs(t, err, &ae) + assert.Equal(t, "access_denied", ae.Code) +} + +func TestAcquireTokenAbortsOnStateMismatch(t *testing.T) { + forceTTY(t, true) + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "", oauth.ErrStateMismatch + }) + + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("would-be-pasted\n")) + cmd.SetErr(io.Discard) + + _, err := acquireToken(cmd, &config.Config{}, false, true) + assert.ErrorIs(t, err, oauth.ErrStateMismatch) +} + +func TestAcquireTokenAbortsOnContextCancellation(t *testing.T) { + forceTTY(t, true) + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "", context.Canceled + }) + + cmd := &cobra.Command{} + cmd.SetIn(strings.NewReader("would-be-pasted\n")) + cmd.SetErr(io.Discard) + + _, err := acquireToken(cmd, &config.Config{}, false, true) + assert.ErrorIs(t, err, context.Canceled) +} + +// --- warnIfWebIgnored --- + +func TestWarnIfWebIgnored(t *testing.T) { + warnOutput := func(t *testing.T, web, withToken bool) string { + t.Helper() + cmd := &cobra.Command{} + var errb bytes.Buffer + cmd.SetErr(&errb) + warnIfWebIgnored(cmd, web, withToken) + return errb.String() + } + + t.Run("with-token wins over --web", func(t *testing.T) { + assert.Contains(t, warnOutput(t, true, true), "--with-token") + }) + + t.Run("non-TTY needs a terminal", func(t *testing.T) { + forceTTY(t, false) + assert.Contains(t, warnOutput(t, true, false), "interactive terminal") + }) + + t.Run("no --web is silent", func(t *testing.T) { + forceTTY(t, false) + assert.Empty(t, warnOutput(t, false, false)) + }) + + t.Run("--web on a TTY is silent", func(t *testing.T) { + forceTTY(t, true) + assert.Empty(t, warnOutput(t, true, false)) + }) +} + +// --- end-to-end: auth login via OAuth --- + +func TestAuthLoginViaOAuthEndToEnd(t *testing.T) { + isolateConfigHomeForCLI(t) + forceTTY(t, true) + + server := newWhoamiServer(t, `{"data":{"user":{"id":1,"email":"alice@example.com"},"account":{"id":981,"email":"acct@example.com"}}}`) + defer server.Close() + + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "tok-oauth-1", nil + }) + + f := cmdutil.NewFactory("test") + cmd := buildLoginCmdWithBaseURL(t, f, server.URL) + if !assert.NoError(t, cmd.Flags().Set("web", "true")) { // opt into the browser flow + return + } + + var stderr bytes.Buffer + cmd.SetIn(strings.NewReader("")) // OAuth path should not consume stdin + cmd.SetErr(&stderr) + cmd.SetOut(io.Discard) + + if err := cmd.RunE(cmd, nil); !assert.NoError(t, err) { + return + } + + creds, err := config.LoadCredentials() + if !assert.NoError(t, err) { + return + } + if assert.Len(t, creds.Contexts, 1) { + ctx := creds.Contexts[0] + assert.Equal(t, "production", ctx.Name) + assert.Equal(t, config.ProductionHost, ctx.Host) + assert.Equal(t, "tok-oauth-1", ctx.Token, "stored token should come from the OAuth flow") + assert.Equal(t, "981", ctx.AccountID) + assert.Equal(t, "alice@example.com", ctx.User) + } + assert.Equal(t, "production", creds.ActiveContext) + assert.Contains(t, stderr.String(), "You're now logged in to DNSimple as alice@example.com") + assert.Contains(t, stderr.String(), "is now active") +} + +// TestAuthLoginViaOAuthConfigToggle pins the persistent opt-in: with +// oauth_login enabled in config (cfg.OAuthLogin) and no --web flag, `auth +// login` on a TTY runs the browser flow. This exercises the cfg.OAuthLogin +// operand of `useOAuth := web || cfg.OAuthLogin`, which the --web tests do not. +func TestAuthLoginViaOAuthConfigToggle(t *testing.T) { + isolateConfigHomeForCLI(t) + forceTTY(t, true) + + server := newWhoamiServer(t, `{"data":{"user":{"id":1,"email":"alice@example.com"},"account":{"id":981,"email":"acct@example.com"}}}`) + defer server.Close() + + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "tok-cfg-toggle", nil + }) + + f := cmdutil.NewFactory("test") + cfg := &config.Config{BaseURL: server.URL, OAuthLogin: true} + f.Config = func() (*config.Config, error) { return cfg, nil } + cmd := newAuthLoginCmd(f) + + cmd.SetIn(strings.NewReader("would-be-pasted\n")) // OAuth path must not consume stdin + cmd.SetErr(io.Discard) + cmd.SetOut(io.Discard) + + if err := cmd.RunE(cmd, nil); !assert.NoError(t, err) { + return + } + + creds, err := config.LoadCredentials() + if !assert.NoError(t, err) { + return + } + if assert.Len(t, creds.Contexts, 1) { + assert.Equal(t, "tok-cfg-toggle", creds.Contexts[0].Token, "stored token should come from the OAuth flow enabled by oauth_login") + } + assert.Equal(t, "production", creds.ActiveContext) +} + +// TestAuthLoginDefaultPromptsForToken pins the dark-launch default: with the +// browser flow off (no --web, oauth_login unset), `auth login` on a TTY reads +// a pasted token and stores a context. The OAuth flow must not run. +func TestAuthLoginDefaultPromptsForToken(t *testing.T) { + isolateConfigHomeForCLI(t) + forceTTY(t, true) + + server := newWhoamiServer(t, `{"data":{"user":{"id":1,"email":"alice@example.com"},"account":{"id":981,"email":"acct@example.com"}}}`) + defer server.Close() + + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + t.Fatal("OAuth flow must not run without --web / oauth_login") + return "", nil + }) + + f := cmdutil.NewFactory("test") + cmd := buildLoginCmdWithBaseURL(t, f, server.URL) + + cmd.SetIn(strings.NewReader("tok-paste\n")) + cmd.SetErr(io.Discard) + cmd.SetOut(io.Discard) + + if err := cmd.RunE(cmd, nil); !assert.NoError(t, err) { + return + } + + creds, err := config.LoadCredentials() + if !assert.NoError(t, err) { + return + } + if assert.Len(t, creds.Contexts, 1) { + assert.Equal(t, "tok-paste", creds.Contexts[0].Token, "stored token should come from the paste prompt") + } + assert.Equal(t, "production", creds.ActiveContext) +} + +// TestAuthLoginViaOAuthNotProvisionedErrors pins that once the browser flow is +// opted into (--web) but the build is not provisioned, the command reports the +// failure and exits instead of falling back to a paste prompt. +func TestAuthLoginViaOAuthNotProvisionedErrors(t *testing.T) { + isolateConfigHomeForCLI(t) + forceTTY(t, true) + + server := newWhoamiServer(t, `{"data":{"user":{"id":1,"email":"alice@example.com"},"account":{"id":981,"email":"acct@example.com"}}}`) + defer server.Close() + + stubLoginViaOAuth(t, func(context.Context, *config.Config, io.Writer) (string, error) { + return "", oauth.ErrNotProvisioned + }) + + f := cmdutil.NewFactory("test") + cmd := buildLoginCmdWithBaseURL(t, f, server.URL) + if !assert.NoError(t, cmd.Flags().Set("web", "true")) { + return + } + + cmd.SetIn(strings.NewReader("tok-paste\n")) // must not be consumed + cmd.SetErr(io.Discard) + cmd.SetOut(io.Discard) + + err := cmd.RunE(cmd, nil) + if !assert.Error(t, err) { + return + } + assert.Contains(t, err.Error(), "--with-token") + + creds, _ := config.LoadCredentials() + assert.Empty(t, creds.Contexts, "no context should be stored when browser login fails") +} diff --git a/internal/cli/auth_test.go b/internal/cli/auth_test.go index 9c898d3..9b01157 100644 --- a/internal/cli/auth_test.go +++ b/internal/cli/auth_test.go @@ -211,7 +211,8 @@ func TestAuthLoginCreatesContextAndSetsActive(t *testing.T) { assert.Equal(t, "alice@example.com", ctx.User) } assert.Equal(t, "production", creds.ActiveContext) - assert.Contains(t, stderr.String(), "Created context") + assert.Contains(t, stderr.String(), "You're now logged in to DNSimple as alice@example.com") + assert.Contains(t, stderr.String(), "is now active") } func TestAuthLoginWithSandboxFlagCreatesSandboxContext(t *testing.T) { diff --git a/internal/config/config.go b/internal/config/config.go index 47c36bc..0a8bf1a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "strings" "github.com/spf13/viper" ) @@ -14,6 +15,12 @@ const ( configDirName = "dnsimple" configFileName = "config" credentialsFileName = "credentials" + + // baseURLEnvVar overrides the API base URL for local development against + // a local API server. Consulted by both Load/SetSandbox (the + // pre-auth `auth login` path) and Resolve (authenticated commands) so the + // override moves every API call, including the OAuth token endpoint. + baseURLEnvVar = "DNSIMPLE_BASE_URL" ) // Config holds the CLI configuration. @@ -31,6 +38,10 @@ type Config struct { // PerPage is the default number of items per page for list commands. PerPage int + + // OAuthLogin opts `auth login` into the interactive browser flow; off by + // default during the dark-launch rollout (see --web / DNSIMPLE_OAUTH_LOGIN). + OAuthLogin bool } // Dir returns the configuration directory path. @@ -60,6 +71,7 @@ func Load() (*Config, error) { v.SetDefault("sandbox", false) v.SetDefault("per_page", defaultPerPage) v.SetDefault("default_account", "") + v.SetDefault("oauth_login", false) // Config file is optional if err := v.ReadInConfig(); err != nil { @@ -76,13 +88,10 @@ func Load() (*Config, error) { Sandbox: v.GetBool("sandbox"), DefaultAccount: v.GetString("default_account"), PerPage: v.GetInt("per_page"), + OAuthLogin: v.GetBool("oauth_login"), } - if cfg.Sandbox { - cfg.BaseURL = sandboxBaseURL - } else { - cfg.BaseURL = defaultBaseURL - } + cfg.BaseURL = baseURLForSandbox(cfg.Sandbox) return cfg, nil } @@ -90,11 +99,22 @@ func Load() (*Config, error) { // SetSandbox overrides the sandbox setting (from --sandbox flag). func (c *Config) SetSandbox(sandbox bool) { c.Sandbox = sandbox + c.BaseURL = baseURLForSandbox(sandbox) +} + +// baseURLForSandbox returns the API base URL for the given environment, +// honoring DNSIMPLE_BASE_URL when set. The env override wins over the +// production/sandbox default, mirroring the precedence in Resolve so the +// `auth login` flow (which builds a Config, not a ResolvedContext, because no +// token exists yet) targets the same local host as authenticated commands. +func baseURLForSandbox(sandbox bool) string { + if v := strings.TrimSpace(os.Getenv(baseURLEnvVar)); v != "" { + return v + } if sandbox { - c.BaseURL = sandboxBaseURL - } else { - c.BaseURL = defaultBaseURL + return sandboxBaseURL } + return defaultBaseURL } // Save writes the current configuration to disk. @@ -111,6 +131,7 @@ func (c *Config) Save() error { c.v.Set("sandbox", c.Sandbox) c.v.Set("default_account", c.DefaultAccount) c.v.Set("per_page", c.PerPage) + c.v.Set("oauth_login", c.OAuthLogin) return c.v.WriteConfigAs(filepath.Join(dir, configFileName+".yml")) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c671208..bda548e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -26,6 +26,7 @@ func TestLoadDefaults(t *testing.T) { assert.Equal(t, defaultBaseURL, cfg.BaseURL) assert.Empty(t, cfg.DefaultAccount) assert.Equal(t, defaultPerPage, cfg.PerPage) + assert.False(t, cfg.OAuthLogin) } func TestLoadFromEnvironment(t *testing.T) { @@ -33,6 +34,7 @@ func TestLoadFromEnvironment(t *testing.T) { t.Setenv("DNSIMPLE_SANDBOX", "true") t.Setenv("DNSIMPLE_DEFAULT_ACCOUNT", "1010") t.Setenv("DNSIMPLE_PER_PAGE", "75") + t.Setenv("DNSIMPLE_OAUTH_LOGIN", "true") cfg, err := Load() if !assert.NoError(t, err) { @@ -43,6 +45,7 @@ func TestLoadFromEnvironment(t *testing.T) { assert.Equal(t, sandboxBaseURL, cfg.BaseURL) assert.Equal(t, "1010", cfg.DefaultAccount) assert.Equal(t, 75, cfg.PerPage) + assert.True(t, cfg.OAuthLogin) } func TestSaveAndReload(t *testing.T) { @@ -56,6 +59,7 @@ func TestSaveAndReload(t *testing.T) { cfg.SetSandbox(true) cfg.DefaultAccount = "2020" cfg.PerPage = 99 + cfg.OAuthLogin = true if !assert.NoError(t, cfg.Save()) { return @@ -70,6 +74,7 @@ func TestSaveAndReload(t *testing.T) { assert.Equal(t, sandboxBaseURL, reloaded.BaseURL) assert.Equal(t, "2020", reloaded.DefaultAccount) assert.Equal(t, 99, reloaded.PerPage) + assert.True(t, reloaded.OAuthLogin) } func TestSetSandboxUpdatesBaseURL(t *testing.T) { diff --git a/internal/config/oauth.go b/internal/config/oauth.go new file mode 100644 index 0000000..fec2e72 --- /dev/null +++ b/internal/config/oauth.go @@ -0,0 +1,84 @@ +package config + +import ( + "os" + "strings" +) + +// OAuth client identifiers for the first-party DNSimple CLI application, one +// per environment. Public OAuth client identifiers are not secrets, so +// embedding them in the binary is standard practice (gh, gcloud, stripe all +// do this). An empty value makes OAuthClientID return "", so `auth login` +// falls back to the paste prompt instead of starting the browser flow. +const ( + oauthClientIDProduction = "902c56782e81c768" + oauthClientIDSandbox = "46bbb707b411f835" +) + +// Per-environment overrides consulted before the embedded constants. +// Useful when developing against a local API server where the CLI app has +// been bootstrapped with a different ID, or when alternating +// between sandbox and production in one shell session: a single shared +// override would route the wrong client ID into one of the two flows and +// produce an opaque `invalid_client` error. +const ( + oauthClientIDEnvVarProduction = "DNSIMPLE_OAUTH_CLIENT_ID_PRODUCTION" + oauthClientIDEnvVarSandbox = "DNSIMPLE_OAUTH_CLIENT_ID_SANDBOX" +) + +// OAuthClientID returns the OAuth client identifier to use for the given +// environment. The matching environment variable +// (DNSIMPLE_OAUTH_CLIENT_ID_PRODUCTION or _SANDBOX) takes precedence over +// the embedded constant. Leading/trailing whitespace on the env value is +// stripped before the empty check so a stray newline from a +// command-substitution export (e.g. `export ...=$(= 300 && resp.StatusCode < 400 { + return "", fmt.Errorf("token endpoint returned an unexpected redirect (HTTP %d to %q); refusing to replay credentials", + resp.StatusCode, resp.Header.Get("Location")) + } + + if resp.StatusCode >= 400 { + var er errorResponse + if json.Unmarshal(respBody, &er) == nil && er.Error != "" { + return "", newAuthError(er.Error, er.ErrorDescription) + } + return "", fmt.Errorf("token endpoint returned HTTP %d", resp.StatusCode) + } + + var tok tokenResponse + if err := json.Unmarshal(respBody, &tok); err != nil { + return "", fmt.Errorf("decode token response: %w", err) + } + if tok.AccessToken == "" { + return "", fmt.Errorf("token response did not include an access_token") + } + // token_type is REQUIRED per RFC 6749 §5.1 and tells the client how + // to use the token. The DNSimple API only accepts the access token + // as a Bearer credential, so a non-bearer response would be silently + // accepted today but rejected by every subsequent API call with an + // opaque 401. Match case-insensitively per the spec. + if !strings.EqualFold(tok.TokenType, "bearer") { + return "", fmt.Errorf("token response contained unsupported token_type %q", tok.TokenType) + } + return tok.AccessToken, nil +} + +func (c *Client) stderr() io.Writer { + if c.Stderr == nil { + return io.Discard + } + return c.Stderr +} diff --git a/internal/oauth/client_test.go b/internal/oauth/client_test.go new file mode 100644 index 0000000..e15788e --- /dev/null +++ b/internal/oauth/client_test.go @@ -0,0 +1,368 @@ +package oauth + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// openerFollowingAuthorize returns a BrowserOpener stub that simulates the +// user reaching the authorize endpoint and consenting: it parses the +// authorize URL, extracts the redirect_uri and state, and fires the +// callback request with the supplied code asynchronously. +// +// recordedAuthURL captures the URL the Client tried to open so tests can +// assert on the query string. +func openerFollowingAuthorize(t *testing.T, code string, recordedAuthURL *string) func(string) error { + t.Helper() + return func(authURL string) error { + if recordedAuthURL != nil { + *recordedAuthURL = authURL + } + u, err := url.Parse(authURL) + if err != nil { + return err + } + q := u.Query() + state := q.Get("state") + redirectURI := q.Get("redirect_uri") + + go func() { + resp, err := http.Get(redirectURI + "?" + url.Values{ + "code": {code}, + "state": {state}, + }.Encode()) + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + return nil + } +} + +func TestClientLoginHappyPath(t *testing.T) { + var receivedBody tokenRequest + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + body, _ := io.ReadAll(r.Body) + if err := json.Unmarshal(body, &receivedBody); err != nil { + t.Errorf("decode body: %v", err) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"the-token","token_type":"bearer","scope":null,"account_id":981}`) + })) + defer tokenServer.Close() + + var authURL string + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: tokenServer.URL, + BrowserOpener: openerFollowingAuthorize(t, "fake-code", &authURL), + Stderr: io.Discard, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + token, err := c.Login(ctx) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "the-token", token) + + // Authorize URL must carry all PKCE + loopback parameters. + u, _ := url.Parse(authURL) + q := u.Query() + assert.Equal(t, "code", q.Get("response_type")) + assert.Equal(t, "client-abc", q.Get("client_id")) + assert.NotEmpty(t, q.Get("state")) + assert.NotEmpty(t, q.Get("code_challenge")) + assert.Equal(t, "S256", q.Get("code_challenge_method")) + assert.True(t, strings.HasPrefix(q.Get("redirect_uri"), "http://127.0.0.1:")) + + // Token request body must carry the verifier (not a client_secret) and + // the same redirect_uri. + assert.Equal(t, "authorization_code", receivedBody.GrantType) + assert.Equal(t, "client-abc", receivedBody.ClientID) + assert.Equal(t, "fake-code", receivedBody.Code) + assert.NotEmpty(t, receivedBody.CodeVerifier) + assert.Equal(t, q.Get("redirect_uri"), receivedBody.RedirectURI) +} + +func TestClientLoginSurfacesTokenEndpointErrorResponse(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `{"error":"invalid_grant","error_description":"bad verifier"}`) + })) + defer tokenServer.Close() + + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: tokenServer.URL, + BrowserOpener: openerFollowingAuthorize(t, "fake-code", nil), + Stderr: io.Discard, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := c.Login(ctx) + var ae *AuthError + if !assert.ErrorAs(t, err, &ae) { + return + } + assert.Equal(t, "invalid_grant", ae.Code) + assert.Equal(t, "bad verifier", ae.Description) +} + +func TestClientLoginSurfacesAuthorizationDenied(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Errorf("token endpoint should not be called when authorize fails") + })) + defer tokenServer.Close() + + // Opener that simulates the user clicking Deny. + opener := func(authURL string) error { + u, _ := url.Parse(authURL) + redirect := u.Query().Get("redirect_uri") + state := u.Query().Get("state") + go func() { + resp, err := http.Get(redirect + "?" + url.Values{ + "error": {"access_denied"}, + "error_description": {"user cancelled"}, + "state": {state}, + }.Encode()) + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + return nil + } + + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: tokenServer.URL, + BrowserOpener: opener, + Stderr: io.Discard, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := c.Login(ctx) + var ae *AuthError + if !assert.ErrorAs(t, err, &ae) { + return + } + assert.Equal(t, "access_denied", ae.Code) +} + +func TestClientLoginRefusesEmptyClientID(t *testing.T) { + c := &Client{ + ClientID: "", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: "https://example.test/v2/oauth/access_token", + BrowserOpener: func(string) error { return nil }, + Stderr: io.Discard, + } + _, err := c.Login(context.Background()) + assert.ErrorIs(t, err, ErrNotProvisioned) +} + +func TestClientLoginPrintsURLWhenBrowserCannotBeOpened(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"the-token","token_type":"bearer","scope":null,"account_id":981}`) + })) + defer tokenServer.Close() + + // Opener that errors. We still need a callback to fire, otherwise the + // listener times out. Fire it from inside the opener even on error. + opener := func(authURL string) error { + u, _ := url.Parse(authURL) + redirect := u.Query().Get("redirect_uri") + state := u.Query().Get("state") + go func() { + resp, _ := http.Get(redirect + "?" + url.Values{ + "code": {"fake-code"}, + "state": {state}, + }.Encode()) + if resp != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + return fmt.Errorf("could not launch browser") + } + + var stderr bytes.Buffer + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: tokenServer.URL, + BrowserOpener: opener, + Stderr: &stderr, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + token, err := c.Login(ctx) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "the-token", token) + assert.Contains(t, stderr.String(), "Could not open a browser") + assert.Contains(t, stderr.String(), "https://example.test/oauth/authorize") +} + +// TestClientLoginRefusesTokenEndpointRedirect pins the defense against a +// 307/308 from the token endpoint replaying the POST body -- containing +// the code and PKCE verifier -- to a redirect target. The redirect +// target server here intentionally records every request so the test +// can assert it is never reached. +func TestClientLoginRefusesTokenEndpointRedirect(t *testing.T) { + var redirectTargetHits int32 + redirectTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&redirectTargetHits, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"would-have-been-leaked","token_type":"bearer"}`) + })) + defer redirectTarget.Close() + + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 307 preserves method + body on resend; this is the exact case + // the CheckRedirect guard protects against. + http.Redirect(w, r, redirectTarget.URL, http.StatusTemporaryRedirect) + })) + defer tokenServer.Close() + + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: tokenServer.URL, + BrowserOpener: openerFollowingAuthorize(t, "fake-code", nil), + Stderr: io.Discard, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := c.Login(ctx) + if !assert.Error(t, err) { + return + } + assert.Contains(t, err.Error(), "unexpected redirect") + assert.Equal(t, int32(0), atomic.LoadInt32(&redirectTargetHits), + "redirect target must not receive the POST body") +} + +func TestClientLoginRejectsNonBearerTokenType(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"x","token_type":"mac","scope":null,"account_id":981}`) + })) + defer tokenServer.Close() + + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: tokenServer.URL, + BrowserOpener: openerFollowingAuthorize(t, "fake-code", nil), + Stderr: io.Discard, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := c.Login(ctx) + if !assert.Error(t, err) { + return + } + assert.Contains(t, err.Error(), "unsupported token_type") + assert.Contains(t, err.Error(), "mac") +} + +// TestClientLoginAcceptsBearerCaseInsensitively pins the case-insensitive +// match required by RFC 6749 §5.1 ("the token type is a string ... +// case insensitive"). +func TestClientLoginAcceptsBearerCaseInsensitively(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"the-token","token_type":"Bearer","scope":null,"account_id":981}`) + })) + defer tokenServer.Close() + + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: tokenServer.URL, + BrowserOpener: openerFollowingAuthorize(t, "fake-code", nil), + Stderr: io.Discard, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + token, err := c.Login(ctx) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "the-token", token) +} + +// TestClientLoginPropagatesContextCancellation pins that a cancelled +// parent context produces a context.Canceled error, NOT the "timed out" +// message used for deadline expiry. Callers can errors.Is on the +// result to decide whether to surface the cancellation or treat it as +// a transient failure. +func TestClientLoginPropagatesContextCancellation(t *testing.T) { + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: "https://example.test/v2/oauth/access_token", + // BrowserOpener that never delivers a callback; the deadline is + // generous so the cancellation, not the timeout, is the trigger. + BrowserOpener: func(string) error { return nil }, + Stderr: io.Discard, + Deadline: 30 * time.Second, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel before Login even runs + _, err := c.Login(ctx) + if !assert.Error(t, err) { + return + } + assert.NotContains(t, err.Error(), "timed out", + "cancellation must not be reported as a timeout") + assert.ErrorIs(t, err, context.Canceled) +} + +func TestClientLoginTimesOutWhenNoCallback(t *testing.T) { + c := &Client{ + ClientID: "client-abc", + AuthorizeBase: "https://example.test/oauth/authorize", + TokenURL: "https://example.test/v2/oauth/access_token", + // Opener that does nothing: no callback ever fires. + BrowserOpener: func(string) error { return nil }, + Stderr: io.Discard, + Deadline: 150 * time.Millisecond, + } + _, err := c.Login(context.Background()) + if !assert.Error(t, err) { + return + } + assert.Contains(t, err.Error(), "timed out") +} diff --git a/internal/oauth/errors.go b/internal/oauth/errors.go new file mode 100644 index 0000000..73e96fc --- /dev/null +++ b/internal/oauth/errors.go @@ -0,0 +1,76 @@ +// Package oauth implements the interactive OAuth 2.0 Authorization Code +// flow with PKCE (RFC 7636) and a loopback redirect (RFC 8252 §7.3) for +// `dnsimple auth login`. +// +// The flow is: +// +// 1. Generate a PKCE verifier and SHA-256 challenge. +// 2. Generate a random state. +// 3. Bind 127.0.0.1 on an OS-assigned port and listen for the callback. +// 4. Open the user's browser to the authorize URL. +// 5. Receive ?code=…&state=… on the loopback, validate state, exchange +// code + verifier at the token endpoint for an access token. +// +// The token endpoint accepts JSON and does not require a client_secret +// for public clients (PKCE replaces the secret). +package oauth + +import "errors" + +// ErrNotProvisioned is returned by Login when the caller did not supply a +// client ID. Callers (typically the CLI `auth login` command) treat this as +// "OAuth is not enabled for this build", and fall back to the manual token +// paste prompt. +var ErrNotProvisioned = errors.New("oauth client id not configured") + +// ErrStateMismatch is returned when the state parameter on the callback +// does not match the one generated for the request. It indicates either +// a stale browser tab or an attempt to inject a forged callback. +var ErrStateMismatch = errors.New("oauth: state mismatch on callback") + +// AuthError carries an RFC 6749 §4.1.2.1 or §5.2 error response from the +// authorization server. The CLI surfaces both Code and Description so the +// user sees a useful message ("access_denied: user cancelled the request"). +// +// Construct via newAuthError so attacker-controlled bytes are stripped +// before they can reach the terminal via Error(). +type AuthError struct { + Code string + Description string +} + +// newAuthError builds an AuthError, stripping C0 control bytes and DEL +// from both fields. The values flow from either an OAuth-server JSON +// body or a query-string parameter on the loopback redirect, and either +// source can carry attacker-controlled content -- left raw they would +// emit ANSI escape sequences when the CLI prints err.Error() to the +// terminal. The HTML error page goes through htmlEscape; this is the +// terminal-side equivalent. +func newAuthError(code, description string) *AuthError { + return &AuthError{ + Code: sanitizeForTerminal(code), + Description: sanitizeForTerminal(description), + } +} + +func (e *AuthError) Error() string { + if e.Description != "" { + return e.Code + ": " + e.Description + } + return e.Code +} + +// sanitizeForTerminal drops every C0 control byte (0x00-0x1F) and DEL +// (0x7F) from s. UTF-8 continuation and lead bytes are all >= 0x80, so +// non-ASCII text passes through unchanged. +func sanitizeForTerminal(s string) string { + out := make([]byte, 0, len(s)) + for i := 0; i < len(s); i++ { + b := s[i] + if b < 0x20 || b == 0x7f { + continue + } + out = append(out, b) + } + return string(out) +} diff --git a/internal/oauth/errors_test.go b/internal/oauth/errors_test.go new file mode 100644 index 0000000..09baaaa --- /dev/null +++ b/internal/oauth/errors_test.go @@ -0,0 +1,45 @@ +package oauth + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSanitizeForTerminalStripsC0AndDEL(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"ansi color", "\x1b[31mred\x1b[0m", "[31mred[0m"}, + {"bel and cr", "alert\x07\rback", "alertback"}, + {"del", "abc\x7fdef", "abcdef"}, + {"plain ascii passes through", "no escapes here", "no escapes here"}, + {"utf-8 multibyte preserved", "café — résumé", "café — résumé"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, sanitizeForTerminal(tc.in)) + }) + } +} + +// TestNewAuthErrorStripsControlBytes guards the boundary where attacker- +// controlled query-string / JSON content lands in AuthError. Error() must +// not expose any byte that a terminal would interpret as a control or +// escape sequence. +func TestNewAuthErrorStripsControlBytes(t *testing.T) { + err := newAuthError("\x1b[31mFAKE\x07", "evil\x1b[2Jcleared") + assert.Equal(t, "[31mFAKE", err.Code) + assert.Equal(t, "evil[2Jcleared", err.Description) + + msg := err.Error() + for _, b := range []byte(msg) { + assert.Falsef(t, b < 0x20 || b == 0x7f, "AuthError.Error() leaked control byte 0x%02x", b) + } + // And the human-readable join still works. + assert.True(t, strings.Contains(msg, "FAKE")) + assert.True(t, strings.Contains(msg, "cleared")) +} diff --git a/internal/oauth/listener.go b/internal/oauth/listener.go new file mode 100644 index 0000000..5ff11bd --- /dev/null +++ b/internal/oauth/listener.go @@ -0,0 +1,251 @@ +package oauth + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" +) + +// Loopback server timeouts. The loopback only ever serves one fast GET +// from the local browser, so the limits are deliberately tight: a stalled +// connection (security scanner, slowloris-style local probe, browser +// pre-fetcher that opens but never sends) would otherwise pin the +// listener until the outer 5-minute flow deadline fires, and would block +// Shutdown indefinitely on close. +const ( + loopbackReadHeaderTimeout = 5 * time.Second + loopbackReadTimeout = 10 * time.Second + loopbackWriteTimeout = 10 * time.Second + loopbackShutdownTimeout = 5 * time.Second +) + +// callbackResult is the outcome of the OAuth redirect: either a code + state +// or an authorization-server error. State is included so the caller (Client) +// can perform a defense-in-depth check, even though the listener also +// validates it before sending the result. +type callbackResult struct { + code string + state string + err error +} + +// loopback owns the local HTTP listener that catches the OAuth redirect. +// One loopback corresponds to one login attempt: it accepts a single valid +// callback and is then shut down. +type loopback struct { + listener net.Listener + server *http.Server + port int + redirectURL string + expectedState string + + result chan callbackResult + once sync.Once +} + +// startLoopback binds 127.0.0.1 on an OS-assigned port and serves the OAuth +// redirect endpoint at /callback. The handler is single-shot: the first +// valid callback (or the first error) is delivered to the result channel, +// and subsequent requests are ignored. +// +// The expectedState parameter is the state value we sent on the authorize +// step; it is validated against the callback's `state` query parameter. +// A mismatch is reported via the result channel as ErrStateMismatch. +func startLoopback(expectedState string) (*loopback, error) { + // 127.0.0.1 (not 0.0.0.0, not "localhost"). "localhost" can resolve to + // ::1 on systems that prefer IPv6, which would break the redirect + // matcher on the server side: the registered URI is the IPv4 literal. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("bind loopback listener: %w", err) + } + + port := listener.Addr().(*net.TCPAddr).Port + lb := &loopback{ + listener: listener, + port: port, + redirectURL: fmt.Sprintf("http://127.0.0.1:%d/callback", port), + expectedState: expectedState, + result: make(chan callbackResult, 1), + } + + mux := http.NewServeMux() + mux.HandleFunc("/callback", lb.handleCallback) + // Any other path is not part of the OAuth dance. Don't reveal that this + // is a CLI listener; just 404. + mux.HandleFunc("/", http.NotFound) + + lb.server = &http.Server{ + Handler: mux, + ReadHeaderTimeout: loopbackReadHeaderTimeout, + ReadTimeout: loopbackReadTimeout, + WriteTimeout: loopbackWriteTimeout, + } + go func() { + // Serve returns http.ErrServerClosed on shutdown, which is the + // normal exit path. Anything else means the listener died before + // we got a callback (e.g., FD exhaustion, the listener was + // closed externally) -- surface that via the result channel so + // await fails fast instead of timing out at the outer deadline. + // deliver() is single-shot, so if a real callback was already + // delivered this is a no-op. + if err := lb.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + lb.deliver(callbackResult{err: fmt.Errorf("oauth: loopback listener died: %w", err)}) + } + }() + + return lb, nil +} + +// await blocks until the callback fires or ctx is cancelled / times out. +func (l *loopback) await(ctx context.Context) (string, string, error) { + select { + case r := <-l.result: + return r.code, r.state, r.err + case <-ctx.Done(): + return "", "", ctx.Err() + } +} + +// close shuts down the listener. Safe to call multiple times. +func (l *loopback) close() { + // Bounded shutdown deadline so a hung in-flight handler cannot pin + // the CLI after the OAuth flow has otherwise completed (e.g. a local + // connection that never sends bytes -- the request-level read + // timeouts cover the read side, but Shutdown still waits for those + // timeouts to fire and we don't want to inherit their full duration + // when we already have what we need). + ctx, cancel := context.WithTimeout(context.Background(), loopbackShutdownTimeout) + defer cancel() + _ = l.server.Shutdown(ctx) +} + +// handleCallback parses the OAuth redirect query, renders a small HTML +// status page so the browser tab is not left blank, and pushes exactly one +// result into the result channel. +func (l *loopback) handleCallback(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + + q := r.URL.Query() + errCode := q.Get("error") + errDesc := q.Get("error_description") + code := q.Get("code") + state := q.Get("state") + + // Validate state before branching on anything else, including the + // `error` response shape. Otherwise a forged callback (e.g. an + // on a page the user is also viewing) can deliver an + // attacker-chosen ?error=... result to the in-flight login without + // holding a valid state token -- exactly the attack the state + // parameter exists to defend against (RFC 6749 §10.12, OAuth 2.0 + // Security BCP §4.7). + if state != l.expectedState { + renderError(w, "state_mismatch", "The login flow could not be verified. Please run `dnsimple auth login` again.") + l.deliver(callbackResult{err: ErrStateMismatch}) + return + } + + switch { + case errCode != "": + renderError(w, errCode, errDesc) + l.deliver(callbackResult{err: newAuthError(errCode, errDesc)}) + case code == "": + renderError(w, "invalid_callback", "The authorization response did not include a code.") + l.deliver(callbackResult{err: fmt.Errorf("oauth: callback missing code")}) + default: + renderSuccess(w) + l.deliver(callbackResult{code: code, state: state}) + } +} + +// deliver sends a result exactly once, ignoring duplicates if the browser +// fires multiple requests (e.g. through a prefetcher). +func (l *loopback) deliver(r callbackResult) { + l.once.Do(func() { l.result <- r }) +} + +// renderSuccess writes a minimal HTML page telling the user the CLI now has +// what it needs and the tab can be closed. +func renderSuccess(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, successHTML) +} + +// renderError writes a minimal HTML error page including the OAuth error +// code and description from the authorization server. +func renderError(w http.ResponseWriter, code, description string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, errorHTMLFmt, htmlEscape(code), htmlEscape(description)) +} + +// htmlEscape is a tiny inline escaper to avoid a dependency on html/template +// for these two short pages. It covers the characters that matter when the +// content is OAuth error codes and free-form descriptions. +func htmlEscape(s string) string { + r := make([]byte, 0, len(s)) + for i := 0; i < len(s); i++ { + switch s[i] { + case '&': + r = append(r, []byte("&")...) + case '<': + r = append(r, []byte("<")...) + case '>': + r = append(r, []byte(">")...) + case '"': + r = append(r, []byte(""")...) + case '\'': + r = append(r, []byte("'")...) + default: + r = append(r, s[i]) + } + } + return string(r) +} + +const successHTML = ` + + + +DNSimple CLI: signed in + + + +

You are signed in to the DNSimple CLI.

+

You can close this tab and return to your terminal.

+ + +` + +const errorHTMLFmt = ` + + + +DNSimple CLI: login failed + + + +

DNSimple CLI login did not complete.

+

Error: %s

+

%s

+

Return to your terminal for further details and try again.

+ + +` diff --git a/internal/oauth/listener_test.go b/internal/oauth/listener_test.go new file mode 100644 index 0000000..a650b71 --- /dev/null +++ b/internal/oauth/listener_test.go @@ -0,0 +1,222 @@ +package oauth + +import ( + "context" + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLoopbackHappyPath(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + defer lb.close() + + go func() { + resp, err := http.Get(lb.redirectURL + "?code=abc&state=expected-state") + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + code, state, err := lb.await(ctx) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "abc", code) + assert.Equal(t, "expected-state", state) +} + +func TestLoopbackReturnsAuthErrorWhenErrorParamPresent(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + defer lb.close() + + go func() { + u := lb.redirectURL + "?" + url.Values{ + "error": {"access_denied"}, + "error_description": {"user said no"}, + "state": {"expected-state"}, + }.Encode() + resp, err := http.Get(u) + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, _, err = lb.await(ctx) + var ae *AuthError + if !assert.ErrorAs(t, err, &ae) { + return + } + assert.Equal(t, "access_denied", ae.Code) + assert.Equal(t, "user said no", ae.Description) +} + +func TestLoopbackRejectsStateMismatch(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + defer lb.close() + + go func() { + resp, err := http.Get(lb.redirectURL + "?code=abc&state=wrong-state") + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, _, err = lb.await(ctx) + assert.ErrorIs(t, err, ErrStateMismatch) +} + +// TestLoopbackRejectsForgedErrorWithoutState pins the fix for the CSRF +// vulnerability where a forged ?error=... callback (with no state, or a +// wrong state) was accepted because the error branch ran before the +// state check. State must be validated first, even on the error path. +func TestLoopbackRejectsForgedErrorWithoutState(t *testing.T) { + cases := []struct { + name string + query string + }{ + {name: "no state", query: "?error=access_denied&error_description=spoof"}, + {name: "wrong state", query: "?error=access_denied&error_description=spoof&state=forged"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + defer lb.close() + + go func() { + resp, err := http.Get(lb.redirectURL + tc.query) + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, _, err = lb.await(ctx) + assert.ErrorIs(t, err, ErrStateMismatch) + + // And specifically, the AuthError path must NOT be reached + // for these forged inputs. + var ae *AuthError + assert.False(t, errors.As(err, &ae), "forged error must not surface as AuthError") + }) + } +} + +// TestLoopbackSurfacesServeErrors verifies that a Serve failure (other +// than the normal ErrServerClosed) is delivered through the result +// channel instead of being silently dropped, so callers see the real +// cause rather than timing out at the outer 5-minute deadline. +func TestLoopbackSurfacesServeErrors(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + + // Close the listener out from under Serve. Serve's Accept loop + // returns the wrapped "use of closed network connection" error, + // which is NOT http.ErrServerClosed. + _ = lb.listener.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, _, err = lb.await(ctx) + if !assert.Error(t, err) { + return + } + assert.Contains(t, err.Error(), "loopback listener died") +} + +func TestLoopbackTimesOutWhenNoCallback(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + defer lb.close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, _, err = lb.await(ctx) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) +} + +// TestLoopbackConfiguresHTTPServerTimeouts pins the slowloris-defense +// configuration on the embedded http.Server. If a future refactor drops +// any of these timeouts, a stalled local connection can hold the +// listener open and block Shutdown until the outer 5-minute flow +// deadline fires. +func TestLoopbackConfiguresHTTPServerTimeouts(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + defer lb.close() + + assert.Greater(t, lb.server.ReadHeaderTimeout, time.Duration(0), "ReadHeaderTimeout must be set") + assert.Greater(t, lb.server.ReadTimeout, time.Duration(0), "ReadTimeout must be set") + assert.Greater(t, lb.server.WriteTimeout, time.Duration(0), "WriteTimeout must be set") +} + +func TestLoopbackIgnoresNonCallbackPaths(t *testing.T) { + lb, err := startLoopback("expected-state") + if !assert.NoError(t, err) { + return + } + defer lb.close() + + // A probe to /, /favicon.ico, /anything must not consume the single-shot + // result. The actual callback fired afterwards should be delivered. + for _, path := range []string{"/", "/favicon.ico", "/probe"} { + base := strings.TrimSuffix(lb.redirectURL, "/callback") + resp, err := http.Get(base + path) + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + } + + go func() { + resp, err := http.Get(lb.redirectURL + "?code=abc&state=expected-state") + if err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + code, _, err := lb.await(ctx) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "abc", code) +} diff --git a/internal/oauth/pkce.go b/internal/oauth/pkce.go new file mode 100644 index 0000000..1808e4f --- /dev/null +++ b/internal/oauth/pkce.go @@ -0,0 +1,30 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" +) + +// verifierBytes is the entropy size for the PKCE code verifier. 32 random +// bytes encode to 43 unreserved base64url characters, which sits at the low +// end of RFC 7636 §4.1's [43, 128] range and gives ~256 bits of entropy. +const verifierBytes = 32 + +// newVerifier returns a fresh PKCE code verifier per RFC 7636 §4.1: a +// base64url-encoded random byte string with no padding. +func newVerifier() (string, error) { + buf := make([]byte, verifierBytes) + if _, err := rand.Read(buf); err != nil { + return "", fmt.Errorf("read random bytes for verifier: %w", err) + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +// challenge returns the S256 code challenge for the given verifier per RFC +// 7636 §4.2: base64url(SHA-256(verifier)) with no padding. +func challenge(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} diff --git a/internal/oauth/pkce_test.go b/internal/oauth/pkce_test.go new file mode 100644 index 0000000..c2e74f6 --- /dev/null +++ b/internal/oauth/pkce_test.go @@ -0,0 +1,66 @@ +package oauth + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestChallengeRFC7636Vector pins the S256 derivation against the published +// vector from RFC 7636 §4.6. Any drift here means the implementation no +// longer interoperates with the server's PKCE check. +func TestChallengeRFC7636Vector(t *testing.T) { + const ( + verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + ) + got := challengeFor(verifier) + assert.Equal(t, challenge, got) +} + +// challengeFor is a one-line trampoline so the test name reads naturally +// alongside the unexported `challenge` function. +func challengeFor(v string) string { return challenge(v) } + +func TestNewVerifierIsBase64URLAndLongEnough(t *testing.T) { + v, err := newVerifier() + if !assert.NoError(t, err) { + return + } + // 32 bytes base64url-no-padding-encoded → 43 characters. + assert.Len(t, v, 43) + // Only RFC 7636 §4.1 unreserved chars: ALPHA / DIGIT / "-" / "." / "_" / "~". + // base64url uses ALPHA / DIGIT / "-" / "_"; that's a strict subset. + for _, r := range v { + isAllowed := (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '-' || r == '_' + assert.Truef(t, isAllowed, "verifier contains disallowed char %q", r) + } +} + +func TestNewVerifierIsRandom(t *testing.T) { + a, err := newVerifier() + if !assert.NoError(t, err) { + return + } + b, err := newVerifier() + if !assert.NoError(t, err) { + return + } + assert.NotEqual(t, a, b) +} + +func TestNewStateIsRandomAndBase64URL(t *testing.T) { + a, err := newState() + if !assert.NoError(t, err) { + return + } + b, err := newState() + if !assert.NoError(t, err) { + return + } + assert.NotEqual(t, a, b) + assert.Len(t, a, 43) +} diff --git a/internal/oauth/state.go b/internal/oauth/state.go new file mode 100644 index 0000000..10d1771 --- /dev/null +++ b/internal/oauth/state.go @@ -0,0 +1,22 @@ +package oauth + +import ( + "crypto/rand" + "encoding/base64" + "fmt" +) + +// stateBytes is the entropy size for the OAuth `state` parameter. 32 bytes +// (~256 bits) is well above the typical recommendation for CSRF protection. +const stateBytes = 32 + +// newState returns a fresh OAuth state token: a base64url-encoded random +// byte string with no padding. The CLI generates one per login attempt and +// validates it on the callback to detect forged or stale redirects. +func newState() (string, error) { + buf := make([]byte, stateBytes) + if _, err := rand.Read(buf); err != nil { + return "", fmt.Errorf("read random bytes for state: %w", err) + } + return base64.RawURLEncoding.EncodeToString(buf), nil +}