diff --git a/adapters/gorm/adapter.go b/adapters/gorm/adapter.go index 9e94b63..7a4fde7 100644 --- a/adapters/gorm/adapter.go +++ b/adapters/gorm/adapter.go @@ -105,6 +105,13 @@ func (a *Adapter) FindMany(ctx context.Context, tableName limen.SchemaTableName, } func (a *Adapter) Update(ctx context.Context, tableName limen.SchemaTableName, conditions []limen.Where, updates map[string]any) error { + if len(updates) == 0 { + return nil + } + if len(conditions) == 0 { + return fmt.Errorf("%w: conditions required to prevent accidental table-wide update", limen.ErrMissingConditions) + } + db := a.getDB() query := db.WithContext(ctx).Table(string(tableName)) query = a.applyConditions(query, conditions) @@ -112,6 +119,10 @@ func (a *Adapter) Update(ctx context.Context, tableName limen.SchemaTableName, c } func (a *Adapter) Delete(ctx context.Context, tableName limen.SchemaTableName, conditions []limen.Where) error { + if len(conditions) == 0 { + return fmt.Errorf("%w: conditions required to prevent accidental table-wide delete", limen.ErrMissingConditions) + } + db := a.getDB() query := db.WithContext(ctx).Table(string(tableName)) query = a.applyConditions(query, conditions) diff --git a/adapters/gorm/adapter_test.go b/adapters/gorm/adapter_test.go index 9cc3ce4..a20a4e4 100644 --- a/adapters/gorm/adapter_test.go +++ b/adapters/gorm/adapter_test.go @@ -93,6 +93,24 @@ func TestGorm_Update(t *testing.T) { assert.Equal(t, "new@test.com", result["email"]) } +func TestGorm_Update_RequiresConditions(t *testing.T) { + adapter := setupTestGormDB(t) + ctx := context.Background() + + err := adapter.Update(ctx, "test_items", nil, map[string]any{"email": "new@test.com"}) + + assert.ErrorIs(t, err, limen.ErrMissingConditions) +} + +func TestGorm_Update_EmptyUpdatesNoop(t *testing.T) { + adapter := setupTestGormDB(t) + ctx := context.Background() + + err := adapter.Update(ctx, "test_items", nil, map[string]any{}) + + assert.NoError(t, err) +} + func TestGorm_Delete(t *testing.T) { adapter := setupTestGormDB(t) ctx := context.Background() @@ -108,6 +126,15 @@ func TestGorm_Delete(t *testing.T) { assert.ErrorIs(t, err, limen.ErrRecordNotFound) } +func TestGorm_Delete_RequiresConditions(t *testing.T) { + adapter := setupTestGormDB(t) + ctx := context.Background() + + err := adapter.Delete(ctx, "test_items", nil) + + assert.ErrorIs(t, err, limen.ErrMissingConditions) +} + func TestGorm_Exists(t *testing.T) { adapter := setupTestGormDB(t) ctx := context.Background() diff --git a/cmd/limen/migration.go b/cmd/limen/migration.go index 2fd8791..10f2d22 100644 --- a/cmd/limen/migration.go +++ b/cmd/limen/migration.go @@ -3,6 +3,7 @@ package main import ( "database/sql" "fmt" + "sort" "time" "github.com/thecodearcher/limen" @@ -16,7 +17,7 @@ type Migration struct { func generateMigrations(db *sql.DB, driver Driver, config *cliConfig) ([]Migration, error) { migrations := make([]Migration, 0, len(config.Schemas)) - timestamp := time.Now().Format("20060102150405") + timestamp := time.Now() introspector := newSchemaIntrospector(db, driver) tableNames := make([]string, 0, len(config.Schemas)) @@ -34,7 +35,13 @@ func generateMigrations(db *sql.DB, driver Driver, config *cliConfig) ([]Migrati return nil, fmt.Errorf("failed to create migration generator: %w", err) } - for schemaName, schemaDef := range config.Schemas { + schemaNames, err := orderSchemasByDependencies(config.Schemas) + if err != nil { + return nil, err + } + + for _, schemaName := range schemaNames { + schemaDef := config.Schemas[schemaName] var diff *schemaDiff if existingTables[string(schemaDef.GetTableName())] { @@ -59,7 +66,7 @@ func generateMigrations(db *sql.DB, driver Driver, config *cliConfig) ([]Migrati } migration := Migration{ - Version: fmt.Sprintf("%s_%s", timestamp, schemaName), + Version: migrationVersion(timestamp, len(migrations), schemaName), UpSQL: upSQL, DownSQL: downSQL, } @@ -70,6 +77,100 @@ func generateMigrations(db *sql.DB, driver Driver, config *cliConfig) ([]Migrati return migrations, nil } +func migrationVersion(base time.Time, sequence int, schemaName limen.SchemaName) string { + return fmt.Sprintf("%s_%s", base.Add(time.Duration(sequence)*time.Second).Format("20060102150405"), schemaName) +} + +func orderSchemasByDependencies(schemas limen.SchemaDefinitionMap) ([]limen.SchemaName, error) { + names := sortedSchemaNames(schemas) + tableToSchema := make(map[limen.SchemaName]limen.SchemaName, len(schemas)) + for _, name := range names { + schema := schemas[name] + tableToSchema[limen.SchemaName(schema.GetTableName())] = name + } + + indegree := make(map[limen.SchemaName]int, len(schemas)) + dependents := make(map[limen.SchemaName][]limen.SchemaName, len(schemas)) + seenEdges := make(map[limen.SchemaName]map[limen.SchemaName]bool, len(schemas)) + for _, name := range names { + indegree[name] = 0 + } + + for _, name := range names { + for _, fk := range schemas[name].ForeignKeys { + dependency, ok := resolveSchemaDependency(fk.ReferencedSchema, schemas, tableToSchema) + if !ok || dependency == name { + continue + } + if seenEdges[name] == nil { + seenEdges[name] = make(map[limen.SchemaName]bool) + } + if seenEdges[name][dependency] { + continue + } + + seenEdges[name][dependency] = true + indegree[name]++ + dependents[dependency] = append(dependents[dependency], name) + } + } + + ready := make([]limen.SchemaName, 0, len(schemas)) + for _, name := range names { + if indegree[name] == 0 { + ready = append(ready, name) + } + } + + ordered := make([]limen.SchemaName, 0, len(schemas)) + for len(ready) > 0 { + sortSchemaNames(ready) + name := ready[0] + ready = ready[1:] + ordered = append(ordered, name) + + for _, dependent := range dependents[name] { + indegree[dependent]-- + if indegree[dependent] == 0 { + ready = append(ready, dependent) + } + } + } + + if len(ordered) != len(schemas) { + return nil, fmt.Errorf("foreign key dependency cycle detected among schemas") + } + + return ordered, nil +} + +func resolveSchemaDependency( + referenced limen.SchemaName, + schemas limen.SchemaDefinitionMap, + tableToSchema map[limen.SchemaName]limen.SchemaName, +) (limen.SchemaName, bool) { + if _, ok := schemas[referenced]; ok { + return referenced, true + } + name, ok := tableToSchema[referenced] + return name, ok +} + +func sortedSchemaNames(schemas limen.SchemaDefinitionMap) []limen.SchemaName { + names := make([]limen.SchemaName, 0, len(schemas)) + for name := range schemas { + names = append(names, name) + } + sortSchemaNames(names) + return names +} + +func sortSchemaNames(names []limen.SchemaName) { + sort.Slice(names, func(i, j int) bool { + return string(names[i]) < string(names[j]) + }) +} + func generateDiffForTable(introspector *schemaIntrospector, schema *limen.SchemaDefinition) (*schemaDiff, error) { existingSchema, err := introspector.introspectTable(schema.GetTableName()) if err != nil { diff --git a/cmd/limen/migration_test.go b/cmd/limen/migration_test.go new file mode 100644 index 0000000..ae69b5d --- /dev/null +++ b/cmd/limen/migration_test.go @@ -0,0 +1,130 @@ +package main + +import ( + "strings" + "testing" + "time" + + "github.com/thecodearcher/limen" +) + +func TestMigrationVersionIncrementsTimestampPerMigration(t *testing.T) { + base := time.Date(2026, 6, 7, 3, 44, 8, 0, time.UTC) + + got := []string{ + migrationVersion(base, 0, "users"), + migrationVersion(base, 1, "accounts"), + migrationVersion(base, 2, "sessions"), + } + + want := []string{ + "20260607034408_users", + "20260607034409_accounts", + "20260607034410_sessions", + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("version %d = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestOrderSchemasByDependenciesPutsReferencedTablesFirst(t *testing.T) { + schemas := limen.SchemaDefinitionMap{ + "accounts": schemaDef("accounts", fk("app_users")), + "sessions": schemaDef("sessions", fk("app_users")), + "users": schemaDef("app_users"), + } + + ordered, err := orderSchemasByDependencies(schemas) + if err != nil { + t.Fatalf("orderSchemasByDependencies returned error: %v", err) + } + + assertBefore(t, ordered, "users", "accounts") + assertBefore(t, ordered, "users", "sessions") +} + +func TestOrderSchemasByDependenciesSortsIndependentSchemas(t *testing.T) { + schemas := limen.SchemaDefinitionMap{ + "users": schemaDef("users"), + "rate_limits": schemaDef("rate_limits"), + "verifications": schemaDef("verifications"), + } + + ordered, err := orderSchemasByDependencies(schemas) + if err != nil { + t.Fatalf("orderSchemasByDependencies returned error: %v", err) + } + + want := []limen.SchemaName{"rate_limits", "users", "verifications"} + for i := range want { + if ordered[i] != want[i] { + t.Fatalf("ordered[%d] = %q, want %q; full order: %v", i, ordered[i], want[i], ordered) + } + } +} + +func TestOrderSchemasByDependenciesDetectsCycles(t *testing.T) { + schemas := limen.SchemaDefinitionMap{ + "a": schemaDef("a", fk("b")), + "b": schemaDef("b", fk("a")), + } + + _, err := orderSchemasByDependencies(schemas) + if err == nil { + t.Fatal("expected cycle error") + } + if !strings.Contains(err.Error(), "dependency cycle") { + t.Fatalf("expected dependency cycle error, got %v", err) + } +} + +func schemaDef(tableName string, foreignKeys ...limen.ForeignKeyDefinition) limen.SchemaDefinition { + return limen.SchemaDefinition{ + TableName: limen.SchemaTableName(tableName), + Columns: []limen.ColumnDefinition{ + { + Name: "id", + LogicalField: limen.SchemaIDField, + Type: limen.ColumnTypeInt64, + IsPrimaryKey: true, + }, + }, + ForeignKeys: foreignKeys, + } +} + +func fk(referencedSchema limen.SchemaName) limen.ForeignKeyDefinition { + return limen.ForeignKeyDefinition{ + Name: "fk_test", + Column: "user_id", + ReferencedSchema: referencedSchema, + ReferencedField: limen.SchemaIDField, + } +} + +func assertBefore(t *testing.T, ordered []limen.SchemaName, before, after limen.SchemaName) { + t.Helper() + + beforeIndex := schemaIndex(ordered, before) + afterIndex := schemaIndex(ordered, after) + if beforeIndex == -1 { + t.Fatalf("%q not found in ordered schemas %v", before, ordered) + } + if afterIndex == -1 { + t.Fatalf("%q not found in ordered schemas %v", after, ordered) + } + if beforeIndex > afterIndex { + t.Fatalf("expected %q before %q, got %v", before, after, ordered) + } +} + +func schemaIndex(ordered []limen.SchemaName, target limen.SchemaName) int { + for i, name := range ordered { + if name == target { + return i + } + } + return -1 +} diff --git a/middlewares.go b/middlewares.go index c037b76..391a62a 100644 --- a/middlewares.go +++ b/middlewares.go @@ -55,16 +55,16 @@ func (httpCore *LimenHTTPCore) middlewareCheckOrigin() Middleware { return } - if len(httpCore.trustedOriginsPatterns) == 0 { - next.ServeHTTP(w, r) - return - } - origin := r.Header.Get("Origin") if origin == "" { origin = r.Header.Get("Referer") } + if origin == "" { + next.ServeHTTP(w, r) + return + } + if httpCore.IsTrustedOrigin(origin) { next.ServeHTTP(w, r) return diff --git a/middlewares_test.go b/middlewares_test.go index 7fd08b3..8e982b0 100644 --- a/middlewares_test.go +++ b/middlewares_test.go @@ -119,6 +119,59 @@ func TestMiddlewareCheckOrigin(t *testing.T) { } } +func TestMiddlewareCheckOrigin_DefaultTrustsOnlyBaseURL(t *testing.T) { + t.Parallel() + + l := newTestLimen(t) + httpCore := newTestHTTPCore(t, l) + baseURL := l.core.GetBaseURL() + + tests := []struct { + name string + headers map[string]string + wantStatus int + }{ + { + name: "base URL origin allowed", + headers: map[string]string{"Origin": baseURL, "Content-Type": "application/json"}, + wantStatus: http.StatusOK, + }, + { + name: "missing origin allowed", + headers: map[string]string{"Content-Type": "application/json"}, + wantStatus: http.StatusOK, + }, + { + name: "untrusted origin blocked", + headers: map[string]string{"Origin": "http://evil.com", "Content-Type": "application/json"}, + wantStatus: http.StatusForbidden, + }, + { + name: "untrusted referer blocked", + headers: map[string]string{"Referer": "http://evil.com/path", "Content-Type": "application/json"}, + wantStatus: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := httpCore.middlewareCheckOrigin()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/auth/signin", http.NoBody) + for k, v := range tt.headers { + req.Header.Set(k, v) + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.wantStatus, w.Code) + }) + } +} + func TestMiddlewareCSRFProtection(t *testing.T) { t.Parallel() diff --git a/plugins/oauth-apple/go.mod b/plugins/oauth-apple/go.mod index 8df61b6..715d8af 100644 --- a/plugins/oauth-apple/go.mod +++ b/plugins/oauth-apple/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-apple/go.sum b/plugins/oauth-apple/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-apple/go.sum +++ b/plugins/oauth-apple/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-apple/options.go b/plugins/oauth-apple/options.go index 84c5fec..6ca9840 100644 --- a/plugins/oauth-apple/options.go +++ b/plugins/oauth-apple/options.go @@ -1,14 +1,17 @@ package oauthapple +import "github.com/thecodearcher/limen/plugins/oauth" + // ConfigOption configures the Apple OAuth plugin. type ConfigOption func(*config) type config struct { - clientID string - clientSecret string - redirectURL string - scopes []string - options map[string]string + clientID string + clientSecret string + redirectURL string + scopes []string + options map[string]string + verifyIDToken oauth.IDTokenVerifier } // WithClientID sets the Apple Services ID (the identifier for your app). @@ -50,3 +53,9 @@ func WithOption(key, value string) ConfigOption { c.options[key] = value } } + +func WithIDTokenVerifier(verifier oauth.IDTokenVerifier) ConfigOption { + return func(c *config) { + c.verifyIDToken = verifier + } +} diff --git a/plugins/oauth-apple/provider.go b/plugins/oauth-apple/provider.go index dfa1dc3..d4c63e8 100644 --- a/plugins/oauth-apple/provider.go +++ b/plugins/oauth-apple/provider.go @@ -15,8 +15,9 @@ import ( ) var appleEndpoint = oauth2.Endpoint{ - AuthURL: "https://appleid.apple.com/auth/authorize", - TokenURL: "https://appleid.apple.com/auth/token", + AuthURL: "https://appleid.apple.com/auth/authorize", + TokenURL: "https://appleid.apple.com/auth/token", + AuthStyle: oauth2.AuthStyleInParams, } // New creates an Apple OAuth provider that implements oauth.Provider. @@ -38,6 +39,9 @@ type appleProvider struct { } func newAppleProvider(cfg *config) *appleProvider { + if cfg.verifyIDToken == nil { + cfg.verifyIDToken = oauth.NewIDTokenVerifier("https://appleid.apple.com", cfg.clientID) + } oauthCfg := &oauth2.Config{ ClientID: cfg.clientID, ClientSecret: cfg.clientSecret, @@ -60,6 +64,10 @@ func (a *appleProvider) OAuth2Config() (*oauth2.Config, []oauth2.AuthCodeOption) return a.oauthConfig, authOpts } +func (a *appleProvider) IDTokenNonceEnabled() bool { + return true +} + // ResponseMode returns form_post because Apple delivers the authorization // response (including the first-login user payload) as a POST body. func (a *appleProvider) ResponseMode() oauth.ResponseMode { @@ -70,10 +78,13 @@ func (a *appleProvider) GetUserInfo(ctx context.Context, token *oauth.TokenRespo if token.IDToken == "" { return nil, errors.New("apple: id_token required; include email scope") } - claims, err := oauth.DecodeIDTokenClaims(token.IDToken) + claims, err := a.config.verifyIDToken(ctx, token.IDToken) if err != nil { return nil, fmt.Errorf("apple: %w", err) } + if err := oauth.VerifyIDTokenNonce(claims, oauth.IDTokenNonce(ctx)); err != nil { + return nil, fmt.Errorf("apple: %w", err) + } sub, _ := claims["sub"].(string) if sub == "" { diff --git a/plugins/oauth-apple/provider_test.go b/plugins/oauth-apple/provider_test.go new file mode 100644 index 0000000..9bb4a10 --- /dev/null +++ b/plugins/oauth-apple/provider_test.go @@ -0,0 +1,96 @@ +package oauthapple + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "net/url" + "testing" + + "golang.org/x/oauth2" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestGetUserInfo_UsesIDTokenVerifier(t *testing.T) { + t.Parallel() + + called := false + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, idToken string) (map[string]any, error) { + called = true + if idToken != "id-token" { + t.Fatalf("unexpected id token: %s", idToken) + } + return map[string]any{ + "sub": "apple-user-1", + "email": "user@example.com", + "email_verified": "true", + "nonce": "nonce-value", + }, nil + }), + ) + + ctx := oauth.ContextWithIDTokenNonce(context.Background(), "nonce-value") + info, err := provider.GetUserInfo(ctx, &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !called { + t.Fatal("expected verifier to be called") + } + if info.ID != "apple-user-1" || info.Email != "user@example.com" || !info.EmailVerified { + t.Fatalf("unexpected user info: %#v", info) + } +} + +func TestOAuth2Config_UsesAppleClientSecretPost(t *testing.T) { + t.Parallel() + + provider := New() + cfg, _ := provider.OAuth2Config() + if cfg.Endpoint.AuthStyle != oauth2.AuthStyleInParams { + t.Fatalf("AuthStyle = %v", cfg.Endpoint.AuthStyle) + } +} + +func TestGetUserInfo_AcceptsSHA256NonceClaim(t *testing.T) { + t.Parallel() + + nonce := "nonce-value" + sum := sha256.Sum256([]byte(nonce)) + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, _ string) (map[string]any, error) { + return map[string]any{ + "sub": "apple-user-1", + "email": "user@example.com", + "email_verified": true, + "nonce": hex.EncodeToString(sum[:]), + }, nil + }), + ) + + ctx := oauth.ContextWithIDTokenNonce(context.Background(), nonce) + info, err := provider.GetUserInfo(ctx, &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !info.EmailVerified { + t.Fatalf("expected verified email: %#v", info) + } +} + +func TestExtractNameFromParams(t *testing.T) { + t.Parallel() + + name := extractNameFromParams(mapValues("user", `{"name":{"firstName":"Test","lastName":"User"}}`)) + if name != "Test User" { + t.Fatalf("name = %q", name) + } +} + +func mapValues(key, value string) url.Values { + return url.Values{key: []string{value}} +} diff --git a/plugins/oauth-discord/go.mod b/plugins/oauth-discord/go.mod index 5cca7df..2be4254 100644 --- a/plugins/oauth-discord/go.mod +++ b/plugins/oauth-discord/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-discord/go.sum b/plugins/oauth-discord/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-discord/go.sum +++ b/plugins/oauth-discord/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-discord/provider.go b/plugins/oauth-discord/provider.go index c63c2cc..51d2ab9 100644 --- a/plugins/oauth-discord/provider.go +++ b/plugins/oauth-discord/provider.go @@ -13,6 +13,7 @@ import ( "github.com/thecodearcher/limen/plugins/oauth" ) +//nolint:gosec // OAuth endpoint URL, not a credential. var discordEndpoint = oauth2.Endpoint{ AuthURL: "https://discord.com/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", @@ -73,13 +74,14 @@ func (d *discordProvider) GetUserInfo(ctx context.Context, token *oauth.TokenRes } username, _ := raw["username"].(string) email, _ := raw["email"].(string) + emailVerified, _ := raw["verified"].(bool) avatarURL := buildAvatarURL(id, raw) return &oauth.ProviderUserInfo{ ID: id, Email: email, - EmailVerified: email != "", + EmailVerified: emailVerified, Name: username, AvatarURL: avatarURL, Raw: raw, diff --git a/plugins/oauth-facebook/go.mod b/plugins/oauth-facebook/go.mod index 5a35618..a0793e7 100644 --- a/plugins/oauth-facebook/go.mod +++ b/plugins/oauth-facebook/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-facebook/go.sum b/plugins/oauth-facebook/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-facebook/go.sum +++ b/plugins/oauth-facebook/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-facebook/provider.go b/plugins/oauth-facebook/provider.go index 1f8c585..bcfc450 100644 --- a/plugins/oauth-facebook/provider.go +++ b/plugins/oauth-facebook/provider.go @@ -13,6 +13,7 @@ import ( "github.com/thecodearcher/limen/plugins/oauth" ) +//nolint:gosec // OAuth endpoint URL, not a credential. var facebookEndpoint = oauth2.Endpoint{ AuthURL: "https://www.facebook.com/v25.0/dialog/oauth", TokenURL: "https://graph.facebook.com/v25.0/oauth/access_token", @@ -76,13 +77,15 @@ func (f *facebookProvider) GetUserInfo(ctx context.Context, token *oauth.TokenRe } name, _ := raw["name"].(string) email, _ := raw["email"].(string) + // Meta Graph API does not normally include an email verification claim. + emailVerified, _ := raw["email_verified"].(bool) avatarURL := extractPictureURL(raw) return &oauth.ProviderUserInfo{ ID: id, Email: email, - EmailVerified: email != "", + EmailVerified: emailVerified, Name: name, AvatarURL: avatarURL, Raw: raw, diff --git a/plugins/oauth-facebook/provider_test.go b/plugins/oauth-facebook/provider_test.go new file mode 100644 index 0000000..e3b57e7 --- /dev/null +++ b/plugins/oauth-facebook/provider_test.go @@ -0,0 +1,49 @@ +package oauthfacebook + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestGetUserInfo_EmailVerificationUnknown(t *testing.T) { + t.Parallel() + + provider := New().(*facebookProvider) + provider.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Header.Get("Authorization"); got != "Bearer access-token" { + t.Fatalf("Authorization = %q", got) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{ + "id": "facebook-user-1", + "name": "Test User", + "email": "user@example.com", + "picture": {"data": {"url": "https://example.com/avatar.png"}} + }`)), + }, nil + })} + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{AccessToken: "access-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if info.ID != "facebook-user-1" || info.Email != "user@example.com" { + t.Fatalf("unexpected user info: %#v", info) + } + if info.EmailVerified { + t.Fatalf("facebook email should not be marked verified: %#v", info) + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/plugins/oauth-generic/discovery.go b/plugins/oauth-generic/discovery.go index 5907bf7..eee3b4c 100644 --- a/plugins/oauth-generic/discovery.go +++ b/plugins/oauth-generic/discovery.go @@ -1,6 +1,7 @@ package oauthgeneric import ( + "context" "encoding/json" "fmt" "net/http" @@ -12,12 +13,13 @@ type discoveryDocument struct { AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` UserinfoEndpoint string `json:"userinfo_endpoint"` + Issuer string `json:"issuer"` } // fetchDiscoveryDocument fetches and parses the OpenID Connect discovery document. func fetchDiscoveryDocument(discoveryURL string) (*discoveryDocument, error) { client := &http.Client{Timeout: 10 * time.Second} - req, err := http.NewRequest(http.MethodGet, discoveryURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, discoveryURL, http.NoBody) if err != nil { return nil, fmt.Errorf("discovery request: %w", err) } diff --git a/plugins/oauth-generic/go.mod b/plugins/oauth-generic/go.mod index 3946c08..72e2609 100644 --- a/plugins/oauth-generic/go.mod +++ b/plugins/oauth-generic/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-generic/go.sum b/plugins/oauth-generic/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-generic/go.sum +++ b/plugins/oauth-generic/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-generic/options.go b/plugins/oauth-generic/options.go index 45f5fc0..ae5c1b0 100644 --- a/plugins/oauth-generic/options.go +++ b/plugins/oauth-generic/options.go @@ -9,6 +9,8 @@ import ( // ConfigOption configures the generic OAuth provider. type ConfigOption func(*config) +const defaultEmailScope = "email" + type config struct { name string clientID string @@ -16,6 +18,7 @@ type config struct { authorizationURL string tokenURL string userInfoURL string + issuer string discoveryURL string scopes []string redirectURL string @@ -25,6 +28,7 @@ type config struct { buildAuthorizationURL func(ctx context.Context, state, codeVerifier, callbackRedirectURI string) (string, error) exchangeTokens func(ctx context.Context, code, codeVerifier, redirectURI string) (*oauth.TokenResponse, error) refreshTokens func(ctx context.Context, refreshToken string) (*oauth.TokenResponse, error) + verifyIDToken oauth.IDTokenVerifier } func (c *config) resolveDiscovery() { @@ -44,6 +48,9 @@ func (c *config) resolveDiscovery() { if c.userInfoURL == "" { c.userInfoURL = doc.UserinfoEndpoint } + if c.issuer == "" { + c.issuer = doc.Issuer + } } func (c *config) validate() { @@ -76,7 +83,7 @@ func (c *config) validate() { func (c *config) resolveDefaults() { if len(c.scopes) == 0 { - c.scopes = []string{"openid", "email", "profile"} + c.scopes = []string{"openid", defaultEmailScope, "profile"} } } @@ -122,6 +129,12 @@ func WithUserInfoURL(url string) ConfigOption { } } +func WithIssuer(issuer string) ConfigOption { + return func(c *config) { + c.issuer = issuer + } +} + // WithDiscoveryURL sets the OpenID Connect discovery URL. When set, the provider fetches // the discovery document and populates authorization, token, and userinfo endpoints // from it (only for fields not already set explicitly). @@ -195,3 +208,9 @@ func WithRefreshTokens(fn func(ctx context.Context, refreshToken string) (*oauth c.refreshTokens = fn } } + +func WithIDTokenVerifier(verifier oauth.IDTokenVerifier) ConfigOption { + return func(c *config) { + c.verifyIDToken = verifier + } +} diff --git a/plugins/oauth-generic/provider.go b/plugins/oauth-generic/provider.go index 6de9620..feed4fe 100644 --- a/plugins/oauth-generic/provider.go +++ b/plugins/oauth-generic/provider.go @@ -25,6 +25,9 @@ func New(opts ...ConfigOption) oauth.Provider { cfg.resolveDiscovery() cfg.validate() cfg.resolveDefaults() + if cfg.verifyIDToken == nil && cfg.issuer != "" { + cfg.verifyIDToken = oauth.NewIDTokenVerifier(cfg.issuer, cfg.clientID) + } return &genericProvider{ config: cfg, @@ -67,15 +70,40 @@ func (g *genericProvider) GetUserInfo(ctx context.Context, token *oauth.TokenRes if g.config.getUserInfo != nil { return g.config.getUserInfo(ctx, token) } - if token.IDToken != "" { - return g.userInfoFromIDToken(token.IDToken) - } - return g.fetchUserInfoFromURL(ctx, token) + if g.config.userInfoURL != "" { + info, err := g.fetchUserInfoFromURL(ctx, token) + if err != nil { + return nil, err + } + if token.IDToken != "" && g.config.verifyIDToken != nil { + claims, err := g.config.verifyIDToken(ctx, token.IDToken) + if err != nil { + return nil, err + } + sub, _ := claims["sub"].(string) + if sub != "" && info.ID != "" && sub != info.ID { + return nil, fmt.Errorf("userinfo subject does not match id_token subject") + } + } + return info, nil + } + return g.userInfoFromIDToken(ctx, token.IDToken) } -// userInfoFromIDToken decodes the id_token JWT payload and passes the claims to mapUserInfo. -func (g *genericProvider) userInfoFromIDToken(idToken string) (*oauth.ProviderUserInfo, error) { - claims, err := oauth.DecodeIDTokenClaims(idToken) +func (g *genericProvider) userInfoFromIDToken(ctx context.Context, idToken string) (*oauth.ProviderUserInfo, error) { + if idToken == "" { + return nil, fmt.Errorf("id_token is required when no userinfo endpoint is configured") + } + + verifier := g.config.verifyIDToken + if verifier == nil { + if g.config.issuer == "" { + return nil, fmt.Errorf("issuer is required to verify id_token claims") + } + return nil, fmt.Errorf("id_token verifier is not configured") + } + + claims, err := verifier(ctx, idToken) if err != nil { return nil, err } @@ -88,7 +116,7 @@ func (g *genericProvider) userInfoFromIDToken(idToken string) (*oauth.ProviderUs } func (g *genericProvider) fetchUserInfoFromURL(ctx context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, g.config.userInfoURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, g.config.userInfoURL, http.NoBody) if err != nil { return nil, err } diff --git a/plugins/oauth-generic/provider_test.go b/plugins/oauth-generic/provider_test.go new file mode 100644 index 0000000..ec8fefa --- /dev/null +++ b/plugins/oauth-generic/provider_test.go @@ -0,0 +1,111 @@ +package oauthgeneric + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestGetUserInfo_MapsVerifiedIDTokenClaims(t *testing.T) { + t.Parallel() + + called := false + provider := New( + WithName("generic"), + WithClientID("client-id"), + WithClientSecret("client-secret"), + WithAuthorizationURL("https://provider.example.com/oauth/authorize"), + WithTokenURL("https://provider.example.com/oauth/token"), + WithIssuer("https://provider.example.com"), + WithIDTokenVerifier(func(_ context.Context, idToken string) (map[string]any, error) { + called = true + if idToken != "id-token" { + t.Fatalf("unexpected id token: %s", idToken) + } + return map[string]any{ + "sub": "generic-user-1", + "email": "user@example.com", + "email_verified": true, + }, nil + }), + WithMapUserInfo(func(raw map[string]any) (*oauth.ProviderUserInfo, error) { + id, _ := raw["sub"].(string) + email, _ := raw["email"].(string) + emailVerified, _ := raw["email_verified"].(bool) + return &oauth.ProviderUserInfo{ + ID: id, + Email: email, + EmailVerified: emailVerified, + }, nil + }), + ) + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !called { + t.Fatal("expected verifier to be called") + } + if info.ID != "generic-user-1" || info.Email != "user@example.com" || !info.EmailVerified { + t.Fatalf("unexpected user info: %#v", info) + } + if info.Raw == nil { + t.Fatal("expected raw claims to be set") + } +} + +func TestGetUserInfo_RejectsMismatchedUserInfoSubject(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "sub": "userinfo-user", + "email": "user@example.com", + "email_verified": true, + }); err != nil { + t.Fatalf("write userinfo: %v", err) + } + })) + t.Cleanup(server.Close) + + provider := New( + WithName("generic"), + WithClientID("client-id"), + WithClientSecret("client-secret"), + WithAuthorizationURL("https://provider.example.com/oauth/authorize"), + WithTokenURL("https://provider.example.com/oauth/token"), + WithUserInfoURL(server.URL), + WithIDTokenVerifier(func(_ context.Context, idToken string) (map[string]any, error) { + if idToken != "id-token" { + t.Fatalf("unexpected id token: %s", idToken) + } + return map[string]any{ + "sub": "id-token-user", + }, nil + }), + WithMapUserInfo(func(raw map[string]any) (*oauth.ProviderUserInfo, error) { + id, _ := raw["sub"].(string) + email, _ := raw["email"].(string) + emailVerified, _ := raw["email_verified"].(bool) + return &oauth.ProviderUserInfo{ + ID: id, + Email: email, + EmailVerified: emailVerified, + }, nil + }), + ) + + _, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{ + AccessToken: "access-token", + IDToken: "id-token", + }) + if err == nil { + t.Fatal("expected mismatched subject to be rejected") + } +} diff --git a/plugins/oauth-github/go.mod b/plugins/oauth-github/go.mod index 5a0d97a..a004e93 100644 --- a/plugins/oauth-github/go.mod +++ b/plugins/oauth-github/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-github/go.sum b/plugins/oauth-github/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-github/go.sum +++ b/plugins/oauth-github/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-github/provider.go b/plugins/oauth-github/provider.go index 8b90d0e..7fc31f0 100644 --- a/plugins/oauth-github/provider.go +++ b/plugins/oauth-github/provider.go @@ -25,6 +25,7 @@ func New(opts ...ConfigOption) oauth.Provider { return newGitHubProvider(cfg) } +//nolint:gosec // OAuth endpoint URL, not a credential. var githubEndpoint = oauth2.Endpoint{ AuthURL: "https://github.com/login/oauth/authorize", TokenURL: "https://github.com/login/oauth/access_token", @@ -71,9 +72,13 @@ func (g *githubProvider) GetUserInfo(ctx context.Context, token *oauth.TokenResp return nil, err } - email := raw["email"] - if email == nil || email == "" { - email, _ = g.fetchPrimaryEmail(ctx, token.AccessToken) + email, _ := raw["email"].(string) + selectedEmail, emailVerified, err := g.fetchPrimaryEmail(ctx, token.AccessToken, email) + if err != nil { + return nil, err + } + if email == "" { + email = selectedEmail } id, _ := raw["id"].(float64) @@ -81,49 +86,70 @@ func (g *githubProvider) GetUserInfo(ctx context.Context, token *oauth.TokenResp avatarURL, _ := raw["avatar_url"].(string) return &oauth.ProviderUserInfo{ ID: fmt.Sprintf("%d", int64(id)), - Email: email.(string), - EmailVerified: email != "", + Email: email, + EmailVerified: emailVerified, Name: name, AvatarURL: avatarURL, Raw: raw, }, nil } -func (g *githubProvider) fetchPrimaryEmail(ctx context.Context, accessToken string) (string, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user/emails", nil) +type githubEmail struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +func (g *githubProvider) fetchPrimaryEmail(ctx context.Context, accessToken, preferredEmail string) (string, bool, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user/emails", http.NoBody) if err != nil { - return "", err + return "", false, err } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Accept", "application/json") resp, err := g.httpClient.Do(req) if err != nil { - return "", err + return "", false, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", nil - } - var emails []struct { - Email string `json:"email"` - Primary bool `json:"primary"` - Verified bool `json:"verified"` + return preferredEmail, false, nil } + var emails []githubEmail if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { - return "", err + return "", false, err + } + email, verified := selectGitHubEmail(preferredEmail, emails) + return email, verified, nil +} + +func selectGitHubEmail(preferredEmail string, emails []githubEmail) (string, bool) { + if preferredEmail != "" { + for _, e := range emails { + if e.Email == preferredEmail { + return preferredEmail, e.Verified + } + } + return preferredEmail, false } + for _, e := range emails { if e.Primary && e.Verified { - return e.Email, nil + return e.Email, true } } for _, e := range emails { if e.Verified { - return e.Email, nil + return e.Email, true + } + } + for _, e := range emails { + if e.Primary { + return e.Email, false } } if len(emails) > 0 { - return emails[0].Email, nil + return emails[0].Email, false } - return "", nil + return "", false } diff --git a/plugins/oauth-github/provider_test.go b/plugins/oauth-github/provider_test.go new file mode 100644 index 0000000..4969a4d --- /dev/null +++ b/plugins/oauth-github/provider_test.go @@ -0,0 +1,37 @@ +package oauthgithub + +import "testing" + +func TestSelectGitHubEmail_VerifiesPreferredEmail(t *testing.T) { + t.Parallel() + + email, verified := selectGitHubEmail("user@example.com", []githubEmail{ + {Email: "user@example.com", Primary: true, Verified: true}, + }) + if email != "user@example.com" || !verified { + t.Fatalf("email=%q verified=%v", email, verified) + } +} + +func TestSelectGitHubEmail_UnmatchedPreferredEmailIsUnverified(t *testing.T) { + t.Parallel() + + email, verified := selectGitHubEmail("public@example.com", []githubEmail{ + {Email: "primary@example.com", Primary: true, Verified: true}, + }) + if email != "public@example.com" || verified { + t.Fatalf("email=%q verified=%v", email, verified) + } +} + +func TestSelectGitHubEmail_PrefersVerifiedEmailWhenNoProfileEmail(t *testing.T) { + t.Parallel() + + email, verified := selectGitHubEmail("", []githubEmail{ + {Email: "primary@example.com", Primary: true, Verified: false}, + {Email: "verified@example.com", Verified: true}, + }) + if email != "verified@example.com" || !verified { + t.Fatalf("email=%q verified=%v", email, verified) + } +} diff --git a/plugins/oauth-google/go.mod b/plugins/oauth-google/go.mod index 95ffaca..50ca4d0 100644 --- a/plugins/oauth-google/go.mod +++ b/plugins/oauth-google/go.mod @@ -4,11 +4,13 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-google/go.sum b/plugins/oauth-google/go.sum index ecba3b4..5b2afa7 100644 --- a/plugins/oauth-google/go.sum +++ b/plugins/oauth-google/go.sum @@ -1,15 +1,19 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-google/options.go b/plugins/oauth-google/options.go index ed06940..116c9e0 100644 --- a/plugins/oauth-google/options.go +++ b/plugins/oauth-google/options.go @@ -1,14 +1,17 @@ package oauthgoogle +import "github.com/thecodearcher/limen/plugins/oauth" + // ConfigOption configures the Google OAuth plugin. type ConfigOption func(*config) type config struct { - clientID string - clientSecret string - redirectURL string - scopes []string - options map[string]string + clientID string + clientSecret string + redirectURL string + scopes []string + options map[string]string + verifyIDToken oauth.IDTokenVerifier } // WithClientID sets the Google OAuth2 client ID. @@ -50,3 +53,9 @@ func WithOption(key, value string) ConfigOption { c.options[key] = value } } + +func WithIDTokenVerifier(verifier oauth.IDTokenVerifier) ConfigOption { + return func(c *config) { + c.verifyIDToken = verifier + } +} diff --git a/plugins/oauth-google/provider.go b/plugins/oauth-google/provider.go index 80177cb..6a9d46f 100644 --- a/plugins/oauth-google/provider.go +++ b/plugins/oauth-google/provider.go @@ -3,19 +3,20 @@ package oauthgoogle import ( "context" - "encoding/base64" - "encoding/json" "errors" "fmt" "os" - "strings" "golang.org/x/oauth2" - "golang.org/x/oauth2/google" "github.com/thecodearcher/limen/plugins/oauth" ) +var googleEndpoint = oauth2.Endpoint{ + AuthURL: "https://accounts.google.com/o/oauth2/v2/auth", + TokenURL: "https://oauth2.googleapis.com/token", +} + // New creates a Google OAuth provider that implements oauth.Provider. func New(opts ...ConfigOption) oauth.Provider { cfg := &config{ @@ -39,12 +40,15 @@ func newGoogleProvider(cfg *config) *googleProvider { if len(scopes) == 0 { scopes = []string{"openid", "email", "profile"} } + if cfg.verifyIDToken == nil { + cfg.verifyIDToken = oauth.NewIDTokenVerifier("https://accounts.google.com", cfg.clientID) + } config := &oauth2.Config{ ClientID: cfg.clientID, ClientSecret: cfg.clientSecret, RedirectURL: cfg.redirectURL, Scopes: scopes, - Endpoint: google.Endpoint, + Endpoint: googleEndpoint, } return &googleProvider{oauthConfig: config, config: cfg} } @@ -62,14 +66,21 @@ func (g *googleProvider) OAuth2Config() (*oauth2.Config, []oauth2.AuthCodeOption return g.oauthConfig, authOpts } -func (g *googleProvider) GetUserInfo(_ context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { +func (g *googleProvider) IDTokenNonceEnabled() bool { + return true +} + +func (g *googleProvider) GetUserInfo(ctx context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { if token.IDToken == "" { return nil, errors.New("google: id_token required; include openid scope") } - claims, err := decodeIDTokenClaims(token.IDToken) + claims, err := g.config.verifyIDToken(ctx, token.IDToken) if err != nil { return nil, fmt.Errorf("google: %w", err) } + if err := oauth.VerifyIDTokenNonce(claims, oauth.IDTokenNonce(ctx)); err != nil { + return nil, fmt.Errorf("google: %w", err) + } sub, _ := claims["sub"].(string) if sub == "" { @@ -92,22 +103,3 @@ func (g *googleProvider) GetUserInfo(_ context.Context, token *oauth.TokenRespon Raw: claims, }, nil } - -// decodeIDTokenClaims decodes the payload segment of a JWT without verification. -// Safe here because the token was obtained directly from Google's token endpoint over TLS. -func decodeIDTokenClaims(idToken string) (map[string]any, error) { - parts := strings.SplitN(idToken, ".", 3) - if len(parts) != 3 { - return nil, errors.New("id token has invalid JWT format") - } - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, fmt.Errorf("id token payload decode: %w", err) - } - var claims map[string]any - if err := json.Unmarshal(payload, &claims); err != nil { - return nil, fmt.Errorf("id token payload unmarshal: %w", err) - } - - return claims, nil -} diff --git a/plugins/oauth-google/provider_test.go b/plugins/oauth-google/provider_test.go new file mode 100644 index 0000000..f5caff7 --- /dev/null +++ b/plugins/oauth-google/provider_test.go @@ -0,0 +1,78 @@ +package oauthgoogle + +import ( + "context" + "testing" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestGetUserInfo_UsesIDTokenVerifier(t *testing.T) { + t.Parallel() + + called := false + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, idToken string) (map[string]any, error) { + called = true + if idToken != "id-token" { + t.Fatalf("unexpected id token: %s", idToken) + } + return map[string]any{ + "sub": "google-user-1", + "email": "user@example.com", + "email_verified": true, + "name": "Test User", + "picture": "https://example.com/avatar.png", + "nonce": "nonce-value", + }, nil + }), + ) + + ctx := oauth.ContextWithIDTokenNonce(context.Background(), "nonce-value") + info, err := provider.GetUserInfo(ctx, &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !called { + t.Fatal("expected verifier to be called") + } + if info.ID != "google-user-1" || info.Email != "user@example.com" || !info.EmailVerified { + t.Fatalf("unexpected user info: %#v", info) + } +} + +func TestOAuth2Config_UsesCurrentGoogleOIDCEndpoints(t *testing.T) { + t.Parallel() + + provider := New() + cfg, _ := provider.OAuth2Config() + if cfg.Endpoint.AuthURL != "https://accounts.google.com/o/oauth2/v2/auth" { + t.Fatalf("AuthURL = %q", cfg.Endpoint.AuthURL) + } + if cfg.Endpoint.TokenURL != "https://oauth2.googleapis.com/token" { + t.Fatalf("TokenURL = %q", cfg.Endpoint.TokenURL) + } +} + +func TestGetUserInfo_RejectsNonceMismatch(t *testing.T) { + t.Parallel() + + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, _ string) (map[string]any, error) { + return map[string]any{ + "sub": "google-user-1", + "email": "user@example.com", + "email_verified": true, + "nonce": "other-nonce", + }, nil + }), + ) + + ctx := oauth.ContextWithIDTokenNonce(context.Background(), "nonce-value") + _, err := provider.GetUserInfo(ctx, &oauth.TokenResponse{IDToken: "id-token"}) + if err == nil { + t.Fatal("expected nonce mismatch error") + } +} diff --git a/plugins/oauth-linkedin/go.mod b/plugins/oauth-linkedin/go.mod index 7d1478a..8fde9e4 100644 --- a/plugins/oauth-linkedin/go.mod +++ b/plugins/oauth-linkedin/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-linkedin/go.sum b/plugins/oauth-linkedin/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-linkedin/go.sum +++ b/plugins/oauth-linkedin/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-linkedin/options.go b/plugins/oauth-linkedin/options.go index 9d4f0f4..cea9309 100644 --- a/plugins/oauth-linkedin/options.go +++ b/plugins/oauth-linkedin/options.go @@ -1,14 +1,17 @@ package oauthlinkedin +import "github.com/thecodearcher/limen/plugins/oauth" + // ConfigOption configures the LinkedIn OAuth plugin. type ConfigOption func(*config) type config struct { - clientID string - clientSecret string - redirectURL string - scopes []string - options map[string]string + clientID string + clientSecret string + redirectURL string + scopes []string + options map[string]string + verifyIDToken oauth.IDTokenVerifier } // WithClientID sets the LinkedIn OAuth2 client ID. @@ -51,3 +54,9 @@ func WithOption(key, value string) ConfigOption { c.options[key] = value } } + +func WithIDTokenVerifier(verifier oauth.IDTokenVerifier) ConfigOption { + return func(c *config) { + c.verifyIDToken = verifier + } +} diff --git a/plugins/oauth-linkedin/provider.go b/plugins/oauth-linkedin/provider.go index b4be03e..e397e6e 100644 --- a/plugins/oauth-linkedin/provider.go +++ b/plugins/oauth-linkedin/provider.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "strconv" "golang.org/x/oauth2" @@ -17,6 +18,8 @@ var linkedinEndpoint = oauth2.Endpoint{ TokenURL: "https://www.linkedin.com/oauth/v2/accessToken", } +const linkedinIssuer = "https://www.linkedin.com/oauth" + // New creates a LinkedIn OAuth provider that implements oauth.Provider. func New(opts ...ConfigOption) oauth.Provider { cfg := &config{ @@ -36,6 +39,9 @@ type linkedInProvider struct { } func newLinkedInProvider(cfg *config) *linkedInProvider { + if cfg.verifyIDToken == nil { + cfg.verifyIDToken = oauth.NewIDTokenVerifier(linkedinIssuer, cfg.clientID) + } config := &oauth2.Config{ ClientID: cfg.clientID, ClientSecret: cfg.clientSecret, @@ -64,11 +70,11 @@ func (l *linkedInProvider) PKCEEnabled() bool { return false } -func (l *linkedInProvider) GetUserInfo(_ context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { +func (l *linkedInProvider) GetUserInfo(ctx context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { if token.IDToken == "" { return nil, errors.New("linkedin: id_token required; include openid scope") } - claims, err := oauth.DecodeIDTokenClaims(token.IDToken) + claims, err := l.config.verifyIDToken(ctx, token.IDToken) if err != nil { return nil, fmt.Errorf("linkedin: %w", err) } @@ -85,17 +91,24 @@ func (l *linkedInProvider) GetUserInfo(_ context.Context, token *oauth.TokenResp name, _ := claims["name"].(string) picture, _ := claims["picture"].(string) - emailVerified := false - if verified, ok := claims["email_verified"].(string); ok { - emailVerified = verified == "true" - } - return &oauth.ProviderUserInfo{ ID: id, Email: email, - EmailVerified: emailVerified, + EmailVerified: boolClaim(claims, "email_verified"), Name: name, AvatarURL: picture, Raw: claims, }, nil } + +func boolClaim(raw map[string]any, key string) bool { + switch v := raw[key].(type) { + case bool: + return v + case string: + parsed, err := strconv.ParseBool(v) + return err == nil && parsed + default: + return false + } +} diff --git a/plugins/oauth-linkedin/provider_test.go b/plugins/oauth-linkedin/provider_test.go new file mode 100644 index 0000000..e3e40f2 --- /dev/null +++ b/plugins/oauth-linkedin/provider_test.go @@ -0,0 +1,71 @@ +package oauthlinkedin + +import ( + "context" + "testing" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestLinkedInIssuer(t *testing.T) { + t.Parallel() + + if linkedinIssuer != "https://www.linkedin.com/oauth" { + t.Fatalf("linkedinIssuer = %q", linkedinIssuer) + } +} + +func TestGetUserInfo_UsesIDTokenVerifier(t *testing.T) { + t.Parallel() + + called := false + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, idToken string) (map[string]any, error) { + called = true + if idToken != "id-token" { + t.Fatalf("unexpected id token: %s", idToken) + } + return map[string]any{ + "sub": "linkedin-user-1", + "email": "user@example.com", + "email_verified": "true", + "name": "Test User", + }, nil + }), + ) + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !called { + t.Fatal("expected verifier to be called") + } + if info.ID != "linkedin-user-1" || info.Email != "user@example.com" || !info.EmailVerified { + t.Fatalf("unexpected user info: %#v", info) + } +} + +func TestGetUserInfo_MapsBooleanEmailVerifiedClaim(t *testing.T) { + t.Parallel() + + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, _ string) (map[string]any, error) { + return map[string]any{ + "sub": "linkedin-user-1", + "email": "user@example.com", + "email_verified": true, + }, nil + }), + ) + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !info.EmailVerified { + t.Fatalf("expected verified email: %#v", info) + } +} diff --git a/plugins/oauth-microsoft/go.mod b/plugins/oauth-microsoft/go.mod index 3842425..3133227 100644 --- a/plugins/oauth-microsoft/go.mod +++ b/plugins/oauth-microsoft/go.mod @@ -3,11 +3,13 @@ module github.com/thecodearcher/limen/plugins/oauth-microsoft go 1.25.0 require ( + github.com/coreos/go-oidc/v3 v3.18.0 github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-microsoft/go.sum b/plugins/oauth-microsoft/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-microsoft/go.sum +++ b/plugins/oauth-microsoft/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-microsoft/options.go b/plugins/oauth-microsoft/options.go index dab82c8..ea0d470 100644 --- a/plugins/oauth-microsoft/options.go +++ b/plugins/oauth-microsoft/options.go @@ -1,16 +1,19 @@ package oauthmicrosoft +import "github.com/thecodearcher/limen/plugins/oauth" + // ConfigOption configures the Microsoft OAuth plugin. type ConfigOption func(*config) type config struct { - clientID string - clientSecret string - redirectURL string - scopes []string - tenant string - authorityURL string - options map[string]string + clientID string + clientSecret string + redirectURL string + scopes []string + tenant string + authorityURL string + options map[string]string + verifyIDToken oauth.IDTokenVerifier } // WithClientID sets the Microsoft OAuth2 client ID (Application ID). @@ -75,3 +78,9 @@ func WithOption(key, value string) ConfigOption { c.options[key] = value } } + +func WithIDTokenVerifier(verifier oauth.IDTokenVerifier) ConfigOption { + return func(c *config) { + c.verifyIDToken = verifier + } +} diff --git a/plugins/oauth-microsoft/provider.go b/plugins/oauth-microsoft/provider.go index 16e0c76..380b3fb 100644 --- a/plugins/oauth-microsoft/provider.go +++ b/plugins/oauth-microsoft/provider.go @@ -7,15 +7,20 @@ import ( "fmt" "os" "strings" + "sync" + "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" "github.com/thecodearcher/limen/plugins/oauth" ) const ( - defaultTenant = "common" - defaultAuthority = "https://login.microsoftonline.com" + defaultTenant = "common" + organizationsTenant = "organizations" + consumersTenant = "consumers" + defaultAuthority = "https://login.microsoftonline.com" + emailClaim = "email" ) func microsoftEndpoint(authority string) oauth2.Endpoint { @@ -30,7 +35,7 @@ func New(opts ...ConfigOption) oauth.Provider { cfg := &config{ clientID: os.Getenv("MICROSOFT_CLIENT_ID"), clientSecret: os.Getenv("MICROSOFT_CLIENT_SECRET"), - scopes: []string{"openid", "profile", "email"}, + scopes: []string{"openid", "profile", emailClaim}, tenant: defaultTenant, } for _, opt := range opts { @@ -45,14 +50,23 @@ type microsoftProvider struct { } func newMicrosoftProvider(cfg *config) *microsoftProvider { + tenant := cfg.tenant + if tenant == "" { + tenant = defaultTenant + } + authority := strings.TrimRight(cfg.authorityURL, "/") if authority == "" { - tenant := cfg.tenant - if tenant == "" { - tenant = defaultTenant - } authority = defaultAuthority + "/" + tenant } + issuer := authority + "/v2.0" + if cfg.verifyIDToken == nil { + if cfg.authorityURL == "" && isMicrosoftSharedTenant(tenant) { + cfg.verifyIDToken = newMicrosoftSharedTenantIDTokenVerifier(defaultAuthority, tenant, cfg.clientID) + } else { + cfg.verifyIDToken = oauth.NewIDTokenVerifier(issuer, cfg.clientID) + } + } config := &oauth2.Config{ ClientID: cfg.clientID, ClientSecret: cfg.clientSecret, @@ -63,6 +77,90 @@ func newMicrosoftProvider(cfg *config) *microsoftProvider { return µsoftProvider{oauthConfig: config, config: cfg} } +func isMicrosoftSharedTenant(tenant string) bool { + return tenant == defaultTenant || tenant == organizationsTenant || tenant == consumersTenant +} + +func newMicrosoftSharedTenantIDTokenVerifier(authorityBase, tenant, clientID string) oauth.IDTokenVerifier { + discoveryIssuer := strings.TrimRight(authorityBase, "/") + "/" + tenant + "/v2.0" + expectedDiscoveryIssuer := strings.TrimRight(authorityBase, "/") + "/{tenantid}/v2.0" + if tenant == consumersTenant { + expectedDiscoveryIssuer = strings.TrimRight(authorityBase, "/") + "/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0" + } + + var mu sync.Mutex + var verifier *oidc.IDTokenVerifier + + return func(ctx context.Context, idToken string) (map[string]any, error) { + mu.Lock() + if verifier == nil { + providerCtx := oidc.InsecureIssuerURLContext(ctx, expectedDiscoveryIssuer) + provider, err := oidc.NewProvider(providerCtx, discoveryIssuer) + if err != nil { + mu.Unlock() + return nil, fmt.Errorf("microsoft: id token provider discovery failed: %w", err) + } + verifier = provider.Verifier(&oidc.Config{ + ClientID: clientID, + SkipIssuerCheck: true, + }) + } + current := verifier + mu.Unlock() + + verified, err := current.Verify(ctx, idToken) + if err != nil { + return nil, fmt.Errorf("microsoft: id token verification failed: %w", err) + } + if !isMicrosoftTenantIssuer(authorityBase, tenant, verified.Issuer) { + return nil, fmt.Errorf("microsoft: id token issuer is not trusted") + } + + var claims map[string]any + if err := verified.Claims(&claims); err != nil { + return nil, fmt.Errorf("microsoft: id token claims decode failed: %w", err) + } + return claims, nil + } +} + +func isMicrosoftTenantIssuer(authorityBase, tenant, issuer string) bool { + base := strings.TrimRight(authorityBase, "/") + if tenant == consumersTenant { + return issuer == base+"/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0" + } + if tenant != defaultTenant && tenant != organizationsTenant { + return issuer == base+"/"+tenant+"/v2.0" + } + + prefix := base + "/" + suffix := "/v2.0" + if !strings.HasPrefix(issuer, prefix) || !strings.HasSuffix(issuer, suffix) { + return false + } + tenantID := strings.TrimSuffix(strings.TrimPrefix(issuer, prefix), suffix) + return isGUID(tenantID) +} + +func isGUID(value string) bool { + if len(value) != 36 { + return false + } + for i, ch := range value { + switch i { + case 8, 13, 18, 23: + if ch != '-' { + return false + } + default: + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') { + return false + } + } + } + return true +} + func (m *microsoftProvider) Name() string { return "microsoft" } @@ -75,11 +173,11 @@ func (m *microsoftProvider) OAuth2Config() (*oauth2.Config, []oauth2.AuthCodeOpt return m.oauthConfig, authOpts } -func (m *microsoftProvider) GetUserInfo(_ context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { +func (m *microsoftProvider) GetUserInfo(ctx context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { if token.IDToken == "" { return nil, errors.New("microsoft: id_token required; include openid scope") } - claims, err := oauth.DecodeIDTokenClaims(token.IDToken) + claims, err := m.config.verifyIDToken(ctx, token.IDToken) if err != nil { return nil, fmt.Errorf("microsoft: %w", err) } @@ -91,11 +189,12 @@ func (m *microsoftProvider) GetUserInfo(_ context.Context, token *oauth.TokenRes email := extractEmail(claims) name, _ := claims["name"].(string) + emailVerified := microsoftEmailVerified(claims, email) return &oauth.ProviderUserInfo{ ID: oid, Email: email, - EmailVerified: email != "", + EmailVerified: emailVerified, Name: name, Raw: claims, }, nil @@ -105,9 +204,41 @@ func (m *microsoftProvider) GetUserInfo(_ context.Context, token *oauth.TokenRes // "email" is preferred; falls back to "preferred_username" which Microsoft // typically populates with the user's UPN or email address. func extractEmail(claims map[string]any) string { - if email, _ := claims["email"].(string); email != "" { + if email, _ := claims[emailClaim].(string); email != "" { return email } upn, _ := claims["preferred_username"].(string) return upn } + +func microsoftEmailVerified(claims map[string]any, email string) bool { + if verified, ok := claims["email_verified"].(bool); ok { + return verified + } + if verified, ok := claims["email_verified"].(string); ok { + return verified == "true" || verified == "1" + } + if email == "" { + return false + } + return stringSliceClaimContains(claims["verified_primary_email"], email) || + stringSliceClaimContains(claims["verified_secondary_email"], email) +} + +func stringSliceClaimContains(value any, needle string) bool { + switch values := value.(type) { + case []string: + for _, value := range values { + if value == needle { + return true + } + } + case []any: + for _, value := range values { + if value == needle { + return true + } + } + } + return false +} diff --git a/plugins/oauth-microsoft/provider_test.go b/plugins/oauth-microsoft/provider_test.go new file mode 100644 index 0000000..3ffb810 --- /dev/null +++ b/plugins/oauth-microsoft/provider_test.go @@ -0,0 +1,121 @@ +package oauthmicrosoft + +import ( + "context" + "testing" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestGetUserInfo_UsesIDTokenVerifier(t *testing.T) { + t.Parallel() + + called := false + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, idToken string) (map[string]any, error) { + called = true + if idToken != "id-token" { + t.Fatalf("unexpected id token: %s", idToken) + } + return map[string]any{ + "oid": "microsoft-user-1", + "email": "user@example.com", + "email_verified": true, + "name": "Test User", + }, nil + }), + ) + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !called { + t.Fatal("expected verifier to be called") + } + if info.ID != "microsoft-user-1" || info.Email != "user@example.com" || !info.EmailVerified { + t.Fatalf("unexpected user info: %#v", info) + } +} + +func TestGetUserInfo_EmailVerifiedRequiresTrustedClaim(t *testing.T) { + t.Parallel() + + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, _ string) (map[string]any, error) { + return map[string]any{ + "oid": "microsoft-user-1", + "email": "user@example.com", + }, nil + }), + ) + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if info.EmailVerified { + t.Fatalf("expected email to be unverified without email_verified claim: %#v", info) + } +} + +func TestMicrosoftEmailVerified_UsesVerifiedEmailArrays(t *testing.T) { + t.Parallel() + + claims := map[string]any{ + "verified_primary_email": []any{"user@example.com"}, + } + + if !microsoftEmailVerified(claims, "user@example.com") { + t.Fatal("expected verified primary email to mark the email verified") + } +} + +func TestMicrosoftTenantIssuerValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tenant string + issuer string + want bool + }{ + { + name: "common accepts tenant GUID issuer", + tenant: "common", + issuer: "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/v2.0", + want: true, + }, + { + name: "organizations rejects non-GUID tenant", + tenant: "organizations", + issuer: "https://login.microsoftonline.com/evil/v2.0", + want: false, + }, + { + name: "consumers accepts Microsoft consumer tenant", + tenant: "consumers", + issuer: "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0", + want: true, + }, + { + name: "specific tenant requires exact issuer", + tenant: "contoso.onmicrosoft.com", + issuer: "https://login.microsoftonline.com/contoso.onmicrosoft.com/v2.0", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := isMicrosoftTenantIssuer(defaultAuthority, tt.tenant, tt.issuer) + if got != tt.want { + t.Fatalf("isMicrosoftTenantIssuer() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/plugins/oauth-spotify/go.mod b/plugins/oauth-spotify/go.mod index c781be3..0a7f3c7 100644 --- a/plugins/oauth-spotify/go.mod +++ b/plugins/oauth-spotify/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-spotify/go.sum b/plugins/oauth-spotify/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-spotify/go.sum +++ b/plugins/oauth-spotify/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-spotify/provider.go b/plugins/oauth-spotify/provider.go index d6d1caa..7d8ca67 100644 --- a/plugins/oauth-spotify/provider.go +++ b/plugins/oauth-spotify/provider.go @@ -13,6 +13,7 @@ import ( "github.com/thecodearcher/limen/plugins/oauth" ) +//nolint:gosec // OAuth endpoint URL, not a credential. var spotifyEndpoint = oauth2.Endpoint{ AuthURL: "https://accounts.spotify.com/authorize", TokenURL: "https://accounts.spotify.com/api/token", @@ -85,9 +86,10 @@ func (s *spotifyProvider) GetUserInfo(ctx context.Context, token *oauth.TokenRes } return &oauth.ProviderUserInfo{ - ID: id, - Email: email, - EmailVerified: email != "", + ID: id, + Email: email, + // Spotify /v1/me exposes email but does not verify ownership of it. + EmailVerified: false, Name: name, AvatarURL: extractAvatarURL(raw), Raw: raw, diff --git a/plugins/oauth-spotify/provider_test.go b/plugins/oauth-spotify/provider_test.go new file mode 100644 index 0000000..266d61b --- /dev/null +++ b/plugins/oauth-spotify/provider_test.go @@ -0,0 +1,49 @@ +package oauthspotify + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestGetUserInfo_EmailIsNotVerified(t *testing.T) { + t.Parallel() + + provider := New().(*spotifyProvider) + provider.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Header.Get("Authorization"); got != "Bearer access-token" { + t.Fatalf("Authorization = %q", got) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{ + "id": "spotify-user-1", + "display_name": "Test User", + "email": "user@example.com", + "images": [{"url": "https://example.com/avatar.png"}] + }`)), + }, nil + })} + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{AccessToken: "access-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if info.ID != "spotify-user-1" || info.Email != "user@example.com" { + t.Fatalf("unexpected user info: %#v", info) + } + if info.EmailVerified { + t.Fatalf("spotify email should not be marked verified: %#v", info) + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/plugins/oauth-twitch/go.mod b/plugins/oauth-twitch/go.mod index 146f868..ecdd3f4 100644 --- a/plugins/oauth-twitch/go.mod +++ b/plugins/oauth-twitch/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-twitch/go.sum b/plugins/oauth-twitch/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-twitch/go.sum +++ b/plugins/oauth-twitch/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth-twitch/options.go b/plugins/oauth-twitch/options.go index 601c85d..c51e6a6 100644 --- a/plugins/oauth-twitch/options.go +++ b/plugins/oauth-twitch/options.go @@ -1,14 +1,17 @@ package oauthtwitch +import "github.com/thecodearcher/limen/plugins/oauth" + // ConfigOption configures the Twitch OAuth plugin. type ConfigOption func(*config) type config struct { - clientID string - clientSecret string - redirectURL string - scopes []string - options map[string]string + clientID string + clientSecret string + redirectURL string + scopes []string + options map[string]string + verifyIDToken oauth.IDTokenVerifier } // WithClientID sets the Twitch OAuth2 client ID. @@ -51,3 +54,9 @@ func WithOption(key, value string) ConfigOption { c.options[key] = value } } + +func WithIDTokenVerifier(verifier oauth.IDTokenVerifier) ConfigOption { + return func(c *config) { + c.verifyIDToken = verifier + } +} diff --git a/plugins/oauth-twitch/provider.go b/plugins/oauth-twitch/provider.go index 2751a83..4ac32ce 100644 --- a/plugins/oauth-twitch/provider.go +++ b/plugins/oauth-twitch/provider.go @@ -38,6 +38,9 @@ type twitchProvider struct { } func newTwitchProvider(cfg *config) *twitchProvider { + if cfg.verifyIDToken == nil { + cfg.verifyIDToken = oauth.NewIDTokenVerifier("https://id.twitch.tv/oauth2", cfg.clientID) + } config := &oauth2.Config{ ClientID: cfg.clientID, ClientSecret: cfg.clientSecret, @@ -62,11 +65,11 @@ func (t *twitchProvider) OAuth2Config() (*oauth2.Config, []oauth2.AuthCodeOption return t.oauthConfig, authOpts } -func (t *twitchProvider) GetUserInfo(_ context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { +func (t *twitchProvider) GetUserInfo(ctx context.Context, token *oauth.TokenResponse) (*oauth.ProviderUserInfo, error) { if token.IDToken == "" { return nil, errors.New("twitch: id_token required; include openid scope") } - claims, err := oauth.DecodeIDTokenClaims(token.IDToken) + claims, err := t.config.verifyIDToken(ctx, token.IDToken) if err != nil { return nil, fmt.Errorf("twitch: %w", err) } diff --git a/plugins/oauth-twitch/provider_test.go b/plugins/oauth-twitch/provider_test.go new file mode 100644 index 0000000..e26b8dd --- /dev/null +++ b/plugins/oauth-twitch/provider_test.go @@ -0,0 +1,40 @@ +package oauthtwitch + +import ( + "context" + "testing" + + "github.com/thecodearcher/limen/plugins/oauth" +) + +func TestGetUserInfo_UsesIDTokenVerifier(t *testing.T) { + t.Parallel() + + called := false + provider := New( + WithClientID("client-id"), + WithIDTokenVerifier(func(_ context.Context, idToken string) (map[string]any, error) { + called = true + if idToken != "id-token" { + t.Fatalf("unexpected id token: %s", idToken) + } + return map[string]any{ + "sub": "twitch-user-1", + "email": "user@example.com", + "email_verified": true, + "preferred_username": "testuser", + }, nil + }), + ) + + info, err := provider.GetUserInfo(context.Background(), &oauth.TokenResponse{IDToken: "id-token"}) + if err != nil { + t.Fatalf("GetUserInfo: %v", err) + } + if !called { + t.Fatal("expected verifier to be called") + } + if info.ID != "twitch-user-1" || info.Email != "user@example.com" || !info.EmailVerified { + t.Fatalf("unexpected user info: %#v", info) + } +} diff --git a/plugins/oauth-twitter/go.mod b/plugins/oauth-twitter/go.mod index ac61210..12c5cf3 100644 --- a/plugins/oauth-twitter/go.mod +++ b/plugins/oauth-twitter/go.mod @@ -4,10 +4,12 @@ go 1.25.0 require ( github.com/thecodearcher/limen/plugins/oauth v0.1.0 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( + github.com/coreos/go-oidc/v3 v3.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/thecodearcher/limen v0.1.1 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect diff --git a/plugins/oauth-twitter/go.sum b/plugins/oauth-twitter/go.sum index 533dc93..f1c3de1 100644 --- a/plugins/oauth-twitter/go.sum +++ b/plugins/oauth-twitter/go.sum @@ -1,13 +1,17 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/oauth/account_linker.go b/plugins/oauth/account_linker.go index c04cf2e..f3a5c23 100644 --- a/plugins/oauth/account_linker.go +++ b/plugins/oauth/account_linker.go @@ -30,6 +30,13 @@ func (o *oauthPlugin) CreateOrLinkAccount(ctx context.Context, info *limen.OAuth } if user != nil { + // Implicitly linking by email requires proof that the provider owns that email. + if !info.EmailVerified { + return nil, ErrOAuthEmailNotVerified + } + if user.EmailVerifiedAt == nil { + return nil, ErrOAuthLocalEmailNotVerified + } return o.linkAccountToUser(ctx, user, info) } diff --git a/plugins/oauth/account_linker_test.go b/plugins/oauth/account_linker_test.go index 4af1dd0..6376356 100644 --- a/plugins/oauth/account_linker_test.go +++ b/plugins/oauth/account_linker_test.go @@ -45,11 +45,13 @@ func TestCreateOrLinkAccount(t *testing.T) { ctx := context.Background() user := seedOAuthUser(t, l, "existing@example.com") + verifyOAuthUser(t, plugin, user) profile := &limen.OAuthAccountProfile{ Provider: "test", ProviderAccountID: "prov-456", Email: "existing@example.com", + EmailVerified: true, AccessToken: "at", } @@ -71,6 +73,7 @@ func TestCreateOrLinkAccount(t *testing.T) { Provider: "test", ProviderAccountID: "prov-789", Email: "update@example.com", + EmailVerified: true, AccessToken: "new-at", RefreshToken: "new-rt", } @@ -90,12 +93,96 @@ func TestCreateOrLinkAccount(t *testing.T) { Provider: "test", ProviderAccountID: "prov-new", Email: "noexist@example.com", + EmailVerified: true, AccessToken: "at", } _, err := plugin.CreateOrLinkAccount(ctx, profile) assert.ErrorIs(t, err, ErrAccountNotFound) }) + + t.Run("unverified email creates unverified new user", func(t *testing.T) { + t.Parallel() + + _, plugin := newTestOAuthPlugin(t, WithDisableTokensEncryption()) + ctx := context.Background() + + profile := &limen.OAuthAccountProfile{ + Provider: "test", + ProviderAccountID: "prov-unverified-new", + Email: "unverified-new@example.com", + EmailVerified: false, + AccessToken: "at", + } + + result, err := plugin.CreateOrLinkAccount(ctx, profile) + require.NoError(t, err) + require.NotNil(t, result.User) + assert.Nil(t, result.User.EmailVerifiedAt) + }) + + t.Run("unverified email updates existing account", func(t *testing.T) { + t.Parallel() + + l, plugin := newTestOAuthPlugin(t, WithDisableTokensEncryption()) + ctx := context.Background() + + user := seedOAuthUser(t, l, "existing-unverified@example.com") + seedOAuthAccount(t, plugin, user.ID, "test", "prov-unverified-existing") + + profile := &limen.OAuthAccountProfile{ + Provider: "test", + ProviderAccountID: "prov-unverified-existing", + Email: "existing-unverified@example.com", + EmailVerified: false, + AccessToken: "at", + } + + result, err := plugin.CreateOrLinkAccount(ctx, profile) + require.NoError(t, err) + assert.Equal(t, user.ID, result.User.ID) + }) + + t.Run("unverified provider email rejected for implicit link", func(t *testing.T) { + t.Parallel() + + l, plugin := newTestOAuthPlugin(t, WithDisableTokensEncryption()) + ctx := context.Background() + + user := seedOAuthUser(t, l, "implicit-unverified@example.com") + verifyOAuthUser(t, plugin, user) + + profile := &limen.OAuthAccountProfile{ + Provider: "test", + ProviderAccountID: "prov-implicit-unverified", + Email: "implicit-unverified@example.com", + EmailVerified: false, + AccessToken: "at", + } + + _, err := plugin.CreateOrLinkAccount(ctx, profile) + assert.ErrorIs(t, err, ErrOAuthEmailNotVerified) + }) + + t.Run("implicit link rejects unverified local email", func(t *testing.T) { + t.Parallel() + + l, plugin := newTestOAuthPlugin(t, WithDisableTokensEncryption()) + ctx := context.Background() + + seedOAuthUser(t, l, "local-unverified@example.com") + + profile := &limen.OAuthAccountProfile{ + Provider: "test", + ProviderAccountID: "prov-local-unverified", + Email: "local-unverified@example.com", + EmailVerified: true, + AccessToken: "at", + } + + _, err := plugin.CreateOrLinkAccount(ctx, profile) + assert.ErrorIs(t, err, ErrOAuthLocalEmailNotVerified) + }) } func TestLinkAccountToCurrentUser(t *testing.T) { @@ -113,6 +200,7 @@ func TestLinkAccountToCurrentUser(t *testing.T) { Provider: "test", ProviderAccountID: "prov-link", Email: "link@example.com", + EmailVerified: true, AccessToken: "at", } @@ -134,6 +222,7 @@ func TestLinkAccountToCurrentUser(t *testing.T) { Provider: "test", ProviderAccountID: "prov-same", Email: "same@example.com", + EmailVerified: true, AccessToken: "new-at", } @@ -157,6 +246,7 @@ func TestLinkAccountToCurrentUser(t *testing.T) { Provider: "test", ProviderAccountID: "prov-taken", Email: "other@example.com", + EmailVerified: true, AccessToken: "at", } @@ -176,6 +266,7 @@ func TestLinkAccountToCurrentUser(t *testing.T) { Provider: "test", ProviderAccountID: "prov-diff-email", Email: "different@example.com", + EmailVerified: true, AccessToken: "at", } @@ -195,6 +286,28 @@ func TestLinkAccountToCurrentUser(t *testing.T) { Provider: "test", ProviderAccountID: "prov-cross", Email: "different@example.com", + EmailVerified: true, + AccessToken: "at", + } + + result, err := plugin.LinkAccountToCurrentUser(ctx, user, profile) + require.NoError(t, err) + assert.Equal(t, user.ID, result.User.ID) + }) + + t.Run("unverified email can be linked by current user", func(t *testing.T) { + t.Parallel() + + l, plugin := newTestOAuthPlugin(t, WithDisableTokensEncryption()) + ctx := context.Background() + + user := seedOAuthUser(t, l, "current-unverified@example.com") + + profile := &limen.OAuthAccountProfile{ + Provider: "test", + ProviderAccountID: "prov-current-unverified", + Email: "current-unverified@example.com", + EmailVerified: false, AccessToken: "at", } diff --git a/plugins/oauth/authentication.go b/plugins/oauth/authentication.go index 82d5e94..c626c76 100644 --- a/plugins/oauth/authentication.go +++ b/plugins/oauth/authentication.go @@ -35,11 +35,14 @@ func (o *oauthPlugin) constructProviderRedirectURL(provider Provider, config *oa return o.core.GetBaseURLWithPluginPath(limen.PluginOAuth, fmt.Sprintf("%s/callback", provider.Name())) } -func (o *oauthPlugin) buildAuthorizationURL(ctx context.Context, provider Provider, stateToken, verifier string) (string, error) { +func (o *oauthPlugin) buildAuthorizationURL(ctx context.Context, provider Provider, stateToken, verifier, nonce string) (string, error) { if pkce, ok := provider.(PKCEEnabledProvider); ok && !pkce.PKCEEnabled() { verifier = "" } config, authOpts := o.getProviderConfig(provider) + if nonce != "" { + authOpts = append(authOpts, oauth2.SetAuthURLParam("nonce", nonce)) + } if builder, ok := provider.(AuthorizationURLBuilder); ok { return builder.BuildAuthorizationURL(ctx, stateToken, verifier, config.RedirectURL) } @@ -94,6 +97,10 @@ func (o *oauthPlugin) GetAuthorizationURL(ctx context.Context, providerName stri } verifier := generateCodeVerifier() + nonce := "" + if p, ok := provider.(IDTokenNonceProvider); ok && p.IDTokenNonceEnabled() { + nonce = generateRandomString() + } data := map[string]any{ pkceDataKey: verifier, @@ -101,13 +108,16 @@ func (o *oauthPlugin) GetAuthorizationURL(ctx context.Context, providerName stri redirectURIKey: redirectURI, errorRedirectURIKey: errorRedirectURI, } + if nonce != "" { + data[nonceDataKey] = nonce + } stateToken, cookieValue, err := o.stateStore.Generate(ctx, data) if err != nil { return "", "", err } - url, err := o.buildAuthorizationURL(ctx, provider, stateToken, verifier) + url, err := o.buildAuthorizationURL(ctx, provider, stateToken, verifier, nonce) if err != nil { return "", "", err } @@ -187,6 +197,9 @@ func (o *oauthPlugin) HandleOAuthCallback(ctx context.Context, providerName, cod if callbackErr != nil { return nil, stateData, callbackErr.ToLimenError() } + if nonce, _ := stateData[nonceDataKey].(string); nonce != "" { + ctx = ContextWithIDTokenNonce(ctx, nonce) + } if code == "" { return nil, stateData, limen.NewLimenError("authorization code is required", http.StatusBadRequest, nil) diff --git a/plugins/oauth/authentication_test.go b/plugins/oauth/authentication_test.go index 5810fce..7b1f735 100644 --- a/plugins/oauth/authentication_test.go +++ b/plugins/oauth/authentication_test.go @@ -51,6 +51,14 @@ func (p *responseModeProvider) ResponseMode() ResponseMode { return ResponseModeFormPost } +type nonceProvider struct { + testProvider +} + +func (p *nonceProvider) IDTokenNonceEnabled() bool { + return true +} + func TestGetAuthorizationURL(t *testing.T) { t.Parallel() @@ -147,6 +155,29 @@ func TestGetAuthorizationURL(t *testing.T) { require.NoError(t, parseErr) assert.NotEmpty(t, parsed.Query().Get("state")) }) + + t.Run("nonce provider includes nonce param and stores expected nonce", func(t *testing.T) { + t.Parallel() + + provider := &nonceProvider{testProvider: testProvider{name: "nonce-provider"}} + l, plugin := newTestOAuthPlugin(t, WithProviders(provider)) + _ = l.Handler() + + authURL, cookieValue, err := plugin.GetAuthorizationURL(context.Background(), "nonce-provider", &OAuthAuthorizeURLData{}) + require.NoError(t, err) + require.NotEmpty(t, cookieValue) + + parsed, parseErr := url.Parse(authURL) + require.NoError(t, parseErr) + nonce := parsed.Query().Get("nonce") + stateToken := parsed.Query().Get("state") + require.NotEmpty(t, nonce) + require.NotEmpty(t, stateToken) + + stateData, err := plugin.stateStore.Validate(context.Background(), stateToken, cookieValue) + require.NoError(t, err) + assert.Equal(t, nonce, stateData[nonceDataKey]) + }) } func TestExchangeAuthorizationCodeForTokens(t *testing.T) { diff --git a/plugins/oauth/constants.go b/plugins/oauth/constants.go index f3a4f5d..8ee8f8b 100644 --- a/plugins/oauth/constants.go +++ b/plugins/oauth/constants.go @@ -2,7 +2,10 @@ package oauth const ( oauthStateAction = "oauth_state" + formPostCookieName = "limen_oauth_form_post" + formPostQueryKey = "_limen_form_post" pkceDataKey = "pkce_verifier" + nonceDataKey = "nonce" additionalDataKey = "additional_data" linkUserIdKey = "_limen_link_user_id" redirectURIKey = "redirect_uri" diff --git a/plugins/oauth/errors.go b/plugins/oauth/errors.go index 1c83bc6..d66271f 100644 --- a/plugins/oauth/errors.go +++ b/plugins/oauth/errors.go @@ -16,5 +16,7 @@ var ( ErrPKCEVerifierNotFound = limen.NewLimenError("PKCE verifier not found", http.StatusBadRequest, nil) ErrAccountAlreadyLinkedToAnotherUser = limen.NewLimenError("this provider account is already linked to another user", http.StatusConflict, nil) ErrAccountCannotBeLinkedToDifferentEmail = limen.NewLimenError("user cannot be linked to this provider account because the email does not match", http.StatusConflict, nil) + ErrOAuthEmailNotVerified = limen.NewLimenError("provider email is not verified", http.StatusUnauthorized, nil) + ErrOAuthLocalEmailNotVerified = limen.NewLimenError("local email must be verified before account linking", http.StatusUnauthorized, nil) ErrNoRefreshToken = limen.NewLimenError("no refresh token available for this account", http.StatusBadRequest, nil) ) diff --git a/plugins/oauth/go.mod b/plugins/oauth/go.mod index 92e186b..168bed8 100644 --- a/plugins/oauth/go.mod +++ b/plugins/oauth/go.mod @@ -3,9 +3,11 @@ module github.com/thecodearcher/limen/plugins/oauth go 1.25.0 require ( + github.com/coreos/go-oidc/v3 v3.18.0 + github.com/go-jose/go-jose/v4 v4.1.4 github.com/stretchr/testify v1.11.1 github.com/thecodearcher/limen v0.1.1 - golang.org/x/oauth2 v0.35.0 + golang.org/x/oauth2 v0.36.0 ) require ( diff --git a/plugins/oauth/go.sum b/plugins/oauth/go.sum index 99bd19a..67b8ad4 100644 --- a/plugins/oauth/go.sum +++ b/plugins/oauth/go.sum @@ -1,6 +1,10 @@ +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -13,8 +17,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/plugins/oauth/handlers.go b/plugins/oauth/handlers.go index b4c8591..ec7b901 100644 --- a/plugins/oauth/handlers.go +++ b/plugins/oauth/handlers.go @@ -1,6 +1,7 @@ package oauth import ( + "fmt" "net/http" "net/url" @@ -38,9 +39,14 @@ func (h *oauthHandlers) SignInWithOAuth(w http.ResponseWriter, r *http.Request) func (h *oauthHandlers) Callback(w http.ResponseWriter, r *http.Request) { providerName := limen.GetParam(r, "provider") - code := r.URL.Query().Get("code") - state := r.URL.Query().Get("state") - callbackErr := callbackErrorFromQuery(r.URL.Query()) + callbackParams, err := h.callbackParams(w, r) + if err != nil { + h.handleCallbackResponse(w, r, nil, nil, nil, err) + return + } + code := callbackParams.Get("code") + state := callbackParams.Get("state") + callbackErr := callbackErrorFromQuery(callbackParams) cookieValue, err := h.plugin.core.Cookies().Get(r, h.plugin.config.cookieName) if err != nil { @@ -50,7 +56,7 @@ func (h *oauthHandlers) Callback(w http.ResponseWriter, r *http.Request) { h.clearStateCookie(w) - ctx := ContextWithCallbackParams(r.Context(), r.URL.Query()) + ctx := ContextWithCallbackParams(r.Context(), callbackParams) result, stateData, err := h.plugin.AuthenticateWithProvider(ctx, providerName, code, state, cookieValue, callbackErr) if err != nil { h.handleCallbackResponse(w, r, stateData, nil, nil, err) @@ -208,24 +214,62 @@ func (h *oauthHandlers) buildErrorRedirectURL(redirectURI string, err error) str } // FormPostCallback handles OAuth callbacks delivered via response_mode=form_post. -// The IdP POSTs code/state/error as application/x-www-form-urlencoded. Rather than -// processing the POST directly (which lacks cookies), we extract -// the form values and 303 redirect to the same path as a GET with query parameters. -// The browser follows the redirect as a same-site navigation, attaching cookies normally. +// The IdP POSTs code/state/error as application/x-www-form-urlencoded. Cross-site +// POST callbacks may not include SameSite=Lax cookies, so we store the POST body +// in an encrypted short-lived cookie and redirect to the GET callback with only a +// marker query parameter. The browser follows that same-site GET with cookies. func (h *oauthHandlers) FormPostCallback(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { h.responder.Error(w, r, err) return } - // Preserve all existing query params and merge all form-post params. + params := r.URL.Query() + for key, values := range r.PostForm { + params.Del(key) + for _, value := range values { + params.Add(key, value) + } + } + + encrypted, err := limen.EncryptXChaCha(params.Encode(), h.plugin.config.secret, nil) + if err != nil { + h.responder.Error(w, r, fmt.Errorf("oauth: failed to store form_post callback params: %w", err)) + return + } + h.plugin.core.Cookies().Set(w, formPostCookieName, encrypted, 60) + target := url.URL{ - Path: r.URL.Path, - RawQuery: r.Form.Encode(), + Path: r.URL.Path, } + query := target.Query() + query.Set(formPostQueryKey, "1") + target.RawQuery = query.Encode() h.responder.Redirect(w, r, target.String(), http.StatusSeeOther) } +func (h *oauthHandlers) callbackParams(w http.ResponseWriter, r *http.Request) (url.Values, error) { + if r.URL.Query().Get(formPostQueryKey) != "1" { + return r.URL.Query(), nil + } + + cookieValue, err := h.plugin.core.Cookies().Get(r, formPostCookieName) + if err != nil { + return nil, limen.NewLimenError("missing OAuth form_post callback cookie", http.StatusBadRequest, err) + } + h.plugin.core.Cookies().Delete(w, formPostCookieName) + + raw, err := limen.DecryptXChaCha(cookieValue, h.plugin.config.secret, nil) + if err != nil { + return nil, limen.NewLimenError("invalid OAuth form_post callback cookie", http.StatusBadRequest, err) + } + params, err := url.ParseQuery(raw) + if err != nil { + return nil, limen.NewLimenError("invalid OAuth form_post callback params", http.StatusBadRequest, err) + } + return params, nil +} + func (h *oauthHandlers) setStateCookie(w http.ResponseWriter, value string) { h.plugin.core.Cookies().Set(w, h.plugin.config.cookieName, value, int(h.plugin.config.cookieTTL.Seconds())) } diff --git a/plugins/oauth/handlers_test.go b/plugins/oauth/handlers_test.go index 84be188..a81d319 100644 --- a/plugins/oauth/handlers_test.go +++ b/plugins/oauth/handlers_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/thecodearcher/limen" ) func TestFormPostCallback(t *testing.T) { @@ -21,56 +23,53 @@ func TestFormPostCallback(t *testing.T) { contentType string wantStatus int wantPath string - wantQuery map[string]string - absentKeys []string + wantStoredParams map[string]string assertNoLocation bool }{ { - name: "redirects POST form params to GET query string", + name: "stores POST form params outside redirect URL", requestURL: "/oauth/test/callback", body: url.Values{"code": {"auth-code-123"}, "state": {"state-token-456"}}.Encode(), contentType: "application/x-www-form-urlencoded", wantStatus: http.StatusSeeOther, wantPath: "/oauth/test/callback", - wantQuery: map[string]string{ + wantStoredParams: map[string]string{ "code": "auth-code-123", "state": "state-token-456", }, }, { - name: "forwards error params from form body", + name: "stores error params from form body", requestURL: "/oauth/test/callback", body: url.Values{"state": {"state-token"}, "error": {"access_denied"}, "error_description": {"user canceled"}}.Encode(), contentType: "application/x-www-form-urlencoded", wantStatus: http.StatusSeeOther, wantPath: "/oauth/test/callback", - wantQuery: map[string]string{ + wantStoredParams: map[string]string{ "state": "state-token", "error": "access_denied", "error_description": "user canceled", }, - absentKeys: []string{"code"}, }, { - name: "empty form values are not included in redirect", + name: "stores state-only form body", requestURL: "/oauth/test/callback", body: url.Values{"state": {"state-only"}}.Encode(), contentType: "application/x-www-form-urlencoded", wantStatus: http.StatusSeeOther, wantPath: "/oauth/test/callback", - wantQuery: map[string]string{ + wantStoredParams: map[string]string{ "state": "state-only", }, - absentKeys: []string{"code", "error", "error_description"}, }, { - name: "preserves existing query params and appends all form params", + name: "stores existing query params and all form params", requestURL: "/oauth/test/callback?client_hint=abc&foo=bar", body: url.Values{"code": {"auth-code-123"}, "state": {"state-token-456"}, "custom_param": {"custom-value"}}.Encode(), contentType: "application/x-www-form-urlencoded", wantStatus: http.StatusSeeOther, wantPath: "/oauth/test/callback", - wantQuery: map[string]string{ + wantStoredParams: map[string]string{ "client_hint": "abc", "foo": "bar", "code": "auth-code-123", @@ -110,17 +109,54 @@ func TestFormPostCallback(t *testing.T) { parsed, err := url.Parse(location) require.NoError(t, err) assert.Equal(t, tt.wantPath, parsed.Path) - - for key, expected := range tt.wantQuery { - assert.Equal(t, expected, parsed.Query().Get(key)) + assert.Equal(t, "1", parsed.Query().Get(formPostQueryKey)) + for _, key := range []string{"code", "state", "error", "error_description", "custom_param", "user"} { + assert.False(t, parsed.Query().Has(key), "sensitive callback param %q leaked into redirect URL", key) } - for _, key := range tt.absentKeys { - assert.False(t, parsed.Query().Has(key)) + + params := decryptFormPostCookie(t, handlers, rec.Result().Cookies()) + for key, expected := range tt.wantStoredParams { + assert.Equal(t, expected, params.Get(key)) } }) } } +func TestCallbackParamsConsumesFormPostCookie(t *testing.T) { + t.Parallel() + + handlers := newOAuthHandlersForTest(t) + params := url.Values{"code": {"auth-code"}, "state": {"state-token"}} + encrypted, err := limen.EncryptXChaCha(params.Encode(), handlers.plugin.config.secret, nil) + require.NoError(t, err) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/oauth/test/callback?"+formPostQueryKey+"=1", nil) + req.AddCookie(&http.Cookie{Name: formPostCookieName, Value: encrypted}) + rec := httptest.NewRecorder() + + got, err := handlers.callbackParams(rec, req) + require.NoError(t, err) + assert.Equal(t, "auth-code", got.Get("code")) + assert.Equal(t, "state-token", got.Get("state")) + assert.Contains(t, rec.Header().Values("Set-Cookie")[0], formPostCookieName+"=") +} + +func decryptFormPostCookie(t *testing.T, handlers *oauthHandlers, cookies []*http.Cookie) url.Values { + t.Helper() + for _, cookie := range cookies { + if cookie.Name != formPostCookieName { + continue + } + raw, err := limen.DecryptXChaCha(cookie.Value, handlers.plugin.config.secret, nil) + require.NoError(t, err) + params, err := url.ParseQuery(raw) + require.NoError(t, err) + return params + } + t.Fatalf("missing %s cookie", formPostCookieName) + return nil +} + func newOAuthHandlersForTest(t *testing.T) *oauthHandlers { t.Helper() l, plugin := newTestOAuthPlugin(t) diff --git a/plugins/oauth/id_token.go b/plugins/oauth/id_token.go new file mode 100644 index 0000000..bb4a4f0 --- /dev/null +++ b/plugins/oauth/id_token.go @@ -0,0 +1,69 @@ +package oauth + +import ( + "context" + "fmt" + "sync" + + "github.com/coreos/go-oidc/v3/oidc" +) + +type IDTokenVerifier func(ctx context.Context, idToken string) (map[string]any, error) + +func NewIDTokenVerifier(issuer, clientID string) IDTokenVerifier { + var mu sync.Mutex + var verifier *oidc.IDTokenVerifier + + return func(ctx context.Context, idToken string) (map[string]any, error) { + mu.Lock() + if verifier == nil { + created, err := newOIDCIDTokenVerifier(ctx, issuer, clientID) + if err != nil { + mu.Unlock() + return nil, err + } + verifier = created + } + current := verifier + mu.Unlock() + + return verifyIDTokenClaims(ctx, current, idToken) + } +} + +func VerifyIDTokenClaims(ctx context.Context, issuer, clientID, idToken string) (map[string]any, error) { + verifier, err := newOIDCIDTokenVerifier(ctx, issuer, clientID) + if err != nil { + return nil, err + } + return verifyIDTokenClaims(ctx, verifier, idToken) +} + +func newOIDCIDTokenVerifier(ctx context.Context, issuer, clientID string) (*oidc.IDTokenVerifier, error) { + if issuer == "" { + return nil, fmt.Errorf("id token issuer is required") + } + if clientID == "" { + return nil, fmt.Errorf("id token client ID is required") + } + + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, fmt.Errorf("id token provider discovery failed: %w", err) + } + + return provider.Verifier(&oidc.Config{ClientID: clientID}), nil +} + +func verifyIDTokenClaims(ctx context.Context, verifier *oidc.IDTokenVerifier, idToken string) (map[string]any, error) { + verified, err := verifier.Verify(ctx, idToken) + if err != nil { + return nil, fmt.Errorf("id token verification failed: %w", err) + } + + var claims map[string]any + if err := verified.Claims(&claims); err != nil { + return nil, fmt.Errorf("id token claims decode failed: %w", err) + } + return claims, nil +} diff --git a/plugins/oauth/id_token_test.go b/plugins/oauth/id_token_test.go new file mode 100644 index 0000000..4fab6f0 --- /dev/null +++ b/plugins/oauth/id_token_test.go @@ -0,0 +1,120 @@ +package oauth + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + gojose "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" +) + +func TestVerifyIDTokenClaims_VerifiesDiscoveredJWKS(t *testing.T) { + t.Parallel() + + issuer, token := newTestOIDCProviderAndToken(t, "client-id", time.Now().Add(time.Hour)) + + claims, err := VerifyIDTokenClaims(t.Context(), issuer, "client-id", token) + if err != nil { + t.Fatalf("VerifyIDTokenClaims: %v", err) + } + if claims["sub"] != "user-1" || claims["email"] != "user@example.com" || claims["email_verified"] != true { + t.Fatalf("unexpected claims: %#v", claims) + } +} + +func TestVerifyIDTokenClaims_RejectsWrongAudience(t *testing.T) { + t.Parallel() + + issuer, token := newTestOIDCProviderAndToken(t, "other-client-id", time.Now().Add(time.Hour)) + + _, err := VerifyIDTokenClaims(t.Context(), issuer, "client-id", token) + if err == nil { + t.Fatal("expected wrong audience to be rejected") + } +} + +func TestVerifyIDTokenClaims_RejectsExpiredToken(t *testing.T) { + t.Parallel() + + issuer, token := newTestOIDCProviderAndToken(t, "client-id", time.Now().Add(-time.Hour)) + + _, err := VerifyIDTokenClaims(t.Context(), issuer, "client-id", token) + if err == nil { + t.Fatal("expected expired token to be rejected") + } +} + +func newTestOIDCProviderAndToken(t *testing.T, audience string, expiresAt time.Time) (string, string) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + publicJWK := gojose.JSONWebKey{ + Key: &privateKey.PublicKey, + KeyID: "test-key", + Algorithm: string(gojose.RS256), + Use: "sig", + } + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + writeJSON(t, w, map[string]any{ + "issuer": server.URL, + "jwks_uri": server.URL + "/jwks", + "id_token_signing_alg_values_supported": []string{"RS256"}, + }) + case "/jwks": + writeJSON(t, w, map[string]any{ + "keys": []gojose.JSONWebKey{publicJWK}, + }) + default: + http.NotFound(w, r) + } + })) + t.Cleanup(server.Close) + + signer, err := gojose.NewSigner( + gojose.SigningKey{Algorithm: gojose.RS256, Key: privateKey}, + (&gojose.SignerOptions{}).WithHeader("kid", "test-key").WithType("JWT"), + ) + if err != nil { + t.Fatalf("new signer: %v", err) + } + + token, err := jwt.Signed(signer). + Claims(jwt.Claims{ + Issuer: server.URL, + Subject: "user-1", + Audience: jwt.Audience{audience}, + Expiry: jwt.NewNumericDate(expiresAt), + IssuedAt: jwt.NewNumericDate(time.Now()), + }). + Claims(map[string]any{ + "email": "user@example.com", + "email_verified": true, + }). + Serialize() + if err != nil { + t.Fatalf("sign token: %v", err) + } + + return server.URL, token +} + +func writeJSON(t *testing.T, w http.ResponseWriter, payload any) { + t.Helper() + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(payload); err != nil { + t.Fatalf("writeJSON: %v", err) + } +} diff --git a/plugins/oauth/provider.go b/plugins/oauth/provider.go index 7717e85..471a7ea 100644 --- a/plugins/oauth/provider.go +++ b/plugins/oauth/provider.go @@ -62,3 +62,10 @@ const ( type ResponseModeProvider interface { ResponseMode() ResponseMode } + +// IDTokenNonceProvider is optional. If a Provider implements it and returns true, +// the base module sends a nonce authorization parameter and exposes the expected +// nonce to GetUserInfo for provider-side ID token claim validation. +type IDTokenNonceProvider interface { + IDTokenNonceEnabled() bool +} diff --git a/plugins/oauth/testutil_test.go b/plugins/oauth/testutil_test.go index ca86327..8806fba 100644 --- a/plugins/oauth/testutil_test.go +++ b/plugins/oauth/testutil_test.go @@ -63,6 +63,18 @@ func seedOAuthUser(t *testing.T, l *limen.Limen, email string) *limen.User { return limen.SeedTestUser(t, l, email) } +func verifyOAuthUser(t *testing.T, plugin *oauthPlugin, user *limen.User) { + t.Helper() + ctx := context.Background() + now := time.Now() + if err := plugin.core.DBAction.UpdateUser(ctx, &limen.User{EmailVerifiedAt: &now}, []limen.Where{ + limen.Eq(plugin.core.Schema.User.GetIDField(), user.ID), + }); err != nil { + t.Fatalf("verifyOAuthUser: %v", err) + } + user.EmailVerifiedAt = &now +} + func seedOAuthAccount(t *testing.T, plugin *oauthPlugin, userID any, provider, providerAccountID string) { t.Helper() ctx := context.Background() diff --git a/plugins/oauth/util.go b/plugins/oauth/util.go index fa90231..ebe6caa 100644 --- a/plugins/oauth/util.go +++ b/plugins/oauth/util.go @@ -3,6 +3,8 @@ package oauth import ( "context" "crypto/rand" + "crypto/sha256" + "crypto/subtle" "encoding/base64" "encoding/hex" "encoding/json" @@ -18,6 +20,7 @@ import ( ) type callbackParamsContextKey struct{} +type idTokenNonceContextKey struct{} // ContextWithCallbackParams returns a child context carrying the raw callback // query parameters. Providers can retrieve them via CallbackParams inside @@ -32,6 +35,38 @@ func CallbackParams(ctx context.Context) url.Values { return v } +// ContextWithIDTokenNonce returns a child context carrying the expected OIDC nonce. +func ContextWithIDTokenNonce(ctx context.Context, nonce string) context.Context { + return context.WithValue(ctx, idTokenNonceContextKey{}, nonce) +} + +// IDTokenNonce retrieves the expected OIDC nonce stored in ctx, or an empty string. +func IDTokenNonce(ctx context.Context) string { + v, _ := ctx.Value(idTokenNonceContextKey{}).(string) + return v +} + +// VerifyIDTokenNonce verifies the nonce claim against the expected OAuth state nonce. +// Apple may return the SHA-256 hex digest of the sent nonce; raw equality is accepted too. +func VerifyIDTokenNonce(claims map[string]any, expected string) error { + if expected == "" { + return fmt.Errorf("id token nonce is required") + } + claim, _ := claims["nonce"].(string) + if claim == "" { + return fmt.Errorf("id token missing nonce claim") + } + if subtle.ConstantTimeCompare([]byte(claim), []byte(expected)) == 1 { + return nil + } + sum := sha256.Sum256([]byte(expected)) + digest := hex.EncodeToString(sum[:]) + if subtle.ConstantTimeCompare([]byte(claim), []byte(digest)) == 1 { + return nil + } + return fmt.Errorf("id token nonce mismatch") +} + // BuildAuthCodeURL builds the OAuth2 authorization URL using the provider's config. // state and verifier are required for CSRF and PKCE; authOpts add provider-specific params (e.g. AccessTypeOffline). func BuildAuthCodeURL(config *oauth2.Config, state, verifier string, authOpts ...oauth2.AuthCodeOption) string { @@ -152,14 +187,18 @@ func DecodeIDTokenClaims(idToken string) (map[string]any, error) { // generateCodeVerifier creates a cryptographically random PKCE code_verifier func generateCodeVerifier() string { b := make([]byte, 32) - rand.Read(b) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("oauth: crypto random read failed: %v", err)) + } return base64.RawURLEncoding.EncodeToString(b) } // generateRandomString generates a cryptographically secure random string func generateRandomString() string { randomBytes := make([]byte, 32) - rand.Read(randomBytes) + if _, err := rand.Read(randomBytes); err != nil { + panic(fmt.Sprintf("oauth: crypto random read failed: %v", err)) + } return hex.EncodeToString(randomBytes) } diff --git a/plugins/session-jwt/access_token.go b/plugins/session-jwt/access_token.go index 9eeb9e3..718c9bc 100644 --- a/plugins/session-jwt/access_token.go +++ b/plugins/session-jwt/access_token.go @@ -98,10 +98,17 @@ func (p *sessionJWTPlugin) performRefresh(ctx context.Context, rawRefreshToken s return nil, nil, ErrInvalidRefreshToken } - if active, _ := p.FamilyHasActiveTokens(ctx, family); active { - _ = p.DeleteRefreshTokenFamily(ctx, family) + active, activeErr := p.FamilyHasActiveTokens(ctx, family) + if activeErr != nil { + return nil, nil, activeErr + } + if active { + if err := p.DeleteRefreshTokenFamily(ctx, family); err != nil { + return nil, nil, err + } return nil, nil, ErrRefreshTokenReuse } + return nil, nil, ErrInvalidRefreshToken } if time.Now().After(rt.ExpiresAt) { diff --git a/plugins/session-jwt/access_token_test.go b/plugins/session-jwt/access_token_test.go index 903c689..aa12d7b 100644 --- a/plugins/session-jwt/access_token_test.go +++ b/plugins/session-jwt/access_token_test.go @@ -7,6 +7,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thecodearcher/limen" ) @@ -143,3 +144,40 @@ func TestParseAccessTokenLenient_ExpiredButValid(t *testing.T) { assert.NotNil(t, claims) assert.Equal(t, "user-1", claims.Subject) } + +func TestPerformRefresh_MissingTokenWithFamilyAndNoActiveTokens(t *testing.T) { + t.Parallel() + + plugin := New() + limen.NewTestLimen(t, plugin) + + _, _, err := plugin.performRefresh(t.Context(), "missing-refresh-token", "family-1") + + assert.ErrorIs(t, err, ErrInvalidRefreshToken) +} + +func TestPerformRefresh_MissingTokenWithActiveFamilyRevokesFamily(t *testing.T) { + t.Parallel() + + plugin := New() + limen.NewTestLimen(t, plugin) + _, err := plugin.CreateRefreshToken(t.Context(), "user-1", "jti-1", "family-1", nil) + require.NoError(t, err) + + _, _, err = plugin.performRefresh(t.Context(), "missing-refresh-token", "family-1") + + assert.ErrorIs(t, err, ErrRefreshTokenReuse) + active, err := plugin.FamilyHasActiveTokens(t.Context(), "family-1") + require.NoError(t, err) + assert.False(t, active) +} + +func TestRotateRefreshToken_NilOldToken(t *testing.T) { + t.Parallel() + + plugin := newTestPlugin() + + _, err := plugin.RotateRefreshToken(t.Context(), nil, "new-jti") + + assert.ErrorIs(t, err, ErrInvalidRefreshToken) +} diff --git a/plugins/session-jwt/refresh_token.go b/plugins/session-jwt/refresh_token.go index 6dfe42a..3bfc233 100644 --- a/plugins/session-jwt/refresh_token.go +++ b/plugins/session-jwt/refresh_token.go @@ -92,6 +92,10 @@ func (p *sessionJWTPlugin) FamilyHasActiveTokens(ctx context.Context, family str // RotateRefreshToken deletes the old refresh token and creates a new one in // the same family. Returns the new refresh token. func (p *sessionJWTPlugin) RotateRefreshToken(ctx context.Context, old *RefreshToken, newJWTID string) (*RefreshToken, error) { + if old == nil { + return nil, ErrInvalidRefreshToken + } + var newRT *RefreshToken err := p.core.WithTransaction(ctx, func(txCtx context.Context) error { if err := p.core.Delete(txCtx, p.refreshTokenSchema, []limen.Where{ diff --git a/plugins/session-jwt/util.go b/plugins/session-jwt/util.go index e3c40fc..19fe32e 100644 --- a/plugins/session-jwt/util.go +++ b/plugins/session-jwt/util.go @@ -3,6 +3,7 @@ package sessionjwt import ( "crypto/rand" "encoding/base64" + "fmt" "net/http" "strings" @@ -11,7 +12,9 @@ import ( func generateOpaqueToken() string { b := make([]byte, 32) - rand.Read(b) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("session-jwt: crypto random read failed: %v", err)) + } return base64.RawURLEncoding.EncodeToString(b) } @@ -29,7 +32,9 @@ func parseRefreshTokenValue(raw string) (token, family string) { func generateJTI() string { b := make([]byte, 16) - rand.Read(b) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("session-jwt: crypto random read failed: %v", err)) + } return base64.RawURLEncoding.EncodeToString(b) } diff --git a/rate_limiter_test.go b/rate_limiter_test.go index 4064f12..b7dc0d3 100644 --- a/rate_limiter_test.go +++ b/rate_limiter_test.go @@ -108,6 +108,18 @@ func TestRateLimiter_Check_WindowReset(t *testing.T) { }) } +func TestDefaultRateLimiterKeyGenerator_UsesRemoteAddr(t *testing.T) { + t.Parallel() + + config := NewDefaultRateLimiterConfig() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) + req.RemoteAddr = "203.0.113.10:1234" + req.Header.Set("X-Forwarded-For", "198.51.100.10") + req.Header.Set("X-Real-IP", "198.51.100.11") + + assert.Equal(t, "203.0.113.10", config.KeyGenerator(req)) +} + // --------------------------------------------------------------------------- // Rule matching // --------------------------------------------------------------------------- diff --git a/session_config.go b/session_config.go index 94d84e9..5e2c159 100644 --- a/session_config.go +++ b/session_config.go @@ -2,6 +2,7 @@ package limen import ( "fmt" + "net" "net/http" "time" ) @@ -56,6 +57,14 @@ func NewDefaultSessionConfig(opts ...SessionConfigOption) *sessionConfig { return config } +func ipExtractorFromRemoteAddr(request *http.Request) string { + ip, _, err := net.SplitHostPort(request.RemoteAddr) + if err != nil { + return request.RemoteAddr + } + return ip +} + func (c *sessionConfig) validate() error { if c.UpdateAge > c.Duration { return fmt.Errorf("update age cannot be greater than duration") diff --git a/session_config_test.go b/session_config_test.go new file mode 100644 index 0000000..5352710 --- /dev/null +++ b/session_config_test.go @@ -0,0 +1,31 @@ +package limen + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultSessionIPAddressExtractor_UsesRemoteAddr(t *testing.T) { + t.Parallel() + + config := NewDefaultSessionConfig() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) + req.RemoteAddr = "203.0.113.20:1234" + req.Header.Set("X-Forwarded-For", "198.51.100.20") + req.Header.Set("X-Real-IP", "198.51.100.21") + + assert.Equal(t, "203.0.113.20", config.IPAddressExtractor(req)) +} + +func TestDefaultSessionIPAddressExtractor_AllowsRemoteAddrWithoutPort(t *testing.T) { + t.Parallel() + + config := NewDefaultSessionConfig() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) + req.RemoteAddr = "203.0.113.20" + + assert.Equal(t, "203.0.113.20", config.IPAddressExtractor(req)) +} diff --git a/utils.go b/utils.go index 6559fec..97a5d40 100644 --- a/utils.go +++ b/utils.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "log" - "net" "net/http" "net/url" "os" @@ -34,7 +33,9 @@ var ( // generateCryptoSecureRandomString generates a cryptographically secure random string func generateCryptoSecureRandomString() string { buf := make([]byte, 32) - _, _ = rand.Read(buf) + if _, err := rand.Read(buf); err != nil { + panic(fmt.Sprintf("crypto random read failed: %v", err)) + } return base64.RawURLEncoding.EncodeToString(buf) } @@ -46,24 +47,15 @@ func GenerateRandomString(length int, charSetType ...CharSetType) string { charCount := len(chars) expectedBytes := make([]byte, length) - _, _ = rand.Read(expectedBytes) + if _, err := rand.Read(expectedBytes); err != nil { + panic(fmt.Sprintf("crypto random read failed: %v", err)) + } for i := range length { expectedBytes[i] = chars[int(expectedBytes[i])%charCount] } return string(expectedBytes) } -func ipExtractorFromRemoteAddr(request *http.Request) string { - if ip := request.Header.Get("X-Forwarded-For"); ip != "" { - return ip - } - if ip := request.Header.Get("X-Real-IP"); ip != "" { - return ip - } - ip, _, _ := net.SplitHostPort(request.RemoteAddr) - return ip -} - // compileRateLimitPattern compiles a rate limit pattern to a regex // Returns the compiled regex and an error if compilation fails func compileRateLimitPattern(pattern string) (*regexp.Regexp, error) {