Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"os"

retry "github.com/appleboy/go-httpretry"
"github.com/go-authgate/cli/tui"
Expand Down Expand Up @@ -94,7 +95,7 @@ func refreshAccessToken(
}

if err := cfg.Store.Save(cfg.ClientID, *storage); err != nil {
fmt.Printf("Warning: Failed to save refreshed tokens: %v\n", err)
fmt.Fprintf(os.Stderr, "Warning: Failed to save refreshed tokens: %v\n", err)
}
return storage, nil
}
Expand Down
25 changes: 15 additions & 10 deletions browser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@ import (
"runtime"
)

// openBrowser attempts to open url in the user's default browser.
// Returns an error if launching the browser fails, but callers should
// always print the URL as a fallback regardless of the error.
func openBrowser(ctx context.Context, url string) error {
var cmd *exec.Cmd

switch runtime.GOOS {
// browserCommand returns the executable name and arguments for opening a URL
// on the given OS. This is extracted for testability.
func browserCommand(goos, url string) (name string, args []string) {
switch goos {
case "darwin":
cmd = exec.CommandContext(ctx, "open", url)
return "open", []string{url}
case "windows":
cmd = exec.CommandContext(ctx, "cmd", "/c", "start", url)
return "rundll32", []string{"url.dll,FileProtocolHandler", url}
default:
cmd = exec.CommandContext(ctx, "xdg-open", url)
return "xdg-open", []string{url}
}
}

// openBrowser attempts to open url in the user's default browser.
// Returns an error if launching the browser fails, but callers should
// always print the URL as a fallback regardless of the error.
func openBrowser(ctx context.Context, url string) error {
name, args := browserCommand(runtime.GOOS, url)
cmd := exec.CommandContext(ctx, name, args...)

if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to open browser: %w", err)
Expand Down
57 changes: 57 additions & 0 deletions browser_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import "testing"

func TestBrowserCommand(t *testing.T) {
tests := []struct {
goos string
url string
wantName string
wantArgs []string
}{
{
goos: "darwin",
url: "https://example.com/auth",
wantName: "open",
wantArgs: []string{"https://example.com/auth"},
},
{
goos: "windows",
url: "https://example.com/auth",
wantName: "rundll32",
wantArgs: []string{"url.dll,FileProtocolHandler", "https://example.com/auth"},
},
{
goos: "linux",
url: "https://example.com/auth",
wantName: "xdg-open",
wantArgs: []string{"https://example.com/auth"},
},
{
goos: "windows",
url: "https://example.com/auth?foo=bar&baz=qux",
wantName: "rundll32",
wantArgs: []string{
"url.dll,FileProtocolHandler",
"https://example.com/auth?foo=bar&baz=qux",
},
},
}

for _, tt := range tests {
t.Run(tt.goos, func(t *testing.T) {
name, args := browserCommand(tt.goos, tt.url)
if name != tt.wantName {
t.Errorf("name: got %q, want %q", name, tt.wantName)
}
if len(args) != len(tt.wantArgs) {
t.Fatalf("args length: got %d, want %d", len(args), len(tt.wantArgs))
}
for i, arg := range args {
if arg != tt.wantArgs[i] {
t.Errorf("args[%d]: got %q, want %q", i, arg, tt.wantArgs[i])
}
}
})
}
}
3 changes: 2 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ func loadConfig() *AppConfig {
var err error
cfg.RetryClient, err = retry.NewBackgroundClient(retry.WithHTTPClient(baseHTTPClient))
if err != nil {
panic(fmt.Sprintf("failed to create retry client: %v", err))
fmt.Fprintf(os.Stderr, "Error: failed to create retry HTTP client: %v\n", err)
os.Exit(1)
}

// Resolve timeout configuration.
Expand Down
22 changes: 20 additions & 2 deletions token_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,14 @@ func revokeTokenOnServer(

if tok.RefreshToken != "" {
wg.Go(func() {
if err := doRevoke(ctx, cfg, revokeURL, tok.RefreshToken, timeout); err != nil {
if err := doRevoke(
ctx,
cfg,
revokeURL,
tok.RefreshToken,
"refresh_token",
timeout,
); err != nil {
mu.Lock()
refreshErr = err
mu.Unlock()
Expand All @@ -182,7 +189,14 @@ func revokeTokenOnServer(

if tok.AccessToken != "" {
wg.Go(func() {
if err := doRevoke(ctx, cfg, revokeURL, tok.AccessToken, timeout); err != nil {
if err := doRevoke(
ctx,
cfg,
revokeURL,
tok.AccessToken,
"access_token",
timeout,
); err != nil {
mu.Lock()
accessErr = err
mu.Unlock()
Expand Down Expand Up @@ -213,6 +227,7 @@ func doRevoke(
cfg *AppConfig,
revokeURL string,
token string,
tokenTypeHint string,
timeout time.Duration,
) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
Expand All @@ -222,6 +237,9 @@ func doRevoke(
"token": {token},
"client_id": {cfg.ClientID},
}
if tokenTypeHint != "" {
data.Set("token_type_hint", tokenTypeHint)
}
if !cfg.IsPublicClient() {
data.Set("client_secret", cfg.ClientSecret)
}
Expand Down
56 changes: 40 additions & 16 deletions token_cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ func TestRunTokenDelete(t *testing.T) {

func TestRunTokenDelete_ServerRevocation(t *testing.T) {
t.Run("successful revocation and local delete", func(t *testing.T) {
var revokedTokens []string
type revokeCall struct {
token string
tokenTypeHint string
}
var revokeCalls []revokeCall
var mu sync.Mutex
srv := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -115,7 +119,10 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {
return
}
mu.Lock()
revokedTokens = append(revokedTokens, r.FormValue("token"))
revokeCalls = append(revokeCalls, revokeCall{
token: r.FormValue("token"),
tokenTypeHint: r.FormValue("token_type_hint"),
})
mu.Unlock()
w.WriteHeader(http.StatusOK)
}),
Expand Down Expand Up @@ -163,16 +170,24 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {

mu.Lock()
defer mu.Unlock()
if len(revokedTokens) != 2 {
t.Fatalf("expected 2 revoke calls, got %d", len(revokedTokens))
if len(revokeCalls) != 2 {
t.Fatalf("expected 2 revoke calls, got %d", len(revokeCalls))
}
// Revocations run concurrently, so order is non-deterministic.
got := map[string]bool{revokedTokens[0]: true, revokedTokens[1]: true}
if !got["refresh-456"] {
t.Errorf("expected refresh token to be revoked, got %v", revokedTokens)
// Build a map from token to its type hint for assertion.
hintByToken := make(map[string]string, len(revokeCalls))
for _, c := range revokeCalls {
hintByToken[c.token] = c.tokenTypeHint
}
if hint, ok := hintByToken["refresh-456"]; !ok {
t.Errorf("expected refresh token to be revoked, got %v", revokeCalls)
} else if hint != "refresh_token" {
t.Errorf("refresh token_type_hint: got %q, want %q", hint, "refresh_token")
}
if !got["access-123"] {
t.Errorf("expected access token to be revoked, got %v", revokedTokens)
if hint, ok := hintByToken["access-123"]; !ok {
t.Errorf("expected access token to be revoked, got %v", revokeCalls)
} else if hint != "access_token" {
t.Errorf("access token_type_hint: got %q, want %q", hint, "access_token")
}
})

Expand Down Expand Up @@ -269,16 +284,22 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {
})

t.Run("only access token no refresh token", func(t *testing.T) {
var revokedTokens []string
var mu sync.Mutex
var (
callCount int
gotToken string
gotTokenTypeHint string
mu sync.Mutex
)
srv := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "bad form", http.StatusBadRequest)
return
}
mu.Lock()
revokedTokens = append(revokedTokens, r.FormValue("token"))
callCount++
gotToken = r.FormValue("token")
gotTokenTypeHint = r.FormValue("token_type_hint")
mu.Unlock()
w.WriteHeader(http.StatusOK)
}),
Expand Down Expand Up @@ -319,11 +340,14 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {

mu.Lock()
defer mu.Unlock()
if len(revokedTokens) != 1 {
t.Fatalf("expected 1 revoke call (access only), got %d", len(revokedTokens))
if callCount != 1 {
t.Fatalf("expected 1 revoke call (access only), got %d", callCount)
}
if gotToken != "access-only" {
t.Errorf("token: got %q, want %q", gotToken, "access-only")
}
if revokedTokens[0] != "access-only" {
t.Errorf("expected access token, got %q", revokedTokens[0])
if gotTokenTypeHint != "access_token" {
t.Errorf("token_type_hint: got %q, want %q", gotTokenTypeHint, "access_token")
}
})
}
Expand Down
Loading