Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions pkg/oauth2ac/oauth2ac.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ type credentialsProvider struct {
clientSecret string
secretError error

statesMu sync.RWMutex
authStates map[string]authState
statesMu sync.RWMutex
authStates map[string]authState
authStateTTL time.Duration
authStateCleanupInterval time.Duration
reauthorizeIfAuthorized bool

syncGate *util.SyncGate

Expand Down Expand Up @@ -190,12 +193,12 @@ type TokenStorage interface {
// - logger: Logger for logging credential operations
//
// Returns a credential provider and an error if creation fails
func NewCredentialsProvider(id string, cache cache.Cache, tokenStore TokenStorage, cfg *CredentialsConfig, logger logr.Logger) (*credentialsProvider, error) {
func NewCredentialsProvider(id string, cache cache.Cache, tokenStore TokenStorage, cfg *CredentialsConfig, logger logr.Logger, opts ...option.Option) (*credentialsProvider, error) {
if err := validateConfig(cfg); err != nil {
return nil, err
}

return &credentialsProvider{
cp := &credentialsProvider{
id: id,
cache: cache,
tokenStore: tokenStore,
Expand All @@ -209,7 +212,19 @@ func NewCredentialsProvider(id string, cache cache.Cache, tokenStore TokenStorag
authStates: map[string]authState{},
syncGate: util.NewSyncGate(),
pipe: eventbus.NewPipe[Credential](),
}, nil

authStateTTL: time.Minute * 5,
authStateCleanupInterval: time.Second * 5,
reauthorizeIfAuthorized: true,
}

for _, o := range opts {
if o, ok := isCredentialsProviderOption(o); ok {
o.Apply(cp)
}
}

return cp, nil
}

func (cp *credentialsProvider) ID() string {
Expand All @@ -233,9 +248,7 @@ func (cp *credentialsProvider) Start(ctx context.Context) (<-chan StatusEvent, e
return nil, errors.WithStack(ErrAlreadyRunning)
}

if err := cp.authorizeIfTokenExists(ctx); err != nil {
return nil, err
}
cp.authorizeIfTokenExists(ctx)

handlerStopFunc, err := cp.startInformer(ctx)
if err != nil {
Expand All @@ -245,7 +258,7 @@ func (cp *credentialsProvider) Start(ctx context.Context) (<-chan StatusEvent, e
cp.running.Store(true)

// periodic cleanup of expired states
go util.RunFuncAtInterval(ctx, time.Second*5, func(_ context.Context) error {
go util.RunFuncAtInterval(ctx, cp.authStateCleanupInterval, func(_ context.Context) error { //nolint:unparam
cp.cleanupExpiredStates()

return nil
Expand All @@ -270,7 +283,7 @@ func (cp *credentialsProvider) Start(ctx context.Context) (<-chan StatusEvent, e
}

func (cp *credentialsProvider) Authorize(ctx context.Context, authState, code string) (*oauth2.Token, error) {
if cp.syncGate.IsOpen() {
if !cp.reauthorizeIfAuthorized && cp.syncGate.IsOpen() {
return nil, errors.WithStack(ErrAlreadyAuthorized)
}

Expand All @@ -291,6 +304,10 @@ func (cp *credentialsProvider) Authorize(ctx context.Context, authState, code st
return nil, err
}

if cp.syncGate.IsOpen() {
cp.signalRefresh()
}

return token, nil
}

Expand Down Expand Up @@ -409,7 +426,7 @@ func (cp *credentialsProvider) storeTokenAndAuthorize(ctx context.Context, token
return nil
}

func (cp *credentialsProvider) authorizeIfTokenExists(ctx context.Context) error {
func (cp *credentialsProvider) authorizeIfTokenExists(ctx context.Context) {
token, err := cp.tokenStore.Get(ctx, cp.id)
if err == nil && token.RefreshToken != "" {
cp.logger.Info("token exists")
Expand All @@ -418,8 +435,6 @@ func (cp *credentialsProvider) authorizeIfTokenExists(ctx context.Context) error
} else {
cp.setUnauthorizedStatus(errors.WithStack(ErrAuthorizationNeeded))
}

return nil
}

// validateConfig validates the configuration and returns an error if any required field is missing.
Expand Down Expand Up @@ -547,24 +562,30 @@ func (cp *credentialsProvider) tokenRefresherLoop(ctx context.Context) {

select {
case <-cp.refreshCh:
cp.logger.Info("client credentials changed: reset init condition and wait for re-authorization")

cp.mu.RLock()
serr := cp.secretError
cp.mu.RUnlock()

err := errors.WithStack(ErrAuthorizationNeeded)
if serr != nil {
err = errors.Wrap(err, serr.Error())
cp.logger.Info("client credentials changed: reset init condition and wait for re-authorization")

err := errors.Wrap(errors.WithStack(ErrAuthorizationNeeded), serr.Error())
cp.publishCredential(Credential{
Event: credential.RemoveEventType,
Err: err,
})

cp.setUnauthorizedStatus(err)

continue
}

cp.logger.Info("re-authorization: replacing token")
cp.publishCredential(Credential{
Event: credential.RemoveEventType,
Err: err,
Err: errors.WithStack(ErrAuthorizationNeeded),
})

cp.setUnauthorizedStatus(err)

continue
refreshTime = 0
case <-ctx.Done():
return
case <-time.After(refreshTime):
Expand Down Expand Up @@ -736,7 +757,7 @@ func (cp *credentialsProvider) cleanupExpiredStates() {

cp.statesMu.RLock()
for k, authState := range cp.authStates {
if authState.issuedAt.Before(time.Now().Add(-time.Second * 60)) {
if authState.issuedAt.Before(time.Now().Add(-cp.authStateTTL)) {
candidates = append(candidates, k)
}
}
Expand Down
63 changes: 63 additions & 0 deletions pkg/oauth2ac/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2026 Riptides Labs, Inc.
// SPDX-License-Identifier: MIT

package oauth2ac

import (
"time"

"go.riptides.io/tokenex/pkg/option"
)

// CredentialsProviderOption is a function that modifies the credentialsProvider.
type (
CredentialsProviderOption interface {
Apply(*credentialsProvider)
}
credentialsProviderOption struct {
option.Option

f func(*credentialsProvider)
}
)

func (o *credentialsProviderOption) Apply(cp *credentialsProvider) {
o.f(cp)
}

func withCredentialsProviderOption(f func(*credentialsProvider)) option.Option {
return &credentialsProviderOption{option.OptionImpl{}, f}
}

func isCredentialsProviderOption(opt any) (CredentialsProviderOption, bool) {
if o, ok := opt.(*credentialsProviderOption); ok {
return o, ok
}

return nil, false
}

// WithReauthorizeIfAuthorized controls whether completing a new auth flow signals a token
// refresh when the provider is already authorized (i.e. has a valid token).
// Defaults to true.
func WithReauthorizeIfAuthorized(v bool) option.Option {
return withCredentialsProviderOption(func(cp *credentialsProvider) {
cp.reauthorizeIfAuthorized = v
})
}

// WithAuthStateTTL sets how long an auth state is kept before being considered expired.
// Defaults to 5 minutes.
func WithAuthStateTTL(ttl time.Duration) option.Option {
return withCredentialsProviderOption(func(cp *credentialsProvider) {
cp.authStateTTL = ttl
})
}

// WithAuthStateCleanupInterval sets how often expired auth states are swept.
// Defaults to 5 seconds.
func WithAuthStateCleanupInterval(d time.Duration) option.Option {
return withCredentialsProviderOption(func(cp *credentialsProvider) {
cp.authStateCleanupInterval = d
})
}
Loading