diff --git a/README.md b/README.md index bdb579e..a2b4287 100644 --- a/README.md +++ b/README.md @@ -1,290 +1,554 @@ # goauth -Pluggable authentication for Go. Build username/password or JWT-based auth with a simple strategy interface, a configurable JWT access-token + refresh-token issuer, and typed errors for clean, predictable error handling. +[![Go Reference](https://pkg.go.dev/badge/github.com/openframebox/goauth/v2.svg)](https://pkg.go.dev/github.com/openframebox/goauth/v2) + +Pluggable authentication for Go. Build username/password or JWT-based auth with a simple strategy interface, configurable JWT access/refresh token issuers with multi-session support, and typed errors for clean, predictable error handling. Works as an auth core you can drop into HTTP APIs, gRPC, or CLIs. +## Version + +**Current: v2.0.0** + +This is a major version with breaking changes from v1. See [Migration from v1](#migration-from-v1) for upgrade guide. + ## Installation +```bash +go get github.com/openframebox/goauth/v2 ``` + +For v1 (legacy): + +```bash go get github.com/openframebox/goauth ``` +## Features + +- **Multi-session support** - Users can have multiple active sessions (e.g., phone + laptop) +- **Token rotation** - Proper refresh token rotation with old token invalidation +- **Multiple signing algorithms** - HS256/384/512, RS256/384/512, ES256/384/512 +- **Event hooks** - `OnBeforeAuthenticate`, `OnAfterAuthenticate`, `OnTokenIssued`, `OnTokenRevoked` +- **Rate limiting** - Built-in interfaces for rate limiting strategies +- **Password validation** - Optional bcrypt/argon2 integration +- **Thread-safe** - Safe for concurrent use +- **Typed errors** - Categorized errors for consistent HTTP responses + ## Concepts -- Strategy: pluggable auth mechanism that returns a user (e.g., Local, JWT, OAuth, SSO). Implement `Name()` and `Authenticate()`. -- Authenticatable: minimal user interface (`GetID`, `GetUsername`, `GetEmail`, `GetExtra`). A default `User` is included. -- TokenIssuer: service that creates/verifies access tokens and manages refresh tokens. `DefaultTokenIssuer` uses HS256 JWT for access tokens and UUIDs for refresh tokens. -- Typed Errors: errors are categorized (`CredentialError`, `TokenError`, `ConfigError`, `NotFoundError`, `InternalError`) so callers can map responses consistently. +- **Strategy**: pluggable auth mechanism (Local, JWT, OAuth, SSO). Implement `Name()` and `Authenticate()`. +- **Authenticatable**: minimal user interface (`GetID`, `GetUsername`, `GetEmail`, `GetExtra`). +- **TokenIssuer**: creates/verifies access tokens and manages refresh tokens. + - `DefaultTokenIssuer`: basic HS256 JWT issuer + - `SessionTokenIssuer`: multi-session aware issuer with configurable signing +- **SessionInfo**: session metadata (ID, device, IP, expiry) for multi-session support +- **Typed Errors**: `CredentialError`, `TokenError`, `ConfigError`, `NotFoundError`, `InternalError`, `RateLimitError`, `ValidationError`, `SessionError` ## Quick Start -See `example/main.go` for a runnable demo of login → JWT auth → refresh: - ``` -go run ./example +go run ./example # Basic multi-session demo +go run ./example/http_server # HTTP server example ``` -## Initialization +## Basic Setup (DefaultTokenIssuer) -Create a token issuer and the orchestrator, then register strategies: +For simple use cases without multi-session support: ```go package main import ( "context" - goauth "github.com/openframebox/goauth" + goauth "github.com/openframebox/goauth/v2" ) func setup() *goauth.GoAuth { - // 1) Configure token issuer + // Configure token issuer ti := goauth.NewDefaultTokenIssuer("supersecret") ti.SetIssuer("api.example.com") ti.SetAudience([]string{"api.example.com"}) - // Required for refresh tokens: where/how to store them - ti.StoreRefreshTokenWith(func(ctx context.Context, a goauth.Authenticatable, tok *goauth.Token, refreshing bool) error { - // persist tok.Value with user a.GetID() and rotation behavior + // Required: refresh token storage + ti.StoreRefreshTokenWith(func(ctx context.Context, a goauth.Authenticatable, tok *goauth.Token, oldToken *string) error { + // oldToken is nil for initial login, non-nil for refresh (rotation) + if oldToken != nil { + // Invalidate the old token + } + // Store tok.Value with user a.GetID() return nil }) + ti.ValidateRefreshTokenWith(func(ctx context.Context, token string) (goauth.Authenticatable, error) { - // lookup token → user; return &goauth.TokenError on invalid + // Lookup token -> user; return error if invalid return &goauth.User{ID: "user-123"}, nil }) - // Optional: attach extra claims to access JWTs - ti.SetExtraClaimsWith(func(ctx context.Context, a goauth.Authenticatable) map[string]any { - return map[string]any{"role": "admin"} + ti.RevokeRefreshTokenWith(func(ctx context.Context, token string) error { + // Delete the token from storage + return nil }) - // Optional: convert JWT claims → full user (by default: ID/Username/Email) - // ti.ConvertAccessTokenClaimsWith(func(ctx context.Context, c *goauth.TokenClaims) (goauth.Authenticatable, error) { - // return &goauth.User{ID: c.Subject}, nil - // }) - - // 2) Build the orchestrator and register strategies + // Build orchestrator ga := goauth.New() ga.SetTokenIssuer(ti) - // Local username/password strategy - ga.RegisterStrategy(&goauth.LocalStrategy{LookupUserWith: func(ctx context.Context, p goauth.AuthParams) (goauth.Authenticatable, error) { - // validate p.UsernameOrEmail + p.Password - // return &goauth.CredentialError on invalid creds + // Register strategies using builder pattern + ga.RegisterStrategy(goauth.NewLocalStrategy(func(ctx context.Context, p goauth.AuthParams) (goauth.Authenticatable, error) { + // Validate credentials return &goauth.User{ID: "user-" + p.UsernameOrEmail, Username: p.UsernameOrEmail}, nil - }}) + })) + + ga.RegisterStrategy(goauth.NewJWTStrategy(ti).WithExpectedType(goauth.AccessToken)) + + return ga +} +``` + +## Multi-Session Setup (SessionTokenIssuer) + +For apps that need multi-device login, session management, and advanced signing: + +```go +package main + +import ( + "context" + "time" + goauth "github.com/openframebox/goauth/v2" +) + +func setup() *goauth.GoAuth { + // Create key provider (supports HS256/384/512, RS256/384/512, ES256/384/512) + keyProvider, _ := goauth.NewHMACKeyProvider([]byte("supersecret"), goauth.HS256) + + // Build session-aware token issuer + issuer, _ := goauth.NewSessionAwareTokenIssuer(). + WithKeyProvider(keyProvider). + WithIssuer("api.example.com"). + WithAudience([]string{"api.example.com"}). + WithAccessTokenTTL(15 * time.Minute). + WithRefreshTokenTTL(7 * 24 * time.Hour). + WithSessionStore( + storeSession, // Store session + token + validateSession, // Validate token -> user + session + revokeSession, // Revoke single session + revokeAllSessions, // Revoke all user sessions + ). + WithListSessions(listSessions). + WithGetSession(getSession). + WithSessionMetadataExtractor(func(ctx context.Context) map[string]any { + // Extract device info, IP, user agent from context + return map[string]any{"device": "browser", "ip": "127.0.0.1"} + }). + Build() - // JWT bearer token strategy - ga.RegisterStrategy(&goauth.JWTStrategy{TokenIssuer: ti}) + ga := goauth.New() + ga.SetTokenIssuer(issuer) - // Option A (overwrite allowed): - ga.RegisterSingleton() + // Register strategies + ga.RegisterStrategy(goauth.NewLocalStrategy(lookupUser)) + ga.RegisterStrategy(goauth.NewJWTStrategy(issuer).WithExpectedType(goauth.AccessToken)) - // Option B (set once): - // if err := ga.RegisterSingletonOnce(); err != nil { panic(err) } return ga } + +// Session store callbacks +func storeSession(ctx context.Context, auth goauth.Authenticatable, session *goauth.SessionInfo, token *goauth.Token, oldToken *string) error { + // If oldToken != nil, invalidate it (rotation) + // Store session with token + return nil +} + +func validateSession(ctx context.Context, token string) (goauth.Authenticatable, *goauth.SessionInfo, error) { + // Lookup token -> user + session + return user, session, nil +} + +func revokeSession(ctx context.Context, auth goauth.Authenticatable, sessionID string) error { + // Delete session by ID + return nil +} + +func revokeAllSessions(ctx context.Context, auth goauth.Authenticatable) error { + // Delete all sessions for user + return nil +} + +func listSessions(ctx context.Context, auth goauth.Authenticatable) ([]*goauth.SessionInfo, error) { + // Return all active sessions for user + return sessions, nil +} + +func getSession(ctx context.Context, token string) (*goauth.SessionInfo, error) { + // Get session info by token + return session, nil +} ``` +## Choosing a Token Issuer + +| Feature | DefaultTokenIssuer | SessionTokenIssuer | +| ----------------------- | -------------------- | ---------------------------------------------------- | +| **Signing algorithms** | HS256 only | HS256/384/512, RS256/384/512, ES256/384/512 | +| **Multi-device login** | No session isolation | Each device = unique session | +| **Session management** | None | `ListSessions`, `RevokeSession`, `RevokeAllSessions` | +| **JWT `sid` claim** | Not included | Session ID embedded in access token | +| **Session metadata** | None | Device, IP, user agent tracking | +| **Configuration style** | Setter methods | Builder pattern | +| **Storage callbacks** | Token-centric | Session-centric | + +**Use `DefaultTokenIssuer` when:** + +- Simple single-session apps +- You only need basic JWT with HS256 +- You manage token storage yourself without session semantics + +**Use `SessionTokenIssuer` when:** + +- Users log in from multiple devices (phone + laptop) +- You need "see all active sessions" or "logout all devices" features +- You want flexible signing algorithms (RSA, ECDSA) +- You need session metadata (device info, IP tracking) + ## Core Flows -### 1) Username/Password Login → Tokens +### 1) Login and Issue Tokens ```go +// Returns individual tokens res, access, refresh, err := ga.AuthenticateAndIssueTokens(ctx, "local", goauth.AuthParams{ UsernameOrEmail: "alice", Password: "s3cret", }) -// res.Authenticatable → user, access.Value → JWT, refresh.Value → UUID + +// Or returns TokenPair +res, pair, err := ga.AuthenticateAndIssueTokenPair(ctx, "local", params) +// pair.Access, pair.Refresh, pair.Access.SessionID ``` ### 2) Authenticate Requests with JWT ```go res, err := ga.Authenticate(ctx, "jwt", goauth.AuthParams{Token: bearer}) -// res.Authenticatable is your user; errors are typed (TokenError, etc.) +// res.Authenticatable is your user ``` -### 3) Refresh Tokens (Rotation) +### 3) Refresh Tokens (with rotation) ```go -access, refresh, err := ga.RefreshToken(ctx, refreshToken) -// ValidateRefreshTokenWith determines whether it's valid and which user it belongs to +// Old refresh token is passed to storage for invalidation +pair, err := ga.RefreshTokenPair(ctx, refreshToken) +// pair.Access (new), pair.Refresh (new, old is invalidated) ``` -## HTTP Integration +### 4) Revoke Tokens / Sessions + +```go +// Revoke single token +err := ga.RevokeToken(ctx, refreshToken) -Use the helpers to map typed errors to HTTP responses and error codes. +// Revoke specific session (requires SessionTokenIssuer) +err := ga.RevokeSession(ctx, user, sessionID) + +// Revoke all sessions (logout everywhere) +err := ga.RevokeAllTokens(ctx, user) +``` + +### 5) List Active Sessions ```go -func writeError(w http.ResponseWriter, err error) { - status := goauth.HTTPStatusForError(err) - code := goauth.ErrorCodeForError(err) - http.Error(w, code, status) +sessions, err := ga.ListSessions(ctx, user) +for _, s := range sessions { + fmt.Printf("Session %s: device=%s, expires=%s\n", + s.ID, s.Metadata["device"], s.ExpiresAt) } +``` -func loginHandler(w http.ResponseWriter, r *http.Request, ga *goauth.GoAuth) { - // parse JSON {"username":..., "password":...} - ctx := r.Context() - res, access, refresh, err := ga.AuthenticateAndIssueTokens(ctx, "local", goauth.AuthParams{ - UsernameOrEmail: r.FormValue("username"), - Password: r.FormValue("password"), - }) - if err != nil { - writeError(w, err) - return - } - // marshal JSON response with user and tokens - _ = res; _ = access; _ = refresh +## Event Hooks + +Add logging, audit trails, or custom logic: + +```go +type MyHooks struct { + goauth.NoOpEventHooks // Embed to only override what you need } -func meHandler(w http.ResponseWriter, r *http.Request, ga *goauth.GoAuth) { - bearer := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - res, err := ga.Authenticate(r.Context(), "jwt", goauth.AuthParams{Token: bearer}) - if err != nil { - writeError(w, err) - return - } - // return res.Authenticatable as JSON - _ = res +func (h *MyHooks) OnBeforeAuthenticate(ctx context.Context, strategy string, params goauth.AuthParams) error { + // Rate limiting, logging, etc. + // Return error to block authentication + return nil } -func refreshHandler(w http.ResponseWriter, r *http.Request, ga *goauth.GoAuth) { - // parse JSON {"refresh_token":...} - access, refresh, err := ga.RefreshToken(r.Context(), r.FormValue("refresh_token")) +func (h *MyHooks) OnAfterAuthenticate(ctx context.Context, strategy string, result *goauth.AuthResult, err error) { if err != nil { - writeError(w, err) - return + log.Printf("Auth failed for strategy %s: %v", strategy, err) + } else { + log.Printf("User %s authenticated via %s", result.Authenticatable.GetID(), strategy) } - // return new tokens as JSON - _ = access; _ = refresh } -``` -## Configuration & Customization +func (h *MyHooks) OnTokenIssued(ctx context.Context, auth goauth.Authenticatable, tokens *goauth.TokenPair) { + log.Printf("Tokens issued for user %s, session %s", auth.GetID(), tokens.Access.SessionID) +} -### Token Issuer +func (h *MyHooks) OnTokenRevoked(ctx context.Context, auth goauth.Authenticatable, token string) { + log.Printf("Token revoked for user %s", auth.GetID()) +} -`DefaultTokenIssuer` provides: +// Register hooks +ga.SetEventHooks(&MyHooks{}) +``` -- `SetSecret(string)`: HS256 signing secret for access JWTs. -- `SetIssuer(string)`, `SetAudience([]string)`: standard JWT claims. -- `SetAccessTokenExpiresIn(time.Duration)`, `SetRefreshTokenExpiresIn(time.Duration)`. -- `StoreRefreshTokenWith(func)`: required; persist and rotate refresh tokens. -- `ValidateRefreshTokenWith(func)`: required; validate and resolve refresh tokens → user. -- `SetExtraClaimsWith(func)`: optional; add custom claims to access JWTs. -- `SetRegisteredClaimsWith(func)`: optional; override registered claims (exp/iss/aud/sub, etc.). -- `ConvertAccessTokenClaimsWith(func)`: optional; map claims back → user (defaults to ID/Username/Email + Extra). +## Strategy Enhancements -If you need asymmetric signing or non-JWT tokens, implement `TokenIssuer` yourself. +### LocalStrategy with Password Validation & Rate Limiting -### Strategies +```go +strategy := goauth.NewLocalStrategy(lookupUser). + WithName("local"). + WithPasswordValidator( + func(plain, hashed string) bool { + return bcrypt.CompareHashAndPassword([]byte(hashed), []byte(plain)) == nil + }, + func(user goauth.Authenticatable) string { + return user.(*MyUser).HashedPassword + }, + ). + WithRateLimiter( + func(ctx context.Context, identifier string) error { + // Return goauth.ErrRateLimitExceeded if blocked + return nil + }, + func(ctx context.Context, identifier string, success bool) { + // Record attempt for rate limiting + }, + ). + WithUsernameNormalizer(func(username string) string { + return strings.ToLower(strings.TrimSpace(username)) + }) +``` -Two built-in strategies: +### JWTStrategy with Token Type & Revocation Check -- `LocalStrategy`: takes a `LookupUserWith(Context, AuthParams) (Authenticatable, error)` function. Return `CredentialError` for bad creds, `InternalError` for DB failures, etc. -- `JWTStrategy`: takes a `TokenIssuer` and authenticates a bearer token. +```go +strategy := goauth.NewJWTStrategy(issuer). + WithName("jwt"). + WithExpectedType(goauth.AccessToken). // Reject refresh tokens + WithRevocationCheck(func(ctx context.Context, token string) bool { + // Return true if token is revoked + return isRevoked(token) + }) +``` -Custom strategies can implement `Strategy` and be registered on `GoAuth`: +## Signing Algorithms ```go -type OAuthStrategy struct{} -func (s *OAuthStrategy) Name() string { return "oauth" } -func (s *OAuthStrategy) Authenticate(ctx context.Context, params goauth.AuthParams) (goauth.Authenticatable, error) { - // exchange code → user; return typed errors as appropriate - return &goauth.User{ID: "..."}, nil -} -ga.RegisterStrategy(&OAuthStrategy{}) +// HMAC (symmetric) +kp, _ := goauth.NewHMACKeyProvider([]byte("secret"), goauth.HS256) +kp, _ := goauth.NewHMACKeyProvider([]byte("secret"), goauth.HS384) +kp, _ := goauth.NewHMACKeyProvider([]byte("secret"), goauth.HS512) + +// RSA (asymmetric) +kp, _ := goauth.NewRSAKeyProvider(privateKey, publicKey, goauth.RS256) + +// ECDSA (asymmetric) +kp, _ := goauth.NewECDSAKeyProvider(privateKey, publicKey, goauth.ES256) ``` -#### Passing extra parameters +## HTTP Integration -`AuthParams` includes an `Extra map[string]any` field so you can pass provider- or flow-specific values without changing the interface. Examples: +```go +func loginHandler(w http.ResponseWriter, r *http.Request) { + var req LoginRequest + json.NewDecoder(r.Body).Decode(&req) -- OAuth/OIDC: `{ "provider": "google", "code_verifier": "...", "redirect_uri": "...", "state": "..." }` -- SSO/SAML: `{ "relay_state": "..." }` -- Any custom metadata you want strategies to see. + _, pair, err := ga.AuthenticateAndIssueTokenPair(r.Context(), "local", goauth.AuthParams{ + UsernameOrEmail: req.Username, + Password: req.Password, + }) + if err != nil { + resp := goauth.ErrorResponseForError(err) + w.WriteHeader(resp.Status) + json.NewEncoder(w).Encode(resp) + return + } -Access it in your strategy or lookup function via `params.Extra["key"]`. Prefer checking presence and type-asserting to avoid panics. + json.NewEncoder(w).Encode(map[string]any{ + "access_token": pair.Access.Value, + "refresh_token": pair.Refresh.Value, + "expires_in": int(pair.Access.ExpiresIn.Seconds()), + "session_id": pair.Access.SessionID, + }) +} -### Singleton Access +func authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + result, err := ga.Authenticate(r.Context(), "jwt", goauth.AuthParams{Token: token}) + if err != nil { + resp := goauth.ErrorResponseForError(err) + w.WriteHeader(resp.Status) + json.NewEncoder(w).Encode(resp) + return + } + ctx := context.WithValue(r.Context(), "user", result.Authenticatable) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} +``` -Singleton registration exists to provide an easy way to access the GoAuth instance from middleware/handlers without threading it through parameters. Use it when DI isn't practical; otherwise prefer explicit dependency injection. +## Error Types & HTTP Mapping -If you prefer a global instance: +| Error Type | HTTP Status | Error Code | +| ----------------- | ----------- | ------------------------------------------------------------------------------------- | +| `CredentialError` | 401 | `invalid_credentials` | +| `TokenError` | 401 | `token_error` / `token_missing` / `token_invalid` / `token_expired` / `token_revoked` | +| `ValidationError` | 400 | `validation_error` | +| `RateLimitError` | 429 | `rate_limit_exceeded` | +| `NotFoundError` | 404 | `not_found` / `strategy_not_found` / `session_not_found` | +| `ConfigError` | 500 | `config_error` | +| `InternalError` | 500 | `internal_error` | +| `SessionError` | 401 | `session_error` | ```go -ga := setup() -ga.RegisterSingleton() // or: _ = ga.RegisterSingletonOnce() -// Later in handlers/services: -ga = goauth.GetInstance() +// Get structured error response +resp := goauth.ErrorResponseForError(err) +// resp.Status, resp.Code, resp.Message, resp.Fields (for validation), resp.RetryAfter (for rate limit) + +// Or individual helpers +status := goauth.HTTPStatusForError(err) +code := goauth.ErrorCodeForError(err) +retryAfter := goauth.RetryAfterForError(err) +``` + +## Thread Safety + +`GoAuth` is safe for concurrent use: + +```go +// Strategy registration is mutex-protected +ga.RegisterStrategy(strategy) +ga.UnregisterStrategy("oauth") +ga.HasStrategy("local") +ga.ListStrategies() ``` -Testing support: +## Singleton Access + +For convenience when DI isn't practical: ```go +ga.RegisterSingleton() // Overwrite allowed +_ = ga.RegisterSingletonOnce() // Set once, error on second + +// Later +ga = goauth.GetInstance() + +// Testing restore := goauth.ReplaceSingletonForTest(mockGA) defer restore() ``` -### Typed Errors +## Examples -This package returns categorized errors so you can branch behavior and log appropriately: +```bash +# Multi-session demo +go run ./example -- `CredentialError`: bad or missing credentials. -- `TokenError`: invalid/missing/expired token, or refresh token rejected. -- `ConfigError`: misconfiguration, missing token issuer or hooks. -- `NotFoundError`: strategy not found, etc. -- `InternalError`: unexpected failure (IO/DB/crypto). +# HTTP server with login, refresh, logout, sessions endpoints +go run ./example/http_server +``` -Helpers: +The HTTP server example provides: -- `HTTPStatusForError(error) int` and `ErrorCodeForError(error) string`. -- Sentinels for common cases: `ErrMissingToken`, `ErrInvalidToken`, `ErrExpiredToken`, `ErrStrategyNotFound`, `ErrTokenIssuerUnset`. +- `POST /login` - Authenticate and get tokens +- `POST /refresh` - Refresh tokens +- `POST /logout` - Revoke current session +- `POST /logout-all` - Revoke all sessions +- `GET /me` - Get current user (protected) +- `GET /sessions` - List active sessions (protected) -### Refresh Token Storage (example) +## Migration from v1 -```go -// Redis-like pseudo code -ti.StoreRefreshTokenWith(func(ctx context.Context, a goauth.Authenticatable, t *goauth.Token, refreshing bool) error { - key := "rt:" + a.GetID() - if refreshing { - // rotate: overwrite previous token for this user - } - // SET key t.Value EX t.ExpiresIn - return nil -}) +### Breaking Changes -ti.ValidateRefreshTokenWith(func(ctx context.Context, token string) (goauth.Authenticatable, error) { - // GET userID by reverse index or scan keys if you store per-user - // If not found → return &goauth.TokenError{Msg: "refresh token not found"} - return &goauth.User{ID: "user-123"}, nil -}) -``` +1. **Module path changed**: Import path is now `github.com/openframebox/goauth/v2` -## Security Notes + ```go + // v1 + import goauth "github.com/openframebox/goauth" + // v2 + import goauth "github.com/openframebox/goauth/v2" + ``` -- Use a strong, rotated HS256 secret. Keep it out of source control. -- Set correct `issuer` and `audience` and validate them on consumers. -- Keep access tokens short-lived; rely on refresh tokens and rotation. -- Revoke refresh tokens on logout and on suspicious activity. -- Consider binding refresh tokens to device/session identifiers. +2. **TokenIssuer interface**: `CreateRefreshToken` signature changed -## Testing + ```go + // v1 + CreateRefreshToken(ctx, auth, refreshing bool) (*Token, error) + // v2 + CreateRefreshToken(ctx, auth, oldToken *string) (*Token, error) + ``` -- Use `errors.As(err, *TokenError)` etc. to assert failure categories. -- Stub `StoreRefreshTokenWith`/`ValidateRefreshTokenWith` to simulate rotation and revocation. -- For JWTs, set short expirations in tests and validate expiry paths. +3. **StoreRefreshTokenFunc**: signature changed -## Example + ```go + // v1 + func(ctx, auth, token, refreshing bool) error + // v2 + func(ctx, auth, token, oldToken *string) error + ``` -There is a complete runnable example under `example/`: +4. **Strategy constructors**: use builder pattern -``` -go run ./example -``` + ```go + // v1 + &goauth.LocalStrategy{LookupUserWith: fn} + // v2 + goauth.NewLocalStrategy(fn) + + // v1 + &goauth.JWTStrategy{TokenIssuer: ti} + // v2 + goauth.NewJWTStrategy(ti) + ``` + +5. **Token struct**: new fields added + + - `Type` (TokenType) - "access" or "refresh" + - `IssuedAt` (time.Time) + - `SessionID` (string) + +6. **New required method on TokenIssuer**: `RevokeRefreshToken(ctx, token string) error` + +7. **GoAuth methods**: New `TokenPair` returning methods added + - `IssueTokenPair()` alongside `IssueTokens()` + - `RefreshTokenPair()` alongside `RefreshToken()` + - `AuthenticateAndIssueTokenPair()` alongside `AuthenticateAndIssueTokens()` + +### New Features in v2 + +- **Multi-session support** with `SessionTokenIssuer` +- **Multiple signing algorithms** (HS256/384/512, RS256/384/512, ES256/384/512) +- **Event hooks** (`AuthEventHooks` interface) +- **Rate limiting support** in strategies +- **Password validation** in `LocalStrategy` +- **Token type validation** in `JWTStrategy` +- **Thread-safe** strategy registration with `sync.RWMutex` +- **New error types**: `RateLimitError`, `ValidationError`, `SessionError` +- **Session management**: `ListSessions`, `RevokeSession`, `RevokeAllTokens` + +## Security Notes -It demonstrates: local login, JWT request authentication, refresh rotation. +- Use strong, rotated secrets. Keep them out of source control. +- Set correct `issuer` and `audience` claims. +- Keep access tokens short-lived (5-15 min). +- Implement proper refresh token rotation. +- Revoke tokens on logout and suspicious activity. +- Use rate limiting on authentication endpoints. +- Hash passwords with bcrypt/argon2. ## License diff --git a/default_token_issuer.go b/default_token_issuer.go new file mode 100644 index 0000000..5f088d4 --- /dev/null +++ b/default_token_issuer.go @@ -0,0 +1,246 @@ +package goauth + +import ( + "context" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// Callback function types for DefaultTokenIssuer + +// StoreRefreshTokenFunc stores a refresh token +// oldToken is the previous refresh token being rotated (nil for initial login) +type StoreRefreshTokenFunc func(ctx context.Context, authenticatable Authenticatable, token *Token, oldToken *string) error + +// SetExtraClaimsFunc returns extra claims to include in the access token +type SetExtraClaimsFunc func(ctx context.Context, authenticatable Authenticatable) map[string]any + +// SetRegisteredClaimsFunc returns custom registered claims for the access token +type SetRegisteredClaimsFunc func(ctx context.Context, authenticatable Authenticatable) jwt.RegisteredClaims + +// ConvertAccessTokenClaimsFunc converts token claims to an Authenticatable entity +type ConvertAccessTokenClaimsFunc func(ctx context.Context, claims *TokenClaims) (Authenticatable, error) + +// ValidateRefreshTokenFunc validates a refresh token and returns the associated user +type ValidateRefreshTokenFunc func(ctx context.Context, token string) (Authenticatable, error) + +// RevokeRefreshTokenFunc revokes a refresh token +type RevokeRefreshTokenFunc func(ctx context.Context, token string) error + +// DefaultTokenIssuer is a basic implementation of TokenIssuer +// For multi-session support, use SessionAwareTokenIssuer instead +type DefaultTokenIssuer struct { + secret string + issuer string + audience []string + accessTokenExpiresIn time.Duration + refreshTokenExpiresIn time.Duration + storeRefreshTokenWith StoreRefreshTokenFunc + setExtraClaimsWith SetExtraClaimsFunc + setRegisteredClaimsWith SetRegisteredClaimsFunc + convertAccessTokenClaimsWith ConvertAccessTokenClaimsFunc + validateRefreshTokenWith ValidateRefreshTokenFunc + revokeRefreshTokenWith RevokeRefreshTokenFunc +} + +// NewDefaultTokenIssuer creates a new DefaultTokenIssuer with sensible defaults +func NewDefaultTokenIssuer(secret string) *DefaultTokenIssuer { + ti := &DefaultTokenIssuer{ + secret: secret, + issuer: "goauth", + audience: []string{"goauth"}, + accessTokenExpiresIn: 300 * time.Second, // default 5 minutes + refreshTokenExpiresIn: 3600 * time.Second, // default 1 hour + } + + return ti +} + +// SetSecret sets the JWT signing secret +func (ti *DefaultTokenIssuer) SetSecret(secret string) { + ti.secret = secret +} + +// SetIssuer sets the JWT issuer claim +func (ti *DefaultTokenIssuer) SetIssuer(issuer string) { + ti.issuer = issuer +} + +// SetAudience sets the JWT audience claim +func (ti *DefaultTokenIssuer) SetAudience(audience []string) { + ti.audience = audience +} + +// SetAccessTokenExpiresIn sets the access token expiration duration +func (ti *DefaultTokenIssuer) SetAccessTokenExpiresIn(expiresIn time.Duration) { + ti.accessTokenExpiresIn = expiresIn +} + +// SetRefreshTokenExpiresIn sets the refresh token expiration duration +func (ti *DefaultTokenIssuer) SetRefreshTokenExpiresIn(expiresIn time.Duration) { + ti.refreshTokenExpiresIn = expiresIn +} + +// StoreRefreshTokenWith sets the callback for storing refresh tokens +func (ti *DefaultTokenIssuer) StoreRefreshTokenWith(storeRefreshTokenWith StoreRefreshTokenFunc) { + ti.storeRefreshTokenWith = storeRefreshTokenWith +} + +// SetExtraClaimsWith sets the callback for adding extra claims to access tokens +func (ti *DefaultTokenIssuer) SetExtraClaimsWith(setExtraClaimsWith SetExtraClaimsFunc) { + ti.setExtraClaimsWith = setExtraClaimsWith +} + +// SetRegisteredClaimsWith sets the callback for customizing registered claims +func (ti *DefaultTokenIssuer) SetRegisteredClaimsWith(setRegisteredClaimsWith SetRegisteredClaimsFunc) { + ti.setRegisteredClaimsWith = setRegisteredClaimsWith +} + +// ConvertAccessTokenClaimsWith sets the callback for converting claims to Authenticatable +func (ti *DefaultTokenIssuer) ConvertAccessTokenClaimsWith(convertAccessTokenClaimsWith ConvertAccessTokenClaimsFunc) { + ti.convertAccessTokenClaimsWith = convertAccessTokenClaimsWith +} + +// ValidateRefreshTokenWith sets the callback for validating refresh tokens +func (ti *DefaultTokenIssuer) ValidateRefreshTokenWith(validateRefreshTokenWith ValidateRefreshTokenFunc) { + ti.validateRefreshTokenWith = validateRefreshTokenWith +} + +// RevokeRefreshTokenWith sets the callback for revoking refresh tokens +func (ti *DefaultTokenIssuer) RevokeRefreshTokenWith(revokeRefreshTokenWith RevokeRefreshTokenFunc) { + ti.revokeRefreshTokenWith = revokeRefreshTokenWith +} + +// CreateAccessToken creates a new JWT access token +func (ti *DefaultTokenIssuer) CreateAccessToken(ctx context.Context, authenticatable Authenticatable) (*Token, error) { + extraClaims := make(map[string]any) + if ti.setExtraClaimsWith != nil { + extraClaims = ti.setExtraClaimsWith(ctx, authenticatable) + } + + now := time.Now() + var registeredClaims jwt.RegisteredClaims + if ti.setRegisteredClaimsWith != nil { + registeredClaims = ti.setRegisteredClaimsWith(ctx, authenticatable) + } else { + registeredClaims = jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(ti.accessTokenExpiresIn)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + Subject: authenticatable.GetID(), + Issuer: ti.issuer, + Audience: ti.audience, + } + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, TokenClaims{ + RegisteredClaims: registeredClaims, + Username: authenticatable.GetUsername(), + Email: authenticatable.GetEmail(), + TokenType: AccessToken, + ExtraClaims: extraClaims, + }) + + tokenString, err := token.SignedString([]byte(ti.secret)) + if err != nil { + return nil, err + } + + return &Token{ + Value: tokenString, + Type: AccessToken, + ExpiresIn: ti.accessTokenExpiresIn, + IssuedAt: now, + }, nil +} + +// CreateRefreshToken creates a new refresh token +// oldToken is the previous refresh token being rotated (nil for initial login) +func (ti *DefaultTokenIssuer) CreateRefreshToken(ctx context.Context, authenticatable Authenticatable, oldToken *string) (*Token, error) { + if ti.storeRefreshTokenWith == nil { + return nil, &ConfigError{Msg: "StoreRefreshTokenWith is not set"} + } + + now := time.Now() + tokenString := uuid.New().String() + token := &Token{ + Value: tokenString, + Type: RefreshToken, + ExpiresIn: ti.refreshTokenExpiresIn, + IssuedAt: now, + } + + err := ti.storeRefreshTokenWith(ctx, authenticatable, token, oldToken) + if err != nil { + return nil, &InternalError{Msg: "failed to store refresh token", Err: err} + } + + return token, nil +} + +// DecodeAccessToken parses and validates a JWT access token +func (ti *DefaultTokenIssuer) DecodeAccessToken(ctx context.Context, token string) (*TokenClaims, error) { + parsedToken, err := jwt.ParseWithClaims(token, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(ti.secret), nil + }) + + if err != nil { + // jwt lib returns various errors (validation/signature/expired). Classify as token error. + return nil, &TokenError{Msg: "failed to parse or validate access token", Err: err} + } + + claims, ok := parsedToken.Claims.(*TokenClaims) + if !ok { + return nil, &TokenError{Msg: "invalid token claims"} + } + + return claims, nil +} + +// ConvertAccessTokenClaims converts token claims to an Authenticatable entity +func (ti *DefaultTokenIssuer) ConvertAccessTokenClaims(ctx context.Context, claims *TokenClaims) (Authenticatable, error) { + if ti.convertAccessTokenClaimsWith != nil { + a, err := ti.convertAccessTokenClaimsWith(ctx, claims) + if err != nil { + return nil, &TokenError{Msg: "failed to convert access token claims", Err: err} + } + return a, nil + } + + return &User{ + ID: claims.Subject, + Username: claims.Username, + Email: claims.Email, + Extra: claims.ExtraClaims, + }, nil +} + +// ValidateRefreshToken validates a refresh token and returns the associated user +func (ti *DefaultTokenIssuer) ValidateRefreshToken(ctx context.Context, token string) (Authenticatable, error) { + if ti.validateRefreshTokenWith == nil { + return nil, &ConfigError{Msg: "ValidateRefreshTokenWith is not set"} + } + + authenticatable, err := ti.validateRefreshTokenWith(ctx, token) + if err != nil { + return nil, &TokenError{Msg: "invalid or rejected refresh token", Err: err} + } + + return authenticatable, nil +} + +// RevokeRefreshToken revokes a refresh token +func (ti *DefaultTokenIssuer) RevokeRefreshToken(ctx context.Context, token string) error { + if ti.revokeRefreshTokenWith == nil { + return &ConfigError{Msg: "RevokeRefreshTokenWith is not set"} + } + + err := ti.revokeRefreshTokenWith(ctx, token) + if err != nil { + return &InternalError{Msg: "failed to revoke refresh token", Err: err} + } + + return nil +} diff --git a/entity.go b/entity.go index c2e229a..2424459 100644 --- a/entity.go +++ b/entity.go @@ -2,6 +2,15 @@ package goauth import "time" +// TokenType represents the type of token (access or refresh) +type TokenType string + +const ( + AccessToken TokenType = "access" + RefreshToken TokenType = "refresh" +) + +// AuthParams contains authentication parameters passed to strategies type AuthParams struct { UsernameOrEmail string Password string @@ -9,16 +18,126 @@ type AuthParams struct { Extra map[string]any } +// GetExtra returns the value for a key from Extra map +func (ap *AuthParams) GetExtra(key string) (any, bool) { + if ap.Extra == nil { + return nil, false + } + v, ok := ap.Extra[key] + return v, ok +} + +// GetExtraString returns a string value from Extra map +func (ap *AuthParams) GetExtraString(key string) (string, bool) { + v, ok := ap.GetExtra(key) + if !ok { + return "", false + } + s, ok := v.(string) + return s, ok +} + +// GetExtraInt returns an int value from Extra map +func (ap *AuthParams) GetExtraInt(key string) (int, bool) { + v, ok := ap.GetExtra(key) + if !ok { + return 0, false + } + switch i := v.(type) { + case int: + return i, true + case int64: + return int(i), true + case float64: + return int(i), true + default: + return 0, false + } +} + +// GetExtraBool returns a bool value from Extra map +func (ap *AuthParams) GetExtraBool(key string) (bool, bool) { + v, ok := ap.GetExtra(key) + if !ok { + return false, false + } + b, ok := v.(bool) + return b, ok +} + +// Validate checks if the AuthParams has valid data for authentication +func (ap *AuthParams) Validate() error { + // At minimum, either username/email+password or token must be provided + hasCredentials := ap.UsernameOrEmail != "" && ap.Password != "" + hasToken := ap.Token != "" + + if !hasCredentials && !hasToken { + return &ValidationError{ + Msg: "authentication parameters required", + Fields: map[string]string{ + "credentials": "username/email and password or token required", + }, + } + } + return nil +} + +// AuthResult contains the result of a successful authentication type AuthResult struct { Authenticatable Authenticatable Strategy string + Metadata map[string]any // NEW: additional context from authentication } +// Token represents an authentication token (access or refresh) type Token struct { Value string + Type TokenType ExpiresIn time.Duration + IssuedAt time.Time + SessionID string // For multi-session support +} + +// TokenPair contains both access and refresh tokens +type TokenPair struct { + Access *Token + Refresh *Token +} + +// SessionInfo contains session metadata for multi-session support +type SessionInfo struct { + ID string + UserID string + CreatedAt time.Time + ExpiresAt time.Time + Metadata map[string]any // device, IP, user agent, location, etc. +} + +// GetMetadata returns a value from session metadata +func (s *SessionInfo) GetMetadata(key string) (any, bool) { + if s.Metadata == nil { + return nil, false + } + v, ok := s.Metadata[key] + return v, ok +} + +// GetMetadataString returns a string value from session metadata +func (s *SessionInfo) GetMetadataString(key string) (string, bool) { + v, ok := s.GetMetadata(key) + if !ok { + return "", false + } + str, ok := v.(string) + return str, ok +} + +// IsExpired checks if the session has expired +func (s *SessionInfo) IsExpired() bool { + return time.Now().After(s.ExpiresAt) } +// User is a default implementation of Authenticatable type User struct { ID string Username string diff --git a/errors.go b/errors.go index 5ab866c..cf083b8 100644 --- a/errors.go +++ b/errors.go @@ -1,6 +1,9 @@ package goauth -import "fmt" +import ( + "fmt" + "time" +) // CredentialError indicates a problem with user-provided credentials // such as invalid username/password. @@ -93,6 +96,72 @@ func (e *InternalError) Error() string { func (e *InternalError) Unwrap() error { return e.Err } +// RateLimitError indicates that rate limit has been exceeded. +type RateLimitError struct { + Msg string + RetryAfter time.Duration + Err error +} + +func (e *RateLimitError) Error() string { + if e.Msg != "" { + return e.Msg + } + if e.Err != nil { + return e.Err.Error() + } + return "rate limit exceeded" +} + +func (e *RateLimitError) Unwrap() error { return e.Err } + +// ValidationError indicates validation failure on input parameters. +type ValidationError struct { + Msg string + Fields map[string]string // field name -> error message + Err error +} + +func (e *ValidationError) Error() string { + if e.Msg != "" { + return e.Msg + } + if e.Err != nil { + return e.Err.Error() + } + return "validation error" +} + +func (e *ValidationError) Unwrap() error { return e.Err } + +// GetFieldError returns the error message for a specific field +func (e *ValidationError) GetFieldError(field string) (string, bool) { + if e.Fields == nil { + return "", false + } + msg, ok := e.Fields[field] + return msg, ok +} + +// SessionError indicates session-related problems. +type SessionError struct { + Msg string + SessionID string + Err error +} + +func (e *SessionError) Error() string { + if e.Msg != "" { + return e.Msg + } + if e.Err != nil { + return e.Err.Error() + } + return "session error" +} + +func (e *SessionError) Unwrap() error { return e.Err } + // Helper constructors func NewCredentialError(msg string, err error) error { return &CredentialError{Msg: msg, Err: err} } func NewTokenError(msg string, err error) error { return &TokenError{Msg: msg, Err: err} } @@ -100,18 +169,47 @@ func NewConfigError(msg string, err error) error { return &ConfigError{Msg: func NewNotFoundError(msg string, err error) error { return &NotFoundError{Msg: msg, Err: err} } func NewInternalError(msg string, err error) error { return &InternalError{Msg: msg, Err: err} } +func NewRateLimitError(msg string, retryAfter time.Duration) error { + return &RateLimitError{Msg: msg, RetryAfter: retryAfter} +} + +func NewValidationError(msg string, fields map[string]string) error { + return &ValidationError{Msg: msg, Fields: fields} +} + +func NewSessionError(msg string, sessionID string, err error) error { + return &SessionError{Msg: msg, SessionID: sessionID, Err: err} +} + // Convenience sentinels for common cases (use errors.As to match by type). var ( + // Credential errors ErrInvalidCredentials = &CredentialError{Msg: "invalid credentials"} - ErrMissingToken = &TokenError{Msg: "token is required"} - ErrInvalidToken = &TokenError{Msg: "invalid token"} - ErrExpiredToken = &TokenError{Msg: "expired token"} - ErrStrategyNotFound = &NotFoundError{Msg: "strategy not found"} - ErrTokenIssuerUnset = &ConfigError{Msg: "token issuer is not set"} + ErrUserNotFound = &CredentialError{Msg: "user not found"} + + // Token errors + ErrMissingToken = &TokenError{Msg: "token is required"} + ErrInvalidToken = &TokenError{Msg: "invalid token"} + ErrExpiredToken = &TokenError{Msg: "expired token"} + ErrTokenRevoked = &TokenError{Msg: "token has been revoked"} + ErrTokenTypeMismatch = &TokenError{Msg: "unexpected token type"} + + // Config errors + ErrTokenIssuerUnset = &ConfigError{Msg: "token issuer is not set"} + ErrKeyProviderUnset = &ConfigError{Msg: "key provider is not set"} + ErrSessionStoreUnset = &ConfigError{Msg: "session store is not set"} + + // Not found errors + ErrStrategyNotFound = &NotFoundError{Msg: "strategy not found"} + ErrSessionNotFound = &NotFoundError{Msg: "session not found"} + + // Rate limit errors + ErrRateLimitExceeded = &RateLimitError{Msg: "rate limit exceeded"} ) -// Formatting helpers to attach context without losing type info. -func withContext(err error, format string, args ...any) error { +// WithContext attaches context to an error without losing type info. +// Exported for use by strategies and other packages. +func WithContext(err error, format string, args ...any) error { if err == nil { return nil } @@ -126,8 +224,19 @@ func withContext(err error, format string, args ...any) error { return &NotFoundError{Msg: fmt.Sprintf(format, args...), Err: e} case *InternalError: return &InternalError{Msg: fmt.Sprintf(format, args...), Err: e} + case *RateLimitError: + return &RateLimitError{Msg: fmt.Sprintf(format, args...), RetryAfter: e.RetryAfter, Err: e} + case *ValidationError: + return &ValidationError{Msg: fmt.Sprintf(format, args...), Fields: e.Fields, Err: e} + case *SessionError: + return &SessionError{Msg: fmt.Sprintf(format, args...), SessionID: e.SessionID, Err: e} default: // Unknown error type -> wrap as InternalError return &InternalError{Msg: fmt.Sprintf(format, args...), Err: err} } } + +// withContext is kept for backward compatibility (unexported version) +func withContext(err error, format string, args ...any) error { + return WithContext(err, format, args...) +} diff --git a/example/http_server/main.go b/example/http_server/main.go new file mode 100644 index 0000000..c80804b --- /dev/null +++ b/example/http_server/main.go @@ -0,0 +1,460 @@ +package main + +import ( + "context" + "encoding/json" + "log" + "net/http" + "strings" + "sync" + "time" + + goauth "github.com/openframebox/goauth/v2" +) + +var ( + ga *goauth.GoAuth + store *inMemorySessionStore +) + +// inMemorySessionStore - same as main example +type inMemorySessionStore struct { + mu sync.RWMutex + sessions map[string]*sessionData + tokens map[string]string + users map[string]*goauth.User +} + +type sessionData struct { + session *goauth.SessionInfo + token string + userID string +} + +func newInMemorySessionStore() *inMemorySessionStore { + return &inMemorySessionStore{ + sessions: make(map[string]*sessionData), + tokens: make(map[string]string), + users: make(map[string]*goauth.User), + } +} + +func (s *inMemorySessionStore) store(ctx context.Context, auth goauth.Authenticatable, session *goauth.SessionInfo, token *goauth.Token, oldToken *string) error { + s.mu.Lock() + defer s.mu.Unlock() + + userID := auth.GetID() + + if oldToken != nil { + if oldSessionID, ok := s.tokens[*oldToken]; ok { + if oldData, exists := s.sessions[oldSessionID]; exists { + delete(s.tokens, oldData.token) + } + delete(s.sessions, oldSessionID) + } + } + + s.sessions[session.ID] = &sessionData{ + session: session, + token: token.Value, + userID: userID, + } + s.tokens[token.Value] = session.ID + + if u, ok := auth.(*goauth.User); ok { + s.users[userID] = u + } else { + s.users[userID] = &goauth.User{ID: userID, Username: auth.GetUsername(), Email: auth.GetEmail()} + } + + return nil +} + +func (s *inMemorySessionStore) validate(ctx context.Context, token string) (goauth.Authenticatable, *goauth.SessionInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + sessionID, ok := s.tokens[token] + if !ok { + return nil, nil, goauth.ErrTokenRevoked + } + + data, ok := s.sessions[sessionID] + if !ok { + return nil, nil, goauth.ErrSessionNotFound + } + + if data.session.IsExpired() { + return nil, nil, goauth.ErrExpiredToken + } + + user, ok := s.users[data.userID] + if !ok { + return nil, nil, goauth.ErrUserNotFound + } + + return user, data.session, nil +} + +func (s *inMemorySessionStore) revoke(ctx context.Context, auth goauth.Authenticatable, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + data, ok := s.sessions[sessionID] + if !ok { + return goauth.ErrSessionNotFound + } + + if data.userID != auth.GetID() { + return goauth.ErrInvalidCredentials + } + + delete(s.tokens, data.token) + delete(s.sessions, sessionID) + return nil +} + +func (s *inMemorySessionStore) revokeAll(ctx context.Context, auth goauth.Authenticatable) error { + s.mu.Lock() + defer s.mu.Unlock() + + userID := auth.GetID() + for sessionID, data := range s.sessions { + if data.userID == userID { + delete(s.tokens, data.token) + delete(s.sessions, sessionID) + } + } + return nil +} + +func (s *inMemorySessionStore) list(ctx context.Context, auth goauth.Authenticatable) ([]*goauth.SessionInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + userID := auth.GetID() + var sessions []*goauth.SessionInfo + for _, data := range s.sessions { + if data.userID == userID && !data.session.IsExpired() { + sessions = append(sessions, data.session) + } + } + return sessions, nil +} + +func (s *inMemorySessionStore) getSession(ctx context.Context, token string) (*goauth.SessionInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + sessionID, ok := s.tokens[token] + if !ok { + return nil, goauth.ErrTokenRevoked + } + + data, ok := s.sessions[sessionID] + if !ok { + return nil, goauth.ErrSessionNotFound + } + + return data.session, nil +} + +// Request/Response types +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + SessionID string `json:"session_id,omitempty"` +} + +type RefreshRequest struct { + RefreshToken string `json:"refresh_token"` +} + +type SessionResponse struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type UserResponse struct { + ID string `json:"id"` + Username string `json:"username"` + Email string `json:"email"` +} + +func main() { + // Initialize store and auth + store = newInMemorySessionStore() + + keyProvider, err := goauth.NewHMACKeyProvider([]byte("supersecret-change-me"), goauth.HS256) + if err != nil { + log.Fatalf("failed to create key provider: %v", err) + } + + issuer, err := goauth.NewSessionAwareTokenIssuer(). + WithKeyProvider(keyProvider). + WithIssuer("api.example.local"). + WithAudience([]string{"api.example.local"}). + WithAccessTokenTTL(15 * time.Minute). + WithRefreshTokenTTL(7 * 24 * time.Hour). + WithSessionStore(store.store, store.validate, store.revoke, store.revokeAll). + WithListSessions(store.list). + WithGetSession(store.getSession). + WithSessionMetadataExtractor(extractSessionMetadata). + Build() + + if err != nil { + log.Fatalf("failed to build token issuer: %v", err) + } + + ga = goauth.New() + ga.SetTokenIssuer(issuer) + + // Register strategies + localStrategy := goauth.NewLocalStrategy(lookupUser) + ga.RegisterStrategy(localStrategy) + + jwtStrategy := goauth.NewJWTStrategy(issuer).WithExpectedType(goauth.AccessToken) + ga.RegisterStrategy(jwtStrategy) + + // HTTP routes + http.HandleFunc("/login", handleLogin) + http.HandleFunc("/refresh", handleRefresh) + http.HandleFunc("/logout", authMiddleware(handleLogout)) + http.HandleFunc("/logout-all", authMiddleware(handleLogoutAll)) + http.HandleFunc("/me", authMiddleware(handleMe)) + http.HandleFunc("/sessions", authMiddleware(handleSessions)) + + log.Println("Server starting on :8080") + log.Println("Endpoints:") + log.Println(" POST /login - Login with username/password") + log.Println(" POST /refresh - Refresh access token") + log.Println(" POST /logout - Logout current session") + log.Println(" POST /logout-all - Logout all sessions") + log.Println(" GET /me - Get current user (protected)") + log.Println(" GET /sessions - List active sessions (protected)") + + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +func lookupUser(ctx context.Context, params goauth.AuthParams) (goauth.Authenticatable, error) { + // Demo: accept specific users + users := map[string]string{ + "alice": "password123", + "bob": "secret456", + } + + password, exists := users[params.UsernameOrEmail] + if !exists || password != params.Password { + return nil, goauth.ErrInvalidCredentials + } + + return &goauth.User{ + ID: "user-" + params.UsernameOrEmail, + Username: params.UsernameOrEmail, + Email: params.UsernameOrEmail + "@example.com", + }, nil +} + +func extractSessionMetadata(ctx context.Context) map[string]any { + // In real app, extract from request context + return map[string]any{ + "created_at": time.Now().Format(time.RFC3339), + } +} + +// Handlers + +func handleLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "POST only") + return + } + + var req LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON") + return + } + + result, pair, err := ga.AuthenticateAndIssueTokenPair(r.Context(), "local", goauth.AuthParams{ + UsernameOrEmail: req.Username, + Password: req.Password, + }) + if err != nil { + resp := goauth.ErrorResponseForError(err) + writeError(w, resp.Status, resp.Code, resp.Message) + return + } + + writeJSON(w, http.StatusOK, TokenResponse{ + AccessToken: pair.Access.Value, + RefreshToken: pair.Refresh.Value, + ExpiresIn: int(pair.Access.ExpiresIn.Seconds()), + TokenType: "Bearer", + SessionID: pair.Access.SessionID, + }) + + log.Printf("User %s logged in, session: %s", result.Authenticatable.GetUsername(), pair.Access.SessionID) +} + +func handleRefresh(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "POST only") + return + } + + var req RefreshRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON") + return + } + + pair, err := ga.RefreshTokenPair(r.Context(), req.RefreshToken) + if err != nil { + resp := goauth.ErrorResponseForError(err) + writeError(w, resp.Status, resp.Code, resp.Message) + return + } + + writeJSON(w, http.StatusOK, TokenResponse{ + AccessToken: pair.Access.Value, + RefreshToken: pair.Refresh.Value, + ExpiresIn: int(pair.Access.ExpiresIn.Seconds()), + TokenType: "Bearer", + SessionID: pair.Access.SessionID, + }) +} + +func handleLogout(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "POST only") + return + } + + var req RefreshRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON") + return + } + + if err := ga.RevokeToken(r.Context(), req.RefreshToken); err != nil { + resp := goauth.ErrorResponseForError(err) + writeError(w, resp.Status, resp.Code, resp.Message) + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"}) +} + +func handleLogoutAll(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "POST only") + return + } + + user := r.Context().Value("user").(goauth.Authenticatable) + + if err := ga.RevokeAllTokens(r.Context(), user); err != nil { + resp := goauth.ErrorResponseForError(err) + writeError(w, resp.Status, resp.Code, resp.Message) + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "all_sessions_revoked"}) +} + +func handleMe(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "GET only") + return + } + + user := r.Context().Value("user").(goauth.Authenticatable) + + writeJSON(w, http.StatusOK, UserResponse{ + ID: user.GetID(), + Username: user.GetUsername(), + Email: user.GetEmail(), + }) +} + +func handleSessions(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "GET only") + return + } + + user := r.Context().Value("user").(goauth.Authenticatable) + + sessions, err := ga.ListSessions(r.Context(), user) + if err != nil { + resp := goauth.ErrorResponseForError(err) + writeError(w, resp.Status, resp.Code, resp.Message) + return + } + + var response []SessionResponse + for _, s := range sessions { + response = append(response, SessionResponse{ + ID: s.ID, + CreatedAt: s.CreatedAt, + ExpiresAt: s.ExpiresAt, + Metadata: s.Metadata, + }) + } + + writeJSON(w, http.StatusOK, response) +} + +// Middleware + +func authMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeError(w, http.StatusUnauthorized, "token_missing", "Authorization header required") + return + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + writeError(w, http.StatusUnauthorized, "token_invalid", "Bearer token required") + return + } + + token := parts[1] + result, err := ga.Authenticate(r.Context(), "jwt", goauth.AuthParams{Token: token}) + if err != nil { + resp := goauth.ErrorResponseForError(err) + writeError(w, resp.Status, resp.Code, resp.Message) + return + } + + ctx := context.WithValue(r.Context(), "user", result.Authenticatable) + next(w, r.WithContext(ctx)) + } +} + +// Helpers + +func writeJSON(w http.ResponseWriter, status int, data any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +func writeError(w http.ResponseWriter, status int, code, message string) { + writeJSON(w, status, map[string]any{ + "error": code, + "message": message, + }) +} diff --git a/example/main.go b/example/main.go index ea069f6..3c29fd0 100644 --- a/example/main.go +++ b/example/main.go @@ -7,144 +7,326 @@ import ( "sync" "time" - goauth "github.com/openframebox/goauth" + goauth "github.com/openframebox/goauth/v2" ) -// inMemoryRefreshStore is a simple in-memory store for refresh tokens. -// It supports basic rotation by removing the previous token for a user when refreshing. -type inMemoryRefreshStore struct { - mu sync.Mutex - tokenToUserID map[string]string - userIDToToken map[string]string - usersByID map[string]*goauth.User +// inMemorySessionStore is a session-aware in-memory store for refresh tokens +// It supports multiple sessions per user for multi-device login +type inMemorySessionStore struct { + mu sync.RWMutex + sessions map[string]*sessionData // sessionID -> session data + tokens map[string]string // token -> sessionID + users map[string]*goauth.User // userID -> user data } -func newInMemoryRefreshStore() *inMemoryRefreshStore { - return &inMemoryRefreshStore{ - tokenToUserID: make(map[string]string), - userIDToToken: make(map[string]string), - usersByID: make(map[string]*goauth.User), +type sessionData struct { + session *goauth.SessionInfo + token string + userID string +} + +func newInMemorySessionStore() *inMemorySessionStore { + return &inMemorySessionStore{ + sessions: make(map[string]*sessionData), + tokens: make(map[string]string), + users: make(map[string]*goauth.User), } } -func (s *inMemoryRefreshStore) store(ctx context.Context, authenticatable goauth.Authenticatable, token *goauth.Token, refreshing bool) error { +// StoreSession stores a new session with its refresh token +func (s *inMemorySessionStore) store(ctx context.Context, auth goauth.Authenticatable, session *goauth.SessionInfo, token *goauth.Token, oldToken *string) error { s.mu.Lock() defer s.mu.Unlock() - userID := authenticatable.GetID() + userID := auth.GetID() - // On rotation, delete old token mapping for this user if exists - if refreshing { - if oldTok, ok := s.userIDToToken[userID]; ok { - delete(s.tokenToUserID, oldTok) + // If rotating, invalidate the old token + if oldToken != nil { + if oldSessionID, ok := s.tokens[*oldToken]; ok { + // Get the old session data before deleting + if oldData, exists := s.sessions[oldSessionID]; exists { + // Delete old token mapping + delete(s.tokens, oldData.token) + } + // Remove old session if it's being replaced + delete(s.sessions, oldSessionID) } } - s.tokenToUserID[token.Value] = userID - s.userIDToToken[userID] = token.Value + // Store new session + s.sessions[session.ID] = &sessionData{ + session: session, + token: token.Value, + userID: userID, + } + s.tokens[token.Value] = session.ID - // Keep a reference user for demo lookup (username/email optional) - if u, ok := authenticatable.(*goauth.User); ok { - s.usersByID[userID] = u + // Store user for lookup + if u, ok := auth.(*goauth.User); ok { + s.users[userID] = u } else { - s.usersByID[userID] = &goauth.User{ID: userID} + s.users[userID] = &goauth.User{ID: userID, Username: auth.GetUsername(), Email: auth.GetEmail()} + } + + return nil +} + +// ValidateSession validates a refresh token and returns the user and session +func (s *inMemorySessionStore) validate(ctx context.Context, token string) (goauth.Authenticatable, *goauth.SessionInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + sessionID, ok := s.tokens[token] + if !ok { + return nil, nil, fmt.Errorf("refresh token not found or revoked") + } + + data, ok := s.sessions[sessionID] + if !ok { + return nil, nil, fmt.Errorf("session not found") + } + + // Check expiry + if data.session.IsExpired() { + return nil, nil, fmt.Errorf("session expired") + } + + user, ok := s.users[data.userID] + if !ok { + return nil, nil, fmt.Errorf("user not found") + } + + return user, data.session, nil +} + +// RevokeSession revokes a specific session +func (s *inMemorySessionStore) revoke(ctx context.Context, auth goauth.Authenticatable, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + data, ok := s.sessions[sessionID] + if !ok { + return fmt.Errorf("session not found") + } + + // Verify user owns this session + if data.userID != auth.GetID() { + return fmt.Errorf("session does not belong to user") } + + delete(s.tokens, data.token) + delete(s.sessions, sessionID) return nil } -func (s *inMemoryRefreshStore) validate(ctx context.Context, token string) (goauth.Authenticatable, error) { +// RevokeAllSessions revokes all sessions for a user +func (s *inMemorySessionStore) revokeAll(ctx context.Context, auth goauth.Authenticatable) error { s.mu.Lock() defer s.mu.Unlock() - userID, ok := s.tokenToUserID[token] + userID := auth.GetID() + for sessionID, data := range s.sessions { + if data.userID == userID { + delete(s.tokens, data.token) + delete(s.sessions, sessionID) + } + } + return nil +} + +// ListSessions lists all active sessions for a user +func (s *inMemorySessionStore) list(ctx context.Context, auth goauth.Authenticatable) ([]*goauth.SessionInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + userID := auth.GetID() + var sessions []*goauth.SessionInfo + for _, data := range s.sessions { + if data.userID == userID && !data.session.IsExpired() { + sessions = append(sessions, data.session) + } + } + return sessions, nil +} + +// GetSession returns session info by token +func (s *inMemorySessionStore) getSession(ctx context.Context, token string) (*goauth.SessionInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + sessionID, ok := s.tokens[token] if !ok { - return nil, fmt.Errorf("refresh token not found or revoked") + return nil, fmt.Errorf("token not found") } - if u, ok := s.usersByID[userID]; ok { - return u, nil + + data, ok := s.sessions[sessionID] + if !ok { + return nil, fmt.Errorf("session not found") } - return &goauth.User{ID: userID}, nil + + return data.session, nil } func main() { ctx := context.Background() - // Configure token issuer - issuer := goauth.NewDefaultTokenIssuer("supersecret-change-me") - issuer.SetIssuer("api.example.local") - issuer.SetAudience([]string{"api.example.local"}) - issuer.SetAccessTokenExpiresIn(2 * time.Minute) - issuer.SetRefreshTokenExpiresIn(15 * time.Minute) - - // In-memory refresh token storage (for demo) - store := newInMemoryRefreshStore() - issuer.StoreRefreshTokenWith(store.store) - issuer.ValidateRefreshTokenWith(store.validate) - - // Add extra claims into access token and be able to reconstruct the user from claims - issuer.SetExtraClaimsWith(func(ctx context.Context, a goauth.Authenticatable) map[string]any { - return map[string]any{ - "username": a.GetUsername(), - "email": a.GetEmail(), - "role": "admin", // demo custom claim - } - }) - // Rely on default ConvertAccessTokenClaims which maps subject/username/email and passes through extra claims. + // Create session store + store := newInMemorySessionStore() + + // Configure session-aware token issuer using builder pattern + keyProvider, err := goauth.NewHMACKeyProvider([]byte("supersecret-change-me-in-production"), goauth.HS256) + if err != nil { + log.Fatalf("failed to create key provider: %v", err) + } + + issuer, err := goauth.NewSessionAwareTokenIssuer(). + WithKeyProvider(keyProvider). + WithIssuer("api.example.local"). + WithAudience([]string{"api.example.local"}). + WithAccessTokenTTL(5 * time.Minute). + WithRefreshTokenTTL(7 * 24 * time.Hour). + WithSessionStore(store.store, store.validate, store.revoke, store.revokeAll). + WithListSessions(store.list). + WithGetSession(store.getSession). + WithExtraClaims(func(ctx context.Context, a goauth.Authenticatable) map[string]any { + return map[string]any{ + "role": "admin", // demo custom claim + } + }). + WithSessionMetadataExtractor(func(ctx context.Context) map[string]any { + // In real app, extract device info, IP, user agent from context + return map[string]any{ + "device": "browser", + "ip": "127.0.0.1", + "user_agent": "Mozilla/5.0", + } + }). + Build() + + if err != nil { + log.Fatalf("failed to build token issuer: %v", err) + } - // Orchestrator + // Create GoAuth orchestrator ga := goauth.New() ga.SetTokenIssuer(issuer) - // Local strategy using a trivial credential check for demo purposes - ga.RegisterStrategy(&goauth.LocalStrategy{LookupUserWith: func(ctx context.Context, params goauth.AuthParams) (goauth.Authenticatable, error) { - // Demo: accept any non-empty username/password, construct an example user + // Register local strategy with builder pattern + localStrategy := goauth.NewLocalStrategy(func(ctx context.Context, params goauth.AuthParams) (goauth.Authenticatable, error) { + // Demo: accept any non-empty username/password if params.UsernameOrEmail == "" || params.Password == "" { - return nil, fmt.Errorf("missing credentials") + return nil, goauth.ErrInvalidCredentials } - u := &goauth.User{ + return &goauth.User{ ID: "user-" + params.UsernameOrEmail, Username: params.UsernameOrEmail, Email: params.UsernameOrEmail + "@example.local", - } - return u, nil - }}) + }, nil + }) + ga.RegisterStrategy(localStrategy) + + // Register JWT strategy with builder pattern + jwtStrategy := goauth.NewJWTStrategy(issuer). + WithExpectedType(goauth.AccessToken) + ga.RegisterStrategy(jwtStrategy) + + fmt.Println("=== Demo: Multi-Session Token Issuer ===") + fmt.Println() + + // --- Login from first device --- + fmt.Println("== 1. Login from Device 1 (Browser) ==") + authRes, pair, err := ga.AuthenticateAndIssueTokenPair(ctx, "local", goauth.AuthParams{ + UsernameOrEmail: "alice", + Password: "s3cret", + }) + if err != nil { + log.Fatalf("login failed: %v", err) + } - // JWT strategy for authenticating incoming requests by bearer token - ga.RegisterStrategy(&goauth.JWTStrategy{TokenIssuer: issuer}) + fmt.Printf("Authenticated: id=%s user=%s\n", authRes.Authenticatable.GetID(), authRes.Authenticatable.GetUsername()) + fmt.Printf("Session ID: %s\n", pair.Access.SessionID) + fmt.Printf("Access Token (exp %s): %s...\n", formatExpiry(pair.Access.ExpiresIn), truncate(pair.Access.Value, 50)) + fmt.Printf("Refresh Token: %s\n\n", pair.Refresh.Value) - fmt.Println("== Demo: Local login -> issue tokens ==") - authRes, accessTok, refreshTok, err := ga.AuthenticateAndIssueTokens(ctx, "local", goauth.AuthParams{ + device1RefreshToken := pair.Refresh.Value + + // --- Login from second device --- + fmt.Println("== 2. Login from Device 2 (Mobile App) ==") + _, pair2, err := ga.AuthenticateAndIssueTokenPair(ctx, "local", goauth.AuthParams{ UsernameOrEmail: "alice", Password: "s3cret", - Extra: map[string]any{"source": "example"}, // optional metadata }) if err != nil { log.Fatalf("login failed: %v", err) } - fmt.Printf("Authenticated as: id=%s user=%s email=%s via=%s\n", - authRes.Authenticatable.GetID(), authRes.Authenticatable.GetUsername(), authRes.Authenticatable.GetEmail(), authRes.Strategy) - fmt.Printf("Access Token (exp %s): %s\n", issuerExpiry(accessTok.ExpiresIn), accessTok.Value) - fmt.Printf("Refresh Token (exp %s): %s\n\n", issuerExpiry(refreshTok.ExpiresIn), refreshTok.Value) - fmt.Println("== Demo: Authenticate request using JWT strategy ==") - jwtRes, err := ga.Authenticate(ctx, "jwt", goauth.AuthParams{Token: accessTok.Value}) + fmt.Printf("Session ID: %s\n", pair2.Access.SessionID) + fmt.Printf("Refresh Token: %s\n\n", pair2.Refresh.Value) + + device2RefreshToken := pair2.Refresh.Value + + // --- List all sessions --- + fmt.Println("== 3. List Active Sessions ==") + sessions, err := ga.ListSessions(ctx, authRes.Authenticatable) + if err != nil { + log.Fatalf("list sessions failed: %v", err) + } + fmt.Printf("Active sessions for user %s: %d\n", authRes.Authenticatable.GetID(), len(sessions)) + for i, sess := range sessions { + device, _ := sess.GetMetadataString("device") + fmt.Printf(" %d. Session %s (device: %s, expires: %s)\n", + i+1, sess.ID, device, sess.ExpiresAt.Format(time.RFC3339)) + } + fmt.Println() + + // --- Authenticate with JWT --- + fmt.Println("== 4. Authenticate Request with JWT ==") + jwtRes, err := ga.Authenticate(ctx, "jwt", goauth.AuthParams{Token: pair.Access.Value}) if err != nil { log.Fatalf("jwt auth failed: %v", err) } - fmt.Printf("JWT resolved user: id=%s user=%s email=%s via=%s\n\n", - jwtRes.Authenticatable.GetID(), jwtRes.Authenticatable.GetUsername(), jwtRes.Authenticatable.GetEmail(), jwtRes.Strategy) + fmt.Printf("JWT resolved: id=%s user=%s\n\n", jwtRes.Authenticatable.GetID(), jwtRes.Authenticatable.GetUsername()) - fmt.Println("== Demo: Refresh tokens (rotation) ==") - newAccess, newRefresh, err := ga.RefreshToken(ctx, refreshTok.Value) + // --- Refresh token from device 1 --- + fmt.Println("== 5. Refresh Token (Device 1) ==") + newPair, err := ga.RefreshTokenPair(ctx, device1RefreshToken) if err != nil { log.Fatalf("refresh failed: %v", err) } - fmt.Printf("New Access Token (exp %s): %s\n", issuerExpiry(newAccess.ExpiresIn), newAccess.Value) - fmt.Printf("New Refresh Token (exp %s): %s\n\n", issuerExpiry(newRefresh.ExpiresIn), newRefresh.Value) + fmt.Printf("New Session ID: %s (should be same as before)\n", newPair.Access.SessionID) + fmt.Printf("New Refresh Token: %s\n", newPair.Refresh.Value) + fmt.Printf("Old token invalidated: %v\n\n", device1RefreshToken != newPair.Refresh.Value) - fmt.Println("Done.") + // --- Revoke device 2 session --- + fmt.Println("== 6. Revoke Device 2 Session ==") + session2, _ := store.getSession(ctx, device2RefreshToken) + err = ga.RevokeSession(ctx, authRes.Authenticatable, session2.ID) + if err != nil { + log.Fatalf("revoke session failed: %v", err) + } + fmt.Printf("Revoked session: %s\n", session2.ID) + + // Try to use revoked token + _, err = ga.RefreshTokenPair(ctx, device2RefreshToken) + if err != nil { + fmt.Printf("Device 2 refresh correctly failed: %v\n\n", err) + } + + // --- Final session count --- + fmt.Println("== 7. Final Session Count ==") + sessions, _ = ga.ListSessions(ctx, authRes.Authenticatable) + fmt.Printf("Remaining active sessions: %d\n\n", len(sessions)) + + fmt.Println("=== Demo Complete ===") } -func issuerExpiry(d time.Duration) string { +func formatExpiry(d time.Duration) string { return time.Now().Add(d).Format(time.RFC3339) } + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] +} diff --git a/go.mod b/go.mod index c525b16..d667d9d 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/openframebox/goauth +module github.com/openframebox/goauth/v2 go 1.25.0 diff --git a/goauth.go b/goauth.go index be9c3a0..d7287a5 100644 --- a/goauth.go +++ b/goauth.go @@ -3,111 +3,323 @@ package goauth import ( "context" "fmt" + "sync" ) +// GoAuth is the main orchestrator for authentication and token management type GoAuth struct { tokenIssuer TokenIssuer strategies map[string]Strategy + hooks AuthEventHooks + mu sync.RWMutex } +// New creates a new GoAuth instance func New() *GoAuth { - ga := &GoAuth{} - ga.strategies = make(map[string]Strategy) - + ga := &GoAuth{ + strategies: make(map[string]Strategy), + } return ga } +// RegisterStrategy registers an authentication strategy +// If a strategy with the same name already exists, it will not be replaced func (ga *GoAuth) RegisterStrategy(strategy Strategy) { + ga.mu.Lock() + defer ga.mu.Unlock() + if _, ok := ga.strategies[strategy.Name()]; !ok { ga.strategies[strategy.Name()] = strategy } } +// UnregisterStrategy removes a registered strategy by name +func (ga *GoAuth) UnregisterStrategy(name string) error { + ga.mu.Lock() + defer ga.mu.Unlock() + + if _, ok := ga.strategies[name]; !ok { + return &NotFoundError{Msg: fmt.Sprintf("strategy %s not found", name)} + } + + delete(ga.strategies, name) + return nil +} + +// HasStrategy checks if a strategy is registered +func (ga *GoAuth) HasStrategy(name string) bool { + ga.mu.RLock() + defer ga.mu.RUnlock() + + _, ok := ga.strategies[name] + return ok +} + +// ListStrategies returns the names of all registered strategies +func (ga *GoAuth) ListStrategies() []string { + ga.mu.RLock() + defer ga.mu.RUnlock() + + names := make([]string, 0, len(ga.strategies)) + for name := range ga.strategies { + names = append(names, name) + } + return names +} + +// SetTokenIssuer sets the token issuer for the GoAuth instance func (ga *GoAuth) SetTokenIssuer(tokenIssuer TokenIssuer) { ga.tokenIssuer = tokenIssuer } +// GetTokenIssuer returns the current token issuer +func (ga *GoAuth) GetTokenIssuer() TokenIssuer { + return ga.tokenIssuer +} + +// SetEventHooks sets the event hooks for authentication events +func (ga *GoAuth) SetEventHooks(hooks AuthEventHooks) { + ga.hooks = hooks +} + +// Authenticate authenticates using the specified strategy func (ga *GoAuth) Authenticate(ctx context.Context, strategy string, params AuthParams) (*AuthResult, error) { - s, err := ga.lookupStrategy(strategy) + // Call before hook + if ga.hooks != nil { + if err := ga.hooks.OnBeforeAuthenticate(ctx, strategy, params); err != nil { + return nil, err + } + } + s, err := ga.lookupStrategy(strategy) if err != nil { + if ga.hooks != nil { + ga.hooks.OnAfterAuthenticate(ctx, strategy, nil, err) + } return nil, err } user, err := s.Authenticate(ctx, params) if err != nil { + if ga.hooks != nil { + ga.hooks.OnAfterAuthenticate(ctx, strategy, nil, err) + } return nil, err } - return &AuthResult{ + result := &AuthResult{ Authenticatable: user, Strategy: s.Name(), - }, nil + } + + // Call after hook + if ga.hooks != nil { + ga.hooks.OnAfterAuthenticate(ctx, strategy, result, nil) + } + + return result, nil } +// IssueTokens creates access and refresh tokens for an authenticated entity +// Returns individual tokens (for backward compatibility) func (ga *GoAuth) IssueTokens(ctx context.Context, authenticatable Authenticatable) (accessToken *Token, refreshToken *Token, err error) { + pair, err := ga.IssueTokenPair(ctx, authenticatable) + if err != nil { + return nil, nil, err + } + return pair.Access, pair.Refresh, nil +} + +// IssueTokenPair creates access and refresh tokens as a TokenPair +func (ga *GoAuth) IssueTokenPair(ctx context.Context, authenticatable Authenticatable) (*TokenPair, error) { if ga.tokenIssuer == nil { - return nil, nil, ErrTokenIssuerUnset + return nil, ErrTokenIssuerUnset } - accessToken, err = ga.tokenIssuer.CreateAccessToken(ctx, authenticatable) + accessToken, err := ga.tokenIssuer.CreateAccessToken(ctx, authenticatable) if err != nil { - return nil, nil, err + return nil, err } - refreshToken, err = ga.tokenIssuer.CreateRefreshToken(ctx, authenticatable, false) + refreshToken, err := ga.tokenIssuer.CreateRefreshToken(ctx, authenticatable, nil) if err != nil { - return nil, nil, err + return nil, err + } + + pair := &TokenPair{ + Access: accessToken, + Refresh: refreshToken, + } + + // Call token issued hook + if ga.hooks != nil { + ga.hooks.OnTokenIssued(ctx, authenticatable, pair) } - return accessToken, refreshToken, nil + return pair, nil } +// AuthenticateAndIssueTokens authenticates and issues tokens in one call +// Returns individual tokens (for backward compatibility) func (ga *GoAuth) AuthenticateAndIssueTokens(ctx context.Context, strategy string, params AuthParams) (authResult *AuthResult, accessToken *Token, refreshToken *Token, err error) { - result, err := ga.Authenticate(ctx, strategy, params) + result, pair, err := ga.AuthenticateAndIssueTokenPair(ctx, strategy, params) if err != nil { return nil, nil, nil, err } + return result, pair.Access, pair.Refresh, nil +} - accessToken, refreshToken, err = ga.IssueTokens(ctx, result.Authenticatable) +// AuthenticateAndIssueTokenPair authenticates and issues tokens as a TokenPair +func (ga *GoAuth) AuthenticateAndIssueTokenPair(ctx context.Context, strategy string, params AuthParams) (*AuthResult, *TokenPair, error) { + result, err := ga.Authenticate(ctx, strategy, params) if err != nil { - return nil, nil, nil, err + return nil, nil, err + } + + pair, err := ga.IssueTokenPair(ctx, result.Authenticatable) + if err != nil { + return nil, nil, err } - return result, accessToken, refreshToken, nil + return result, pair, nil } +// RefreshToken validates the old refresh token and issues new tokens +// Returns individual tokens (for backward compatibility) func (ga *GoAuth) RefreshToken(ctx context.Context, token string) (accessToken *Token, refreshToken *Token, err error) { + pair, err := ga.RefreshTokenPair(ctx, token) + if err != nil { + return nil, nil, err + } + return pair.Access, pair.Refresh, nil +} + +// RefreshTokenPair validates the old refresh token and issues new tokens as a TokenPair +func (ga *GoAuth) RefreshTokenPair(ctx context.Context, token string) (*TokenPair, error) { if ga.tokenIssuer == nil { - return nil, nil, ErrTokenIssuerUnset + return nil, ErrTokenIssuerUnset } user, err := ga.tokenIssuer.ValidateRefreshToken(ctx, token) if err != nil { - return nil, nil, err + return nil, err } if user == nil { - return nil, nil, &TokenError{Msg: "invalid refresh token"} + return nil, &TokenError{Msg: "invalid refresh token"} } - accessToken, err = ga.tokenIssuer.CreateAccessToken(ctx, user) + accessToken, err := ga.tokenIssuer.CreateAccessToken(ctx, user) if err != nil { - return nil, nil, err + return nil, err } - refreshToken, err = ga.tokenIssuer.CreateRefreshToken(ctx, user, true) + // Pass the old token for proper rotation + refreshToken, err := ga.tokenIssuer.CreateRefreshToken(ctx, user, &token) if err != nil { - return nil, nil, err + return nil, err + } + + pair := &TokenPair{ + Access: accessToken, + Refresh: refreshToken, + } + + // Call token issued hook + if ga.hooks != nil { + ga.hooks.OnTokenIssued(ctx, user, pair) + } + + return pair, nil +} + +// RevokeToken revokes a refresh token +func (ga *GoAuth) RevokeToken(ctx context.Context, token string) error { + if ga.tokenIssuer == nil { + return ErrTokenIssuerUnset } - return accessToken, refreshToken, nil + // Get user before revoking for the hook + var user Authenticatable + if ga.hooks != nil { + user, _ = ga.tokenIssuer.ValidateRefreshToken(ctx, token) + } + + err := ga.tokenIssuer.RevokeRefreshToken(ctx, token) + if err != nil { + return err + } + + // Call token revoked hook + if ga.hooks != nil && user != nil { + ga.hooks.OnTokenRevoked(ctx, user, token) + } + + return nil } +// RevokeAllTokens revokes all sessions for an authenticated entity +// Only works if the token issuer implements SessionAwareTokenIssuer +func (ga *GoAuth) RevokeAllTokens(ctx context.Context, authenticatable Authenticatable) error { + if ga.tokenIssuer == nil { + return ErrTokenIssuerUnset + } + + sessionIssuer, ok := ga.tokenIssuer.(SessionAwareTokenIssuer) + if !ok { + return &ConfigError{Msg: "token issuer does not support session management"} + } + + return sessionIssuer.RevokeAllSessions(ctx, authenticatable) +} + +// ListSessions lists all active sessions for an authenticated entity +// Only works if the token issuer implements SessionAwareTokenIssuer +func (ga *GoAuth) ListSessions(ctx context.Context, authenticatable Authenticatable) ([]*SessionInfo, error) { + if ga.tokenIssuer == nil { + return nil, ErrTokenIssuerUnset + } + + sessionIssuer, ok := ga.tokenIssuer.(SessionAwareTokenIssuer) + if !ok { + return nil, &ConfigError{Msg: "token issuer does not support session management"} + } + + return sessionIssuer.ListSessions(ctx, authenticatable) +} + +// RevokeSession revokes a specific session by ID +// Only works if the token issuer implements SessionAwareTokenIssuer +func (ga *GoAuth) RevokeSession(ctx context.Context, authenticatable Authenticatable, sessionID string) error { + if ga.tokenIssuer == nil { + return ErrTokenIssuerUnset + } + + sessionIssuer, ok := ga.tokenIssuer.(SessionAwareTokenIssuer) + if !ok { + return &ConfigError{Msg: "token issuer does not support session management"} + } + + err := sessionIssuer.RevokeSession(ctx, authenticatable, sessionID) + if err != nil { + return err + } + + // Call session revoked hook + if ga.hooks != nil { + ga.hooks.OnSessionRevoked(ctx, authenticatable, &SessionInfo{ID: sessionID}) + } + + return nil +} + +// lookupStrategy finds a strategy by name (thread-safe) func (ga *GoAuth) lookupStrategy(name string) (Strategy, error) { + ga.mu.RLock() + defer ga.mu.RUnlock() + if strategy, ok := ga.strategies[name]; ok { return strategy, nil - } else { - return nil, &NotFoundError{Msg: fmt.Sprintf("strategy %s not found", name)} } + return nil, &NotFoundError{Msg: fmt.Sprintf("strategy %s not found", name)} } diff --git a/goauth_test.go b/goauth_test.go index c892569..192b67e 100644 --- a/goauth_test.go +++ b/goauth_test.go @@ -25,14 +25,16 @@ func TestNew(t *testing.T) { func TestRegisterStrategy(t *testing.T) { t.Run("should register a new strategy", func(t *testing.T) { goauth := New() - goauth.RegisterStrategy(&LocalStrategy{}) - goauth.RegisterStrategy(&JWTStrategy{}) + goauth.RegisterStrategy(NewLocalStrategy(func(ctx context.Context, params AuthParams) (Authenticatable, error) { + return nil, nil + })) + goauth.RegisterStrategy(NewJWTStrategy(NewDefaultTokenIssuer("test"))) - if _, ok := goauth.strategies["local"]; !ok { + if !goauth.HasStrategy("local") { t.Errorf("goauth.Strategies should contain a local strategy") } - if _, ok := goauth.strategies["jwt"]; !ok { + if !goauth.HasStrategy("jwt") { t.Errorf("goauth.Strategies should contain a jwt strategy") } }) @@ -42,15 +44,13 @@ func TestAuthenticate(t *testing.T) { goauth := New() t.Run("should authenticate a user with local strategy", func(t *testing.T) { - goauth.RegisterStrategy(&LocalStrategy{ - LookupUserWith: func(ctx context.Context, params AuthParams) (Authenticatable, error) { - return &User{ - ID: "test", - Username: "test", - Email: "test@test.com", - }, nil - }, - }) + goauth.RegisterStrategy(NewLocalStrategy(func(ctx context.Context, params AuthParams) (Authenticatable, error) { + return &User{ + ID: "test", + Username: "test", + Email: "test@test.com", + }, nil + })) result, err := goauth.Authenticate(context.TODO(), "local", AuthParams{ UsernameOrEmail: "test", @@ -101,14 +101,12 @@ func TestAuthenticate(t *testing.T) { }) tokenIssuer.SetIssuer("api.example.com") tokenIssuer.SetAudience([]string{"api.example.com", "auth.example.com"}) - tokenIssuer.StoreRefreshTokenWith(func(ctx context.Context, authenticatable Authenticatable, token *Token, refreshing bool) error { - // use refreshing to determine if the token is being refreshed or not + tokenIssuer.StoreRefreshTokenWith(func(ctx context.Context, authenticatable Authenticatable, token *Token, oldToken *string) error { + // oldToken is nil for initial login, non-nil for token refresh return nil }) - goauth.RegisterStrategy(&JWTStrategy{ - TokenIssuer: tokenIssuer, - }) + goauth.RegisterStrategy(NewJWTStrategy(tokenIssuer)) goauth.SetTokenIssuer(tokenIssuer) user := &User{ @@ -171,8 +169,8 @@ func TestIssueTokens(t *testing.T) { t.Run("should issue access and refresh tokens", func(t *testing.T) { var storedRefreshToken string tokenIssuer := NewDefaultTokenIssuer("testsecret") - tokenIssuer.StoreRefreshTokenWith(func(ctx context.Context, authenticatable Authenticatable, token *Token, refreshing bool) error { - // use refreshing to determine if the token is being refreshed or not + tokenIssuer.StoreRefreshTokenWith(func(ctx context.Context, authenticatable Authenticatable, token *Token, oldToken *string) error { + // oldToken is nil for initial login, non-nil for token refresh storedRefreshToken = token.Value return nil }) @@ -187,15 +185,13 @@ func TestIssueTokens(t *testing.T) { goauth := New() goauth.SetTokenIssuer(tokenIssuer) - goauth.RegisterStrategy(&LocalStrategy{ - LookupUserWith: func(ctx context.Context, params AuthParams) (Authenticatable, error) { - return &User{ - ID: "test", - Username: "test", - Email: "test@test.com", - }, nil - }, - }) + goauth.RegisterStrategy(NewLocalStrategy(func(ctx context.Context, params AuthParams) (Authenticatable, error) { + return &User{ + ID: "test", + Username: "test", + Email: "test@test.com", + }, nil + })) result, accessToken, refreshToken, err := goauth.AuthenticateAndIssueTokens(context.TODO(), "local", AuthParams{ UsernameOrEmail: "test", diff --git a/http_errors.go b/http_errors.go index f5753af..209575a 100644 --- a/http_errors.go +++ b/http_errors.go @@ -3,6 +3,7 @@ package goauth import ( "errors" "net/http" + "time" ) // HTTPStatusForError maps typed goauth errors to an HTTP status code. @@ -16,10 +17,18 @@ func HTTPStatusForError(err error) int { switch { case errors.Is(err, ErrMissingToken), errors.Is(err, ErrInvalidToken), errors.Is(err, ErrExpiredToken): return http.StatusUnauthorized + case errors.Is(err, ErrTokenRevoked): + return http.StatusUnauthorized + case errors.Is(err, ErrTokenTypeMismatch): + return http.StatusUnauthorized case errors.Is(err, ErrStrategyNotFound): return http.StatusNotFound - case errors.Is(err, ErrTokenIssuerUnset): + case errors.Is(err, ErrSessionNotFound): + return http.StatusNotFound + case errors.Is(err, ErrTokenIssuerUnset), errors.Is(err, ErrKeyProviderUnset), errors.Is(err, ErrSessionStoreUnset): return http.StatusInternalServerError + case errors.Is(err, ErrRateLimitExceeded): + return http.StatusTooManyRequests } // Category types next @@ -29,12 +38,21 @@ func HTTPStatusForError(err error) int { cfg *ConfigError nf *NotFoundError inr *InternalError + rl *RateLimitError + val *ValidationError + sess *SessionError ) switch { + case errors.As(err, &val): + return http.StatusBadRequest + case errors.As(err, &rl): + return http.StatusTooManyRequests case errors.As(err, &cred): return http.StatusUnauthorized case errors.As(err, &tok): return http.StatusUnauthorized + case errors.As(err, &sess): + return http.StatusUnauthorized case errors.As(err, &cfg): return http.StatusInternalServerError case errors.As(err, &nf): @@ -61,10 +79,24 @@ func ErrorCodeForError(err error) string { return "token_invalid" case errors.Is(err, ErrExpiredToken): return "token_expired" + case errors.Is(err, ErrTokenRevoked): + return "token_revoked" + case errors.Is(err, ErrTokenTypeMismatch): + return "token_type_mismatch" case errors.Is(err, ErrStrategyNotFound): return "strategy_not_found" + case errors.Is(err, ErrSessionNotFound): + return "session_not_found" case errors.Is(err, ErrTokenIssuerUnset): return "config_token_issuer_unset" + case errors.Is(err, ErrKeyProviderUnset): + return "config_key_provider_unset" + case errors.Is(err, ErrSessionStoreUnset): + return "config_session_store_unset" + case errors.Is(err, ErrRateLimitExceeded): + return "rate_limit_exceeded" + case errors.Is(err, ErrUserNotFound): + return "user_not_found" } // Category types next @@ -74,12 +106,21 @@ func ErrorCodeForError(err error) string { cfg *ConfigError nf *NotFoundError inr *InternalError + rl *RateLimitError + val *ValidationError + sess *SessionError ) switch { + case errors.As(err, &val): + return "validation_error" + case errors.As(err, &rl): + return "rate_limit_error" case errors.As(err, &cred): return "invalid_credentials" case errors.As(err, &tok): return "token_error" + case errors.As(err, &sess): + return "session_error" case errors.As(err, &cfg): return "config_error" case errors.As(err, &nf): @@ -90,3 +131,44 @@ func ErrorCodeForError(err error) string { return "internal_error" } } + +// ErrorResponse represents a structured HTTP error response +type ErrorResponse struct { + Status int `json:"status"` + Code string `json:"code"` + Message string `json:"message"` + Fields map[string]string `json:"fields,omitempty"` // For validation errors + RetryAfter int `json:"retry_after,omitempty"` // For rate limit errors (seconds) +} + +// ErrorResponseForError creates a structured error response from an error +func ErrorResponseForError(err error) ErrorResponse { + resp := ErrorResponse{ + Status: HTTPStatusForError(err), + Code: ErrorCodeForError(err), + Message: err.Error(), + } + + // Add extra fields for specific error types + var val *ValidationError + if errors.As(err, &val) && val.Fields != nil { + resp.Fields = val.Fields + } + + var rl *RateLimitError + if errors.As(err, &rl) && rl.RetryAfter > 0 { + resp.RetryAfter = int(rl.RetryAfter / time.Second) + } + + return resp +} + +// RetryAfterForError returns the Retry-After header value for rate limit errors +// Returns 0 if the error is not a rate limit error +func RetryAfterForError(err error) time.Duration { + var rl *RateLimitError + if errors.As(err, &rl) { + return rl.RetryAfter + } + return 0 +} diff --git a/interface.go b/interface.go index 94775c0..667af03 100644 --- a/interface.go +++ b/interface.go @@ -2,21 +2,15 @@ package goauth import ( "context" - - "github.com/golang-jwt/jwt/v5" ) -type StoreRefreshTokenFunc func(ctx context.Context, authenticatable Authenticatable, token *Token, refreshing bool) error -type SetExtraClaimsFunc func(ctx context.Context, authenticatable Authenticatable) map[string]any -type SetRegisteredClaimsFunc func(ctx context.Context, authenticatable Authenticatable) jwt.RegisteredClaims -type ConvertAccessTokenClaimsFunc func(ctx context.Context, claims *TokenClaims) (Authenticatable, error) -type ValidateRefreshTokenFunc func(ctx context.Context, token string) (Authenticatable, error) - +// Strategy defines the authentication strategy interface type Strategy interface { Name() string Authenticate(ctx context.Context, params AuthParams) (Authenticatable, error) } +// Authenticatable represents an authenticated entity (user, service, etc.) type Authenticatable interface { GetID() string GetUsername() string @@ -24,10 +18,119 @@ type Authenticatable interface { GetExtra() map[string]any } +// TokenIssuer defines the contract for token creation and validation type TokenIssuer interface { + // CreateAccessToken generates a new access token for the authenticated entity CreateAccessToken(ctx context.Context, authenticatable Authenticatable) (*Token, error) - CreateRefreshToken(ctx context.Context, authenticatable Authenticatable, refreshing bool) (*Token, error) + + // CreateRefreshToken generates a new refresh token + // oldToken is the previous refresh token being rotated (nil for initial login) + CreateRefreshToken(ctx context.Context, authenticatable Authenticatable, oldToken *string) (*Token, error) + + // DecodeAccessToken parses and validates an access token, returning its claims DecodeAccessToken(ctx context.Context, token string) (*TokenClaims, error) + + // ConvertAccessTokenClaims converts token claims back to an Authenticatable entity ConvertAccessTokenClaims(ctx context.Context, claims *TokenClaims) (Authenticatable, error) + + // ValidateRefreshToken validates a refresh token and returns the associated entity ValidateRefreshToken(ctx context.Context, token string) (Authenticatable, error) + + // RevokeRefreshToken invalidates a refresh token + RevokeRefreshToken(ctx context.Context, token string) error +} + +// SessionAwareTokenIssuer extends TokenIssuer with session management capabilities +type SessionAwareTokenIssuer interface { + TokenIssuer + + // GetSession returns session information for a refresh token + GetSession(ctx context.Context, token string) (*SessionInfo, error) + + // RevokeSession revokes a specific session by ID + RevokeSession(ctx context.Context, authenticatable Authenticatable, sessionID string) error + + // RevokeAllSessions revokes all sessions for an authenticated entity + RevokeAllSessions(ctx context.Context, authenticatable Authenticatable) error + + // ListSessions returns all active sessions for an authenticated entity + ListSessions(ctx context.Context, authenticatable Authenticatable) ([]*SessionInfo, error) +} + +// AuthEventHooks provides hooks for authentication events +// Implement this interface to add custom logic (logging, audit, rate limiting, etc.) +type AuthEventHooks interface { + // OnBeforeAuthenticate is called before authentication + // Return an error to prevent authentication (e.g., rate limiting) + OnBeforeAuthenticate(ctx context.Context, strategy string, params AuthParams) error + + // OnAfterAuthenticate is called after authentication (success or failure) + OnAfterAuthenticate(ctx context.Context, strategy string, result *AuthResult, err error) + + // OnTokenIssued is called when tokens are issued + OnTokenIssued(ctx context.Context, authenticatable Authenticatable, tokens *TokenPair) + + // OnTokenRevoked is called when a token is revoked + OnTokenRevoked(ctx context.Context, authenticatable Authenticatable, token string) + + // OnSessionCreated is called when a new session is created + OnSessionCreated(ctx context.Context, authenticatable Authenticatable, session *SessionInfo) + + // OnSessionRevoked is called when a session is revoked + OnSessionRevoked(ctx context.Context, authenticatable Authenticatable, session *SessionInfo) +} + +// NoOpEventHooks is a default implementation of AuthEventHooks that does nothing +// Embed this in your custom hooks to only override the methods you need +type NoOpEventHooks struct{} + +func (h *NoOpEventHooks) OnBeforeAuthenticate(ctx context.Context, strategy string, params AuthParams) error { + return nil +} + +func (h *NoOpEventHooks) OnAfterAuthenticate(ctx context.Context, strategy string, result *AuthResult, err error) { +} + +func (h *NoOpEventHooks) OnTokenIssued(ctx context.Context, authenticatable Authenticatable, tokens *TokenPair) { +} + +func (h *NoOpEventHooks) OnTokenRevoked(ctx context.Context, authenticatable Authenticatable, token string) { +} + +func (h *NoOpEventHooks) OnSessionCreated(ctx context.Context, authenticatable Authenticatable, session *SessionInfo) { +} + +func (h *NoOpEventHooks) OnSessionRevoked(ctx context.Context, authenticatable Authenticatable, session *SessionInfo) { +} + +// PasswordValidator defines the contract for password validation +type PasswordValidator interface { + // ValidatePassword checks if the plain password matches the hashed password + ValidatePassword(plain, hashed string) bool +} + +// PasswordValidatorFunc is a function adapter for PasswordValidator +type PasswordValidatorFunc func(plain, hashed string) bool + +func (f PasswordValidatorFunc) ValidatePassword(plain, hashed string) bool { + return f(plain, hashed) +} + +// RateLimiter defines the contract for rate limiting authentication attempts +type RateLimiter interface { + // CheckRateLimit checks if the authentication attempt is allowed + // Returns nil if allowed, RateLimitError if exceeded + CheckRateLimit(ctx context.Context, identifier string) error + + // RecordAttempt records an authentication attempt (success or failure) + RecordAttempt(ctx context.Context, identifier string, success bool) +} + +// TokenRevoker defines the contract for checking token revocation +type TokenRevoker interface { + // IsRevoked checks if a token has been revoked + IsRevoked(ctx context.Context, token string) bool + + // Revoke marks a token as revoked + Revoke(ctx context.Context, token string) error } diff --git a/jwt_strategy.go b/jwt_strategy.go index 97a2bef..178117a 100644 --- a/jwt_strategy.go +++ b/jwt_strategy.go @@ -4,28 +4,109 @@ import ( "context" ) +// CheckRevokedFunc checks if a token has been revoked +type CheckRevokedFunc func(ctx context.Context, token string) bool + +// ConvertClaimsFunc converts token claims to an Authenticatable +type ConvertClaimsFunc func(ctx context.Context, claims *TokenClaims) (Authenticatable, error) + +// JWTStrategy implements JWT token-based authentication type JWTStrategy struct { - TokenIssuer TokenIssuer + name string + tokenIssuer TokenIssuer + expectedType TokenType // Optional: validate token type + checkRevoked CheckRevokedFunc + convertClaims ConvertClaimsFunc +} + +// NewJWTStrategy creates a new JWTStrategy with the given token issuer +func NewJWTStrategy(tokenIssuer TokenIssuer) *JWTStrategy { + return &JWTStrategy{ + name: "jwt", + tokenIssuer: tokenIssuer, + } +} + +// WithName sets a custom name for the strategy +func (js *JWTStrategy) WithName(name string) *JWTStrategy { + js.name = name + return js +} + +// WithExpectedType sets the expected token type +// When set, the strategy will reject tokens that don't match +func (js *JWTStrategy) WithExpectedType(tokenType TokenType) *JWTStrategy { + js.expectedType = tokenType + return js +} + +// WithRevocationCheck sets the revocation check function +func (js *JWTStrategy) WithRevocationCheck(check CheckRevokedFunc) *JWTStrategy { + js.checkRevoked = check + return js +} + +// WithClaimsConverter sets a custom claims to Authenticatable converter +// This overrides the TokenIssuer's ConvertAccessTokenClaims +func (js *JWTStrategy) WithClaimsConverter(convert ConvertClaimsFunc) *JWTStrategy { + js.convertClaims = convert + return js } -func (ls *JWTStrategy) Name() string { - return "jwt" +// Name returns the strategy name +func (js *JWTStrategy) Name() string { + return js.name } -func (ls *JWTStrategy) Authenticate(ctx context.Context, params AuthParams) (Authenticatable, error) { +// Authenticate authenticates using a JWT token +func (js *JWTStrategy) Authenticate(ctx context.Context, params AuthParams) (Authenticatable, error) { token := params.Token if token == "" { - return nil, &TokenError{Msg: "token is required"} + return nil, ErrMissingToken + } + + // Check if token is revoked + if js.checkRevoked != nil && js.checkRevoked(ctx, token) { + return nil, ErrTokenRevoked } - claims, err := ls.TokenIssuer.DecodeAccessToken(ctx, token) + // Decode and validate token + claims, err := js.tokenIssuer.DecodeAccessToken(ctx, token) if err != nil { - return nil, withContext(&TokenError{Err: err}, "failed to decode access token") + return nil, WithContext(err, "failed to decode access token") + } + + // Validate token type if expected type is set + if js.expectedType != "" && claims.TokenType != js.expectedType { + // Special case: empty token type is treated as access token (backward compatibility) + if !(js.expectedType == AccessToken && claims.TokenType == "") { + return nil, ErrTokenTypeMismatch + } + } + + // Convert claims to user + var user Authenticatable + if js.convertClaims != nil { + user, err = js.convertClaims(ctx, claims) + } else { + user, err = js.tokenIssuer.ConvertAccessTokenClaims(ctx, claims) } - user, err := ls.TokenIssuer.ConvertAccessTokenClaims(ctx, claims) + if err != nil { - return nil, withContext(&TokenError{Err: err}, "failed to convert token claims") + return nil, WithContext(err, "failed to convert token claims") } return user, nil } + +// GetTokenIssuer returns the underlying token issuer +// Deprecated: Access TokenIssuer directly +func (js *JWTStrategy) GetTokenIssuer() TokenIssuer { + return js.tokenIssuer +} + +// SetTokenIssuer sets the token issuer (for backward compatibility) +// Deprecated: Use NewJWTStrategy instead +func (js *JWTStrategy) SetTokenIssuer(ti TokenIssuer) { + js.tokenIssuer = ti +} diff --git a/local_strategy.go b/local_strategy.go index df13082..a99749b 100644 --- a/local_strategy.go +++ b/local_strategy.go @@ -3,45 +3,171 @@ package goauth import ( "context" "errors" + "strings" ) +// LookupUserFunc looks up a user by credentials and returns an Authenticatable +// The returned user should have a hashed password available if password validation is used type LookupUserFunc func(ctx context.Context, params AuthParams) (Authenticatable, error) +// ValidatePasswordFunc validates a plain password against a hashed password +type ValidatePasswordFunc func(plain, hashed string) bool + +// RateLimitCheckFunc checks if an authentication attempt should be rate limited +// Returns nil if allowed, RateLimitError if exceeded +type RateLimitCheckFunc func(ctx context.Context, identifier string) error + +// RecordAttemptFunc records an authentication attempt for rate limiting +type RecordAttemptFunc func(ctx context.Context, identifier string, success bool) + +// NormalizeUsernameFunc normalizes a username (e.g., trim whitespace, lowercase) +type NormalizeUsernameFunc func(username string) string + +// GetHashedPasswordFunc retrieves the hashed password from an Authenticatable +// Used when password validation is enabled +type GetHashedPasswordFunc func(user Authenticatable) string + +// LocalStrategy implements username/password authentication type LocalStrategy struct { - LookupUserWith LookupUserFunc + name string + lookupUser LookupUserFunc + validatePassword ValidatePasswordFunc + getHashedPassword GetHashedPasswordFunc + checkRateLimit RateLimitCheckFunc + recordAttempt RecordAttemptFunc + normalizeUsername NormalizeUsernameFunc } +// NewLocalStrategy creates a new LocalStrategy with the given lookup function +func NewLocalStrategy(lookupUser LookupUserFunc) *LocalStrategy { + return &LocalStrategy{ + name: "local", + lookupUser: lookupUser, + normalizeUsername: func(username string) string { + return strings.TrimSpace(username) + }, + } +} + +// WithName sets a custom name for the strategy +func (ls *LocalStrategy) WithName(name string) *LocalStrategy { + ls.name = name + return ls +} + +// WithPasswordValidator sets the password validation function +// When set, the strategy will validate the password from AuthParams against +// the hashed password retrieved via GetHashedPassword +func (ls *LocalStrategy) WithPasswordValidator(validate ValidatePasswordFunc, getHashed GetHashedPasswordFunc) *LocalStrategy { + ls.validatePassword = validate + ls.getHashedPassword = getHashed + return ls +} + +// WithRateLimiter sets the rate limiting functions +func (ls *LocalStrategy) WithRateLimiter(check RateLimitCheckFunc, record RecordAttemptFunc) *LocalStrategy { + ls.checkRateLimit = check + ls.recordAttempt = record + return ls +} + +// WithUsernameNormalizer sets a custom username normalization function +func (ls *LocalStrategy) WithUsernameNormalizer(normalize NormalizeUsernameFunc) *LocalStrategy { + ls.normalizeUsername = normalize + return ls +} + +// Name returns the strategy name func (ls *LocalStrategy) Name() string { - return "local" + return ls.name } +// Authenticate authenticates a user with username/email and password func (ls *LocalStrategy) Authenticate(ctx context.Context, params AuthParams) (Authenticatable, error) { - user, err := ls.LookupUserWith(ctx, params) + // Normalize username + if ls.normalizeUsername != nil { + params.UsernameOrEmail = ls.normalizeUsername(params.UsernameOrEmail) + } + + // Check rate limit + if ls.checkRateLimit != nil { + if err := ls.checkRateLimit(ctx, params.UsernameOrEmail); err != nil { + return nil, err + } + } + + // Lookup user + user, err := ls.lookupUser(ctx, params) if err != nil { - // If the Lookup returns a known typed error, forward it. - var ( - credErr *CredentialError - tokErr *TokenError - cfgErr *ConfigError - nfErr *NotFoundError - intErr *InternalError - ) - switch { - case errors.As(err, &credErr): - return nil, credErr - case errors.As(err, &tokErr): - return nil, tokErr - case errors.As(err, &cfgErr): - return nil, cfgErr - case errors.As(err, &nfErr): - return nil, nfErr - case errors.As(err, &intErr): - return nil, intErr - default: - // Unknown error -> treat as internal failure - return nil, &InternalError{Msg: "lookup user failed", Err: err} + // Record failed attempt + if ls.recordAttempt != nil { + ls.recordAttempt(ctx, params.UsernameOrEmail, false) } + return nil, forwardTypedError(err) + } + + // Validate password if enabled + if ls.validatePassword != nil && ls.getHashedPassword != nil { + hashedPassword := ls.getHashedPassword(user) + if !ls.validatePassword(params.Password, hashedPassword) { + // Record failed attempt + if ls.recordAttempt != nil { + ls.recordAttempt(ctx, params.UsernameOrEmail, false) + } + return nil, ErrInvalidCredentials + } + } + + // Record successful attempt + if ls.recordAttempt != nil { + ls.recordAttempt(ctx, params.UsernameOrEmail, true) } return user, nil } + +// LookupUserWith is kept for backward compatibility +// Deprecated: Use NewLocalStrategy instead +func (ls *LocalStrategy) SetLookupUser(fn LookupUserFunc) { + ls.lookupUser = fn +} + +// forwardTypedError forwards known typed errors, wrapping unknown errors as InternalError +func forwardTypedError(err error) error { + if err == nil { + return nil + } + + var ( + credErr *CredentialError + tokErr *TokenError + cfgErr *ConfigError + nfErr *NotFoundError + intErr *InternalError + rlErr *RateLimitError + valErr *ValidationError + sessErr *SessionError + ) + + switch { + case errors.As(err, &credErr): + return credErr + case errors.As(err, &tokErr): + return tokErr + case errors.As(err, &cfgErr): + return cfgErr + case errors.As(err, &nfErr): + return nfErr + case errors.As(err, &intErr): + return intErr + case errors.As(err, &rlErr): + return rlErr + case errors.As(err, &valErr): + return valErr + case errors.As(err, &sessErr): + return sessErr + default: + // Unknown error -> treat as internal failure + return &InternalError{Msg: "lookup user failed", Err: err} + } +} diff --git a/session_aware_token_issuer.go b/session_aware_token_issuer.go new file mode 100644 index 0000000..2ab910d --- /dev/null +++ b/session_aware_token_issuer.go @@ -0,0 +1,467 @@ +package goauth + +import ( + "context" + "crypto/rsa" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// Session storage callback function types + +// StoreSessionFunc stores a session and its refresh token +// oldToken is the previous refresh token being rotated (nil for initial login) +type StoreSessionFunc func(ctx context.Context, auth Authenticatable, session *SessionInfo, token *Token, oldToken *string) error + +// ValidateSessionFunc validates a refresh token and returns the user and session +type ValidateSessionFunc func(ctx context.Context, token string) (Authenticatable, *SessionInfo, error) + +// RevokeSessionFunc revokes a specific session +type RevokeSessionFunc func(ctx context.Context, auth Authenticatable, sessionID string) error + +// RevokeAllSessionsFunc revokes all sessions for a user +type RevokeAllSessionsFunc func(ctx context.Context, auth Authenticatable) error + +// ListSessionsFunc lists all active sessions for a user +type ListSessionsFunc func(ctx context.Context, auth Authenticatable) ([]*SessionInfo, error) + +// GetSessionFunc gets session info by refresh token +type GetSessionFunc func(ctx context.Context, token string) (*SessionInfo, error) + +// GenerateSessionIDFunc generates a unique session ID +type GenerateSessionIDFunc func(ctx context.Context) string + +// ExtractSessionMetadataFunc extracts session metadata from context (device, IP, etc.) +type ExtractSessionMetadataFunc func(ctx context.Context) map[string]any + +// SessionTokenIssuer implements TokenIssuer and SessionAwareTokenIssuer interfaces +// with full multi-session support and configurable signing methods +type SessionTokenIssuer struct { + keyProvider KeyProvider + issuer string + audience []string + accessTokenTTL time.Duration + refreshTokenTTL time.Duration + + // Session storage callbacks + storeSession StoreSessionFunc + validateSession ValidateSessionFunc + revokeSession RevokeSessionFunc + revokeAllSessions RevokeAllSessionsFunc + listSessions ListSessionsFunc + getSession GetSessionFunc + + // Optional customization callbacks + setExtraClaims SetExtraClaimsFunc + setRegisteredClaims SetRegisteredClaimsFunc + convertClaims ConvertAccessTokenClaimsFunc + generateSessionID GenerateSessionIDFunc + extractSessionMeta ExtractSessionMetadataFunc +} + +// SessionTokenIssuerBuilder provides a fluent API for building SessionTokenIssuer +type SessionTokenIssuerBuilder struct { + issuer *SessionTokenIssuer + errors []error +} + +// NewSessionAwareTokenIssuer creates a new builder for SessionTokenIssuer +func NewSessionAwareTokenIssuer() *SessionTokenIssuerBuilder { + return &SessionTokenIssuerBuilder{ + issuer: &SessionTokenIssuer{ + issuer: "goauth", + audience: []string{"goauth"}, + accessTokenTTL: 5 * time.Minute, + refreshTokenTTL: 7 * 24 * time.Hour, // 7 days default for sessions + generateSessionID: func(ctx context.Context) string { + return uuid.New().String() + }, + }, + } +} + +// WithHMACSecret configures HMAC signing with the given secret +func (b *SessionTokenIssuerBuilder) WithHMACSecret(secret []byte, method SigningMethod) *SessionTokenIssuerBuilder { + kp, err := NewHMACKeyProvider(secret, method) + if err != nil { + b.errors = append(b.errors, err) + return b + } + b.issuer.keyProvider = kp + return b +} + +// WithRSAKeys configures RSA signing with the given keys +func (b *SessionTokenIssuerBuilder) WithRSAKeys(privateKey *rsa.PrivateKey, publicKey *rsa.PublicKey, method SigningMethod) *SessionTokenIssuerBuilder { + kp, err := NewRSAKeyProvider(privateKey, publicKey, method) + if err != nil { + b.errors = append(b.errors, err) + return b + } + b.issuer.keyProvider = kp + return b +} + +// WithKeyProvider sets a custom key provider +func (b *SessionTokenIssuerBuilder) WithKeyProvider(kp KeyProvider) *SessionTokenIssuerBuilder { + b.issuer.keyProvider = kp + return b +} + +// WithIssuer sets the JWT issuer claim +func (b *SessionTokenIssuerBuilder) WithIssuer(issuer string) *SessionTokenIssuerBuilder { + b.issuer.issuer = issuer + return b +} + +// WithAudience sets the JWT audience claim +func (b *SessionTokenIssuerBuilder) WithAudience(audience []string) *SessionTokenIssuerBuilder { + b.issuer.audience = audience + return b +} + +// WithAccessTokenTTL sets the access token time-to-live +func (b *SessionTokenIssuerBuilder) WithAccessTokenTTL(ttl time.Duration) *SessionTokenIssuerBuilder { + b.issuer.accessTokenTTL = ttl + return b +} + +// WithRefreshTokenTTL sets the refresh token time-to-live +func (b *SessionTokenIssuerBuilder) WithRefreshTokenTTL(ttl time.Duration) *SessionTokenIssuerBuilder { + b.issuer.refreshTokenTTL = ttl + return b +} + +// WithSessionStore sets the session storage callbacks +func (b *SessionTokenIssuerBuilder) WithSessionStore( + store StoreSessionFunc, + validate ValidateSessionFunc, + revoke RevokeSessionFunc, + revokeAll RevokeAllSessionsFunc, +) *SessionTokenIssuerBuilder { + b.issuer.storeSession = store + b.issuer.validateSession = validate + b.issuer.revokeSession = revoke + b.issuer.revokeAllSessions = revokeAll + return b +} + +// WithListSessions sets the list sessions callback +func (b *SessionTokenIssuerBuilder) WithListSessions(fn ListSessionsFunc) *SessionTokenIssuerBuilder { + b.issuer.listSessions = fn + return b +} + +// WithGetSession sets the get session callback +func (b *SessionTokenIssuerBuilder) WithGetSession(fn GetSessionFunc) *SessionTokenIssuerBuilder { + b.issuer.getSession = fn + return b +} + +// WithExtraClaims sets the extra claims callback +func (b *SessionTokenIssuerBuilder) WithExtraClaims(fn SetExtraClaimsFunc) *SessionTokenIssuerBuilder { + b.issuer.setExtraClaims = fn + return b +} + +// WithRegisteredClaims sets the registered claims callback +func (b *SessionTokenIssuerBuilder) WithRegisteredClaims(fn SetRegisteredClaimsFunc) *SessionTokenIssuerBuilder { + b.issuer.setRegisteredClaims = fn + return b +} + +// WithClaimsConverter sets the claims to Authenticatable converter +func (b *SessionTokenIssuerBuilder) WithClaimsConverter(fn ConvertAccessTokenClaimsFunc) *SessionTokenIssuerBuilder { + b.issuer.convertClaims = fn + return b +} + +// WithSessionIDGenerator sets a custom session ID generator +func (b *SessionTokenIssuerBuilder) WithSessionIDGenerator(fn GenerateSessionIDFunc) *SessionTokenIssuerBuilder { + b.issuer.generateSessionID = fn + return b +} + +// WithSessionMetadataExtractor sets a custom session metadata extractor +func (b *SessionTokenIssuerBuilder) WithSessionMetadataExtractor(fn ExtractSessionMetadataFunc) *SessionTokenIssuerBuilder { + b.issuer.extractSessionMeta = fn + return b +} + +// Build creates the SessionTokenIssuer, returning any configuration errors +func (b *SessionTokenIssuerBuilder) Build() (*SessionTokenIssuer, error) { + if len(b.errors) > 0 { + return nil, b.errors[0] + } + + if b.issuer.keyProvider == nil { + return nil, ErrKeyProviderUnset + } + + if b.issuer.storeSession == nil || b.issuer.validateSession == nil { + return nil, ErrSessionStoreUnset + } + + return b.issuer, nil +} + +// CreateAccessToken creates a new JWT access token with session ID +func (ti *SessionTokenIssuer) CreateAccessToken(ctx context.Context, auth Authenticatable) (*Token, error) { + return ti.CreateAccessTokenWithSession(ctx, auth, "") +} + +// CreateAccessTokenWithSession creates a new JWT access token with a specific session ID +func (ti *SessionTokenIssuer) CreateAccessTokenWithSession(ctx context.Context, auth Authenticatable, sessionID string) (*Token, error) { + if ti.keyProvider == nil { + return nil, ErrKeyProviderUnset + } + + extraClaims := make(map[string]any) + if ti.setExtraClaims != nil { + extraClaims = ti.setExtraClaims(ctx, auth) + } + + now := time.Now() + var registeredClaims jwt.RegisteredClaims + if ti.setRegisteredClaims != nil { + registeredClaims = ti.setRegisteredClaims(ctx, auth) + } else { + registeredClaims = jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(ti.accessTokenTTL)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + Subject: auth.GetID(), + Issuer: ti.issuer, + Audience: ti.audience, + } + } + + claims := TokenClaims{ + RegisteredClaims: registeredClaims, + Username: auth.GetUsername(), + Email: auth.GetEmail(), + TokenType: AccessToken, + SessionID: sessionID, + ExtraClaims: extraClaims, + } + + token := jwt.NewWithClaims(ti.keyProvider.Method(), claims) + tokenString, err := token.SignedString(ti.keyProvider.SignKey()) + if err != nil { + return nil, &InternalError{Msg: "failed to sign access token", Err: err} + } + + return &Token{ + Value: tokenString, + Type: AccessToken, + ExpiresIn: ti.accessTokenTTL, + IssuedAt: now, + SessionID: sessionID, + }, nil +} + +// CreateRefreshToken creates a new refresh token with session +func (ti *SessionTokenIssuer) CreateRefreshToken(ctx context.Context, auth Authenticatable, oldToken *string) (*Token, error) { + if ti.storeSession == nil { + return nil, ErrSessionStoreUnset + } + + now := time.Now() + + // Generate session ID - reuse from old token if rotating, otherwise generate new + var sessionID string + if oldToken != nil && ti.getSession != nil { + oldSession, err := ti.getSession(ctx, *oldToken) + if err == nil && oldSession != nil { + sessionID = oldSession.ID + } + } + if sessionID == "" { + sessionID = ti.generateSessionID(ctx) + } + + // Extract session metadata from context + var metadata map[string]any + if ti.extractSessionMeta != nil { + metadata = ti.extractSessionMeta(ctx) + } + + session := &SessionInfo{ + ID: sessionID, + UserID: auth.GetID(), + CreatedAt: now, + ExpiresAt: now.Add(ti.refreshTokenTTL), + Metadata: metadata, + } + + tokenString := uuid.New().String() + token := &Token{ + Value: tokenString, + Type: RefreshToken, + ExpiresIn: ti.refreshTokenTTL, + IssuedAt: now, + SessionID: sessionID, + } + + err := ti.storeSession(ctx, auth, session, token, oldToken) + if err != nil { + return nil, &InternalError{Msg: "failed to store session", Err: err} + } + + return token, nil +} + +// DecodeAccessToken parses and validates a JWT access token +func (ti *SessionTokenIssuer) DecodeAccessToken(ctx context.Context, tokenStr string) (*TokenClaims, error) { + if ti.keyProvider == nil { + return nil, ErrKeyProviderUnset + } + + parsedToken, err := jwt.ParseWithClaims(tokenStr, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + // Verify the signing method matches + if token.Method.Alg() != ti.keyProvider.Method().Alg() { + return nil, &TokenError{Msg: "unexpected signing method"} + } + return ti.keyProvider.VerifyKey(), nil + }) + + if err != nil { + return nil, &TokenError{Msg: "failed to parse or validate access token", Err: err} + } + + claims, ok := parsedToken.Claims.(*TokenClaims) + if !ok { + return nil, &TokenError{Msg: "invalid token claims"} + } + + return claims, nil +} + +// ConvertAccessTokenClaims converts token claims to an Authenticatable entity +func (ti *SessionTokenIssuer) ConvertAccessTokenClaims(ctx context.Context, claims *TokenClaims) (Authenticatable, error) { + if ti.convertClaims != nil { + a, err := ti.convertClaims(ctx, claims) + if err != nil { + return nil, &TokenError{Msg: "failed to convert access token claims", Err: err} + } + return a, nil + } + + return &User{ + ID: claims.Subject, + Username: claims.Username, + Email: claims.Email, + Extra: claims.ExtraClaims, + }, nil +} + +// ValidateRefreshToken validates a refresh token and returns the associated user +func (ti *SessionTokenIssuer) ValidateRefreshToken(ctx context.Context, token string) (Authenticatable, error) { + if ti.validateSession == nil { + return nil, ErrSessionStoreUnset + } + + auth, _, err := ti.validateSession(ctx, token) + if err != nil { + return nil, &TokenError{Msg: "invalid or rejected refresh token", Err: err} + } + + return auth, nil +} + +// RevokeRefreshToken revokes a refresh token by revoking its session +func (ti *SessionTokenIssuer) RevokeRefreshToken(ctx context.Context, token string) error { + if ti.getSession == nil || ti.revokeSession == nil { + return ErrSessionStoreUnset + } + + session, err := ti.getSession(ctx, token) + if err != nil { + return &TokenError{Msg: "failed to get session for token", Err: err} + } + + // We need to get the user to revoke the session + auth, _, err := ti.validateSession(ctx, token) + if err != nil { + return &TokenError{Msg: "failed to validate token for revocation", Err: err} + } + + return ti.revokeSession(ctx, auth, session.ID) +} + +// GetSession returns session information for a refresh token +func (ti *SessionTokenIssuer) GetSession(ctx context.Context, token string) (*SessionInfo, error) { + if ti.getSession == nil { + return nil, ErrSessionStoreUnset + } + + session, err := ti.getSession(ctx, token) + if err != nil { + return nil, &TokenError{Msg: "failed to get session", Err: err} + } + + return session, nil +} + +// RevokeSession revokes a specific session by ID +func (ti *SessionTokenIssuer) RevokeSession(ctx context.Context, auth Authenticatable, sessionID string) error { + if ti.revokeSession == nil { + return ErrSessionStoreUnset + } + + err := ti.revokeSession(ctx, auth, sessionID) + if err != nil { + return &SessionError{Msg: "failed to revoke session", SessionID: sessionID, Err: err} + } + + return nil +} + +// RevokeAllSessions revokes all sessions for an authenticated entity +func (ti *SessionTokenIssuer) RevokeAllSessions(ctx context.Context, auth Authenticatable) error { + if ti.revokeAllSessions == nil { + return ErrSessionStoreUnset + } + + err := ti.revokeAllSessions(ctx, auth) + if err != nil { + return &InternalError{Msg: "failed to revoke all sessions", Err: err} + } + + return nil +} + +// ListSessions returns all active sessions for an authenticated entity +func (ti *SessionTokenIssuer) ListSessions(ctx context.Context, auth Authenticatable) ([]*SessionInfo, error) { + if ti.listSessions == nil { + return nil, ErrSessionStoreUnset + } + + sessions, err := ti.listSessions(ctx, auth) + if err != nil { + return nil, &InternalError{Msg: "failed to list sessions", Err: err} + } + + return sessions, nil +} + +// IssueTokenPair creates both access and refresh tokens in one call +// This is a convenience method that ensures tokens share the same session ID +func (ti *SessionTokenIssuer) IssueTokenPair(ctx context.Context, auth Authenticatable, oldRefreshToken *string) (*TokenPair, error) { + refreshToken, err := ti.CreateRefreshToken(ctx, auth, oldRefreshToken) + if err != nil { + return nil, err + } + + accessToken, err := ti.CreateAccessTokenWithSession(ctx, auth, refreshToken.SessionID) + if err != nil { + return nil, err + } + + return &TokenPair{ + Access: accessToken, + Refresh: refreshToken, + }, nil +} diff --git a/signing.go b/signing.go new file mode 100644 index 0000000..4146721 --- /dev/null +++ b/signing.go @@ -0,0 +1,204 @@ +package goauth + +import ( + "crypto/ecdsa" + "crypto/rsa" + "errors" + "fmt" + + "github.com/golang-jwt/jwt/v5" +) + +// SigningMethod represents supported JWT signing algorithms +type SigningMethod string + +const ( + // HMAC signing methods (symmetric) + HS256 SigningMethod = "HS256" + HS384 SigningMethod = "HS384" + HS512 SigningMethod = "HS512" + + // RSA signing methods (asymmetric) + RS256 SigningMethod = "RS256" + RS384 SigningMethod = "RS384" + RS512 SigningMethod = "RS512" + + // ECDSA signing methods (asymmetric) + ES256 SigningMethod = "ES256" + ES384 SigningMethod = "ES384" + ES512 SigningMethod = "ES512" +) + +// KeyProvider abstracts the signing key management for JWT tokens +type KeyProvider interface { + // Method returns the JWT signing method + Method() jwt.SigningMethod + // SignKey returns the key used for signing tokens + SignKey() any + // VerifyKey returns the key used for verifying tokens + VerifyKey() any + // Algorithm returns the signing method name + Algorithm() SigningMethod +} + +// HMACKeyProvider implements KeyProvider for HMAC-based signing (HS256, HS384, HS512) +type HMACKeyProvider struct { + secret []byte + method SigningMethod +} + +// NewHMACKeyProvider creates a new HMAC key provider +func NewHMACKeyProvider(secret []byte, method SigningMethod) (*HMACKeyProvider, error) { + if len(secret) == 0 { + return nil, errors.New("secret cannot be empty") + } + + switch method { + case HS256, HS384, HS512: + // valid + default: + return nil, fmt.Errorf("invalid HMAC signing method: %s", method) + } + + return &HMACKeyProvider{ + secret: secret, + method: method, + }, nil +} + +func (p *HMACKeyProvider) Method() jwt.SigningMethod { + switch p.method { + case HS384: + return jwt.SigningMethodHS384 + case HS512: + return jwt.SigningMethodHS512 + default: + return jwt.SigningMethodHS256 + } +} + +func (p *HMACKeyProvider) SignKey() any { + return p.secret +} + +func (p *HMACKeyProvider) VerifyKey() any { + return p.secret +} + +func (p *HMACKeyProvider) Algorithm() SigningMethod { + return p.method +} + +// RSAKeyProvider implements KeyProvider for RSA-based signing (RS256, RS384, RS512) +type RSAKeyProvider struct { + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + method SigningMethod +} + +// NewRSAKeyProvider creates a new RSA key provider +// privateKey is required for signing, publicKey is required for verification +// If only verification is needed, privateKey can be nil +func NewRSAKeyProvider(privateKey *rsa.PrivateKey, publicKey *rsa.PublicKey, method SigningMethod) (*RSAKeyProvider, error) { + if privateKey == nil && publicKey == nil { + return nil, errors.New("at least one key (private or public) must be provided") + } + + switch method { + case RS256, RS384, RS512: + // valid + default: + return nil, fmt.Errorf("invalid RSA signing method: %s", method) + } + + // If private key is provided but public key is not, derive public from private + if privateKey != nil && publicKey == nil { + publicKey = &privateKey.PublicKey + } + + return &RSAKeyProvider{ + privateKey: privateKey, + publicKey: publicKey, + method: method, + }, nil +} + +func (p *RSAKeyProvider) Method() jwt.SigningMethod { + switch p.method { + case RS384: + return jwt.SigningMethodRS384 + case RS512: + return jwt.SigningMethodRS512 + default: + return jwt.SigningMethodRS256 + } +} + +func (p *RSAKeyProvider) SignKey() any { + return p.privateKey +} + +func (p *RSAKeyProvider) VerifyKey() any { + return p.publicKey +} + +func (p *RSAKeyProvider) Algorithm() SigningMethod { + return p.method +} + +// ECDSAKeyProvider implements KeyProvider for ECDSA-based signing (ES256, ES384, ES512) +type ECDSAKeyProvider struct { + privateKey *ecdsa.PrivateKey + publicKey *ecdsa.PublicKey + method SigningMethod +} + +// NewECDSAKeyProvider creates a new ECDSA key provider +// privateKey is required for signing, publicKey is required for verification +// If only verification is needed, privateKey can be nil +func NewECDSAKeyProvider(privateKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, method SigningMethod) (*ECDSAKeyProvider, error) { + if privateKey == nil && publicKey == nil { + return nil, errors.New("at least one key (private or public) must be provided") + } + + switch method { + case ES256, ES384, ES512: + // valid + default: + return nil, fmt.Errorf("invalid ECDSA signing method: %s", method) + } + + // If private key is provided but public key is not, derive public from private + if privateKey != nil && publicKey == nil { + publicKey = &privateKey.PublicKey + } + + return &ECDSAKeyProvider{ + privateKey: privateKey, + publicKey: publicKey, + method: method, + }, nil +} + +func (p *ECDSAKeyProvider) Method() jwt.SigningMethod { + switch p.method { + case ES384: + return jwt.SigningMethodES384 + case ES512: + return jwt.SigningMethodES512 + default: + return jwt.SigningMethodES256 + } +} + +func (p *ECDSAKeyProvider) SignKey() any { + return p.privateKey +} + +func (p *ECDSAKeyProvider) VerifyKey() any { + return p.publicKey +} + +func (p *ECDSAKeyProvider) Algorithm() SigningMethod { + return p.method +} diff --git a/token.go b/token.go index b3bf294..93a461f 100644 --- a/token.go +++ b/token.go @@ -1,186 +1,54 @@ package goauth import ( - "context" - "time" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" ) +// TokenClaims represents the JWT claims for access tokens type TokenClaims struct { jwt.RegisteredClaims Username string `json:"username,omitempty"` Email string `json:"email,omitempty"` + TokenType TokenType `json:"typ,omitempty"` // "access" or "refresh" + SessionID string `json:"sid,omitempty"` // Session identifier for multi-session support ExtraClaims map[string]any `json:"ext,omitempty"` } -type DefaultTokenIssuer struct { - secret string - issuer string - audience []string - accessTokenExpiresIn time.Duration - refreshTokenExpiresIn time.Duration - storeRefreshTokenWith StoreRefreshTokenFunc - setExtraClaimsWith SetExtraClaimsFunc - setRegisteredClaimsWith SetRegisteredClaimsFunc - convertAccessTokenClaimsWith ConvertAccessTokenClaimsFunc - validateRefreshTokenWith ValidateRefreshTokenFunc -} - -func NewDefaultTokenIssuer(secret string) *DefaultTokenIssuer { - ti := &DefaultTokenIssuer{ - secret: secret, - issuer: "goauth", - audience: []string{"goauth"}, - accessTokenExpiresIn: 300 * time.Second, // default 5 minutes - refreshTokenExpiresIn: 3600 * time.Second, // default 1 hour +// GetExtraClaim returns a value from extra claims +func (tc *TokenClaims) GetExtraClaim(key string) (any, bool) { + if tc.ExtraClaims == nil { + return nil, false } - - return ti + v, ok := tc.ExtraClaims[key] + return v, ok } -func (ti *DefaultTokenIssuer) SetSecret(secret string) { - ti.secret = secret -} - -func (ti *DefaultTokenIssuer) SetIssuer(issuer string) { - ti.issuer = issuer -} - -func (ti *DefaultTokenIssuer) SetAudience(audience []string) { - ti.audience = audience -} - -func (ti *DefaultTokenIssuer) SetAccessTokenExpiresIn(expiresIn time.Duration) { - ti.accessTokenExpiresIn = expiresIn -} - -func (ti *DefaultTokenIssuer) SetRefreshTokenExpiresIn(expiresIn time.Duration) { - ti.refreshTokenExpiresIn = expiresIn -} - -func (ti *DefaultTokenIssuer) StoreRefreshTokenWith(storeRefreshTokenWith StoreRefreshTokenFunc) { - ti.storeRefreshTokenWith = storeRefreshTokenWith -} - -func (ti *DefaultTokenIssuer) SetExtraClaimsWith(setExtraClaimsWith SetExtraClaimsFunc) { - ti.setExtraClaimsWith = setExtraClaimsWith -} - -func (ti *DefaultTokenIssuer) SetRegisteredClaimsWith(setRegisteredClaimsWith SetRegisteredClaimsFunc) { - ti.setRegisteredClaimsWith = setRegisteredClaimsWith -} - -func (ti *DefaultTokenIssuer) ConvertAccessTokenClaimsWith(convertAccessTokenClaimsWith ConvertAccessTokenClaimsFunc) { - ti.convertAccessTokenClaimsWith = convertAccessTokenClaimsWith -} - -func (ti *DefaultTokenIssuer) ValidateRefreshTokenWith(validateRefreshTokenWith ValidateRefreshTokenFunc) { - ti.validateRefreshTokenWith = validateRefreshTokenWith -} - -func (ti *DefaultTokenIssuer) CreateAccessToken(ctx context.Context, authenticatable Authenticatable) (*Token, error) { - extraClaims := make(map[string]any) - if ti.setExtraClaimsWith != nil { - extraClaims = ti.setExtraClaimsWith(ctx, authenticatable) - } - - var registeredClaims jwt.RegisteredClaims - if ti.setRegisteredClaimsWith != nil { - registeredClaims = ti.setRegisteredClaimsWith(ctx, authenticatable) - } else { - registeredClaims = jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(ti.accessTokenExpiresIn)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Subject: authenticatable.GetID(), - Issuer: ti.issuer, - Audience: ti.audience, - } - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, TokenClaims{ - RegisteredClaims: registeredClaims, - Username: authenticatable.GetUsername(), - Email: authenticatable.GetEmail(), - ExtraClaims: extraClaims, - }) - - tokenString, err := token.SignedString([]byte(ti.secret)) - if err != nil { - return nil, err - } - - return &Token{ - Value: tokenString, - ExpiresIn: ti.accessTokenExpiresIn, - }, nil -} - -func (ti *DefaultTokenIssuer) CreateRefreshToken(ctx context.Context, authenticatable Authenticatable, refreshing bool) (*Token, error) { - if ti.storeRefreshTokenWith == nil { - return nil, &ConfigError{Msg: "StoreRefreshTokenWith is not set"} - } - - tokenString := uuid.New().String() - token := &Token{ - Value: tokenString, - ExpiresIn: ti.refreshTokenExpiresIn, - } - - err := ti.storeRefreshTokenWith(ctx, authenticatable, token, refreshing) - if err != nil { - return nil, &InternalError{Msg: "failed to store refresh token", Err: err} +// GetExtraClaimString returns a string value from extra claims +func (tc *TokenClaims) GetExtraClaimString(key string) (string, bool) { + v, ok := tc.GetExtraClaim(key) + if !ok { + return "", false } - - return token, nil + s, ok := v.(string) + return s, ok } -func (ti *DefaultTokenIssuer) DecodeAccessToken(ctx context.Context, token string) (*TokenClaims, error) { - jwt, err := jwt.ParseWithClaims(token, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte(ti.secret), nil - }) - - if err != nil { - // jwt lib returns various errors (validation/signature/expired). Classify as token error. - return nil, &TokenError{Msg: "failed to parse or validate access token", Err: err} - } - - claims, ok := jwt.Claims.(*TokenClaims) +// GetExtraClaimBool returns a bool value from extra claims +func (tc *TokenClaims) GetExtraClaimBool(key string) (bool, bool) { + v, ok := tc.GetExtraClaim(key) if !ok { - return nil, &TokenError{Msg: "invalid token claims"} + return false, false } - - return claims, nil + b, ok := v.(bool) + return b, ok } -func (ti *DefaultTokenIssuer) ConvertAccessTokenClaims(ctx context.Context, claims *TokenClaims) (Authenticatable, error) { - if ti.convertAccessTokenClaimsWith != nil { - a, err := ti.convertAccessTokenClaimsWith(ctx, claims) - if err != nil { - return nil, &TokenError{Msg: "failed to convert access token claims", Err: err} - } - return a, nil - } - - return &User{ - ID: claims.Subject, - Username: claims.Username, - Email: claims.Email, - Extra: claims.ExtraClaims, - }, nil +// IsAccessToken returns true if this is an access token +func (tc *TokenClaims) IsAccessToken() bool { + return tc.TokenType == AccessToken || tc.TokenType == "" } -func (ti *DefaultTokenIssuer) ValidateRefreshToken(ctx context.Context, token string) (Authenticatable, error) { - if ti.validateRefreshTokenWith == nil { - return nil, &ConfigError{Msg: "ValidateRefreshTokenWith is not set"} - } - - authenticatable, err := ti.validateRefreshTokenWith(ctx, token) - if err != nil { - return nil, &TokenError{Msg: "invalid or rejected refresh token", Err: err} - } - - return authenticatable, nil +// IsRefreshToken returns true if this is a refresh token +func (tc *TokenClaims) IsRefreshToken() bool { + return tc.TokenType == RefreshToken }