diff --git a/service/internal/auth/authz/casbin/casbin.go b/service/internal/auth/authz/casbin/casbin.go index 24657e1565..fc63dd999d 100644 --- a/service/internal/auth/authz/casbin/casbin.go +++ b/service/internal/auth/authz/casbin/casbin.go @@ -360,6 +360,13 @@ func (a *Authorizer) extractSubjects(req *authz.Request) []string { if username := a.extractUsernameFromToken(req.Token); username != "" { subjects = append(subjects, username) } + + // Treat the configured OAuth client ID claim as a first-class subject. + // This allows policy to authorize service clients directly, e.g. + // p, kas-a, /policy.kasregistry.KeyAccessServerRegistryService/ListKeys, kas_uri=http://kas-a, allow + if clientID := a.extractClientIDFromToken(req.Token); clientID != "" { + subjects = append(subjects, clientID) + } } // Extract roles from userInfo @@ -375,6 +382,33 @@ func (a *Authorizer) extractSubjects(req *authz.Request) []string { return subjects } +// extractClientIDFromToken extracts and validates the configured OAuth client ID claim. +func (a *Authorizer) extractClientIDFromToken(token jwt.Token) string { + if token == nil || a.baseConfig.ClientIDClaim == "" { + return "" + } + + claim, found := token.Get(a.baseConfig.ClientIDClaim) + if !found { + return "" + } + + clientID, ok := claim.(string) + if !ok || clientID == "" { + return "" + } + + if strings.HasPrefix(clientID, rolePrefix) { + a.logger.Warn("ignoring client ID subject with reserved role prefix", + slog.String("claim", a.baseConfig.ClientIDClaim), + slog.String("prefix", rolePrefix), + ) + return "" + } + + return clientID +} + // extractUsernameFromToken extracts and validates username subject from token. func (a *Authorizer) extractUsernameFromToken(token jwt.Token) string { if token == nil || a.baseConfig.UserNameClaim == "" { diff --git a/service/internal/auth/authz/casbin/casbin_test.go b/service/internal/auth/authz/casbin/casbin_test.go index 6c275db891..26985e61ef 100644 --- a/service/internal/auth/authz/casbin/casbin_test.go +++ b/service/internal/auth/authz/casbin/casbin_test.go @@ -862,6 +862,99 @@ func (s *CasbinAuthorizerSuite) TestAuthorizeV1_PathHandlingHeuristic() { s.Equal("/http/path", receivedResources[1], "HTTP path should keep leading slash") } +func (s *CasbinAuthorizerSuite) TestAuthorizeV2_ClientIDSubjectKASKeyScope() { + cfg := authz.Config{ + Version: "v2", + PolicyConfig: authz.PolicyConfig{ + ClientIDClaim: "azp", + Csv: `p, kas-a, /policy.kasregistry.KeyAccessServerRegistryService/ListKeys, kas_uri=http://localhost:9081, allow +p, kas-a, /policy.kasregistry.KeyAccessServerRegistryService/GetKey, kas_uri=http://localhost:9081, allow +p, kas-b, /policy.kasregistry.KeyAccessServerRegistryService/ListKeys, kas_uri=http://localhost:9082, allow +p, kas-b, /policy.kasregistry.KeyAccessServerRegistryService/GetKey, kas_uri=http://localhost:9082, allow`, + }, + Logger: s.logger, + } + + authorizer, err := NewAuthorizer(cfg) + s.Require().NoError(err) + + kasAToken := createTestToken(s.T(), map[string]interface{}{"azp": "kas-a"}) + kasBToken := createTestToken(s.T(), map[string]interface{}{"azp": "kas-b"}) + + kasAKeysReq := &authz.Request{ + Token: kasAToken, + RPC: "/policy.kasregistry.KeyAccessServerRegistryService/ListKeys", + ResourceContext: &authz.ResolverContext{ + Resources: []*authz.ResolverResource{ + {"kas_uri": "http://localhost:9081"}, + }, + }, + } + decision, err := authorizer.Authorize(context.Background(), kasAKeysReq) + s.Require().NoError(err) + s.True(decision.Allowed, "kas-a should list kas-a keys") + s.Equal("kas-a", decision.MatchedPolicy) + + kasBKeysReq := &authz.Request{ + Token: kasAToken, + RPC: "/policy.kasregistry.KeyAccessServerRegistryService/ListKeys", + ResourceContext: &authz.ResolverContext{ + Resources: []*authz.ResolverResource{ + {"kas_uri": "http://localhost:9082"}, + }, + }, + } + decision, err = authorizer.Authorize(context.Background(), kasBKeysReq) + s.Require().NoError(err) + s.False(decision.Allowed, "kas-a should not list kas-b keys") + + unscopedListReq := &authz.Request{ + Token: kasAToken, + RPC: "/policy.kasregistry.KeyAccessServerRegistryService/ListKeys", + } + decision, err = authorizer.Authorize(context.Background(), unscopedListReq) + s.Require().NoError(err) + s.False(decision.Allowed, "kas-a should not perform an unscoped key list") + + kasBGetReq := &authz.Request{ + Token: kasBToken, + RPC: "/policy.kasregistry.KeyAccessServerRegistryService/GetKey", + ResourceContext: &authz.ResolverContext{ + Resources: []*authz.ResolverResource{ + {"kas_uri": "http://localhost:9082"}, + }, + }, + } + decision, err = authorizer.Authorize(context.Background(), kasBGetReq) + s.Require().NoError(err) + s.True(decision.Allowed, "kas-b should get kas-b keys") + s.Equal("kas-b", decision.MatchedPolicy) +} + +func (s *CasbinAuthorizerSuite) TestAuthorizeV2_ClientIDSubjectWithReservedRolePrefixIsIgnored() { + cfg := authz.Config{ + Version: "v2", + PolicyConfig: authz.PolicyConfig{ + ClientIDClaim: "azp", + Csv: `p, role:admin, /policy.kasregistry.KeyAccessServerRegistryService/ListKeys, *, allow`, + }, + Logger: s.logger, + } + + authorizer, err := NewAuthorizer(cfg) + s.Require().NoError(err) + + token := createTestToken(s.T(), map[string]interface{}{"azp": "role:admin"}) + req := &authz.Request{ + Token: token, + RPC: "/policy.kasregistry.KeyAccessServerRegistryService/ListKeys", + } + + decision, err := authorizer.Authorize(context.Background(), req) + s.Require().NoError(err) + s.False(decision.Allowed, "client ID with reserved role prefix must not match role subjects") +} + // Helper function to create test JWT tokens func createTestToken(t *testing.T, claims map[string]interface{}) jwt.Token { t.Helper() diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 5888d821b0..949b1401fc 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -22,6 +22,7 @@ import ( "github.com/opentdf/platform/sdk" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/internal/auth" + "github.com/opentdf/platform/service/internal/auth/authz" "github.com/opentdf/platform/service/internal/security" "github.com/opentdf/platform/service/internal/server/memhttp" "github.com/opentdf/platform/service/logger" @@ -48,6 +49,18 @@ func (e Error) Error() string { return string(e) } +type openTDFServerOptions struct { + authzResolverRegistry *authz.ResolverRegistry +} + +type OpenTDFServerOption func(*openTDFServerOptions) + +func WithAuthzResolverRegistry(registry *authz.ResolverRegistry) OpenTDFServerOption { + return func(options *openTDFServerOptions) { + options.authzResolverRegistry = registry + } +} + // Configurations for the server type Config struct { Auth auth.Config `mapstructure:"auth" json:"auth"` @@ -251,20 +264,29 @@ type inProcessServer struct { *ConnectRPC } -func NewOpenTDFServer(config Config, logger *logger.Logger, cacheManager *cache.Manager) (*OpenTDFServer, error) { +func NewOpenTDFServer(config Config, logger *logger.Logger, cacheManager *cache.Manager, opts ...OpenTDFServerOption) (*OpenTDFServer, error) { var ( authN *auth.Authentication err error ) + options := openTDFServerOptions{} + for _, opt := range opts { + opt(&options) + } // Add authN interceptor // TODO Remove this conditional once we move to the hardening phase (https://github.com/opentdf/platform/issues/381) if config.Auth.Enabled { + authOpts := []auth.AuthenticatorOption{} + if options.authzResolverRegistry != nil { + authOpts = append(authOpts, auth.WithAuthzResolverRegistry(options.authzResolverRegistry)) + } authN, err = auth.NewAuthenticator( context.Background(), config.Auth, logger, config.WellKnownConfigRegister, + authOpts..., ) if err != nil { return nil, fmt.Errorf("failed to create authentication interceptor: %w", err) diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 4bb168832c..e50d2a6a25 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -177,10 +177,14 @@ func Start(f ...StartOptions) error { cfg.Server.CORS.AdditionalExposedHeaders = append(cfg.Server.CORS.AdditionalExposedHeaders, startConfig.additionalCORSExposedHeaders...) } + // Create the global authz resolver registry. + // It is shared by the auth interceptor and by services registering scoped resolvers. + authzResolverRegistry := authz.NewResolverRegistry() + // Create new server for grpc & http. Also will support in process grpc potentially too logger.Debug("initializing opentdf server") cfg.Server.WellKnownConfigRegister = wellknown.RegisterConfiguration - otdf, err := server.NewOpenTDFServer(cfg.Server, logger, cacheManager) + otdf, err := server.NewOpenTDFServer(cfg.Server, logger, cacheManager, server.WithAuthzResolverRegistry(authzResolverRegistry)) if err != nil { logger.Error("issue creating opentdf server", slog.String("error", err.Error())) return fmt.Errorf("issue creating opentdf server: %w", err) @@ -271,10 +275,6 @@ func Start(f ...StartOptions) error { defer client.Close() - // Create the global authz resolver registry - // Services will receive scoped registries that can only register resolvers for their own methods - authzResolverRegistry := authz.NewResolverRegistry() - logger.Info("starting services") gatewayCleanup, err := startServices(ctx, startServicesParams{ cfg: cfg, diff --git a/service/policy/kasregistry/key_access_server_registry.go b/service/policy/kasregistry/key_access_server_registry.go index 53d12a84c6..dcf323db29 100644 --- a/service/policy/kasregistry/key_access_server_registry.go +++ b/service/policy/kasregistry/key_access_server_registry.go @@ -10,6 +10,7 @@ import ( "github.com/opentdf/platform/protocol/go/policy" kasr "github.com/opentdf/platform/protocol/go/policy/kasregistry" "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" + "github.com/opentdf/platform/service/internal/auth/authz" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/config" @@ -30,6 +31,12 @@ var ( ErrUnsupportedCurve = errors.New("unsupported curve") ) +const ( + authzDimKASID = "kas_id" + authzDimKASName = "kas_name" + authzDimKASURI = "kas_uri" +) + type KeyAccessServerRegistry struct { dbClient policydb.PolicyDBClient logger *logger.Logger @@ -80,18 +87,272 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer } kasrSvc.config = cfg + kasrSvc.registerAuthzResolvers(srp.AuthzResolverRegistry) return kasrSvc, nil }, }, } } +func (s *KeyAccessServerRegistry) registerAuthzResolvers(registry *authz.ScopedResolverRegistry) { + if registry == nil { + return + } + + registry.MustRegister("GetKeyAccessServer", s.getKeyAccessServerAuthzResolver) + registry.MustRegister("CreateKey", s.createKeyAuthzResolver) + registry.MustRegister("GetKey", s.getKeyAuthzResolver) + registry.MustRegister("ListKeys", s.listKeysAuthzResolver) + registry.MustRegister("UpdateKey", s.updateKeyAuthzResolver) + registry.MustRegister("RotateKey", s.rotateKeyAuthzResolver) + registry.MustRegister("SetBaseKey", s.setBaseKeyAuthzResolver) + registry.MustRegister("ListKeyMappings", s.listKeyMappingsAuthzResolver) +} + // Close gracefully shuts down the service, closing the database client. func (s *KeyAccessServerRegistry) Close() { s.logger.Info("gracefully shutting down key access server registry service") s.dbClient.Close() } +func (s KeyAccessServerRegistry) getKeyAccessServerAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.GetKeyAccessServerRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + if msg.GetId() != "" { //nolint:staticcheck // Id can still be used until removed + if err := s.addResolvedKASDimensions(ctx, res, msg.GetId()); err != nil { //nolint:staticcheck // Id can still be used until removed + return resolverCtx, err + } + return resolverCtx, nil + } + + if err := s.addResolvedKASDimensions(ctx, res, msg.GetIdentifier()); err != nil { + return resolverCtx, err + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) createKeyAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.CreateKeyRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + if err := s.addResolvedKASDimensions(ctx, res, &kasr.GetKeyAccessServerRequest_KasId{KasId: msg.GetKasId()}); err != nil { + return resolverCtx, err + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) getKeyAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.GetKeyRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + if err := s.addKeyRequestDimensions(ctx, res, msg.GetIdentifier()); err != nil { + return resolverCtx, err + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) listKeysAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.ListKeysRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + if err := s.addListKeysFilterDimensions(ctx, res, msg.GetKasFilter()); err != nil { + return resolverCtx, err + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) updateKeyAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.UpdateKeyRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + if err := s.addKeyDimensionsByID(ctx, res, msg.GetId()); err != nil { + return resolverCtx, err + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) rotateKeyAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.RotateKeyRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + switch active := msg.GetActiveKey().(type) { + case *kasr.RotateKeyRequest_Id: + if err := s.addKeyDimensionsByID(ctx, res, active.Id); err != nil { + return resolverCtx, err + } + case *kasr.RotateKeyRequest_Key: + if err := s.addKASKeyIdentifierDimensions(ctx, res, active.Key); err != nil { + return resolverCtx, err + } + default: + return resolverCtx, errors.New("no active key identifier provided") + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) setBaseKeyAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.SetBaseKeyRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + switch active := msg.GetActiveKey().(type) { + case *kasr.SetBaseKeyRequest_Id: + if err := s.addKeyDimensionsByID(ctx, res, active.Id); err != nil { + return resolverCtx, err + } + case *kasr.SetBaseKeyRequest_Key: + if err := s.addKASKeyIdentifierDimensions(ctx, res, active.Key); err != nil { + return resolverCtx, err + } + default: + return resolverCtx, errors.New("no active key identifier provided") + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) listKeyMappingsAuthzResolver(ctx context.Context, req connect.AnyRequest) (authz.ResolverContext, error) { + resolverCtx := authz.NewResolverContext() + msg, ok := req.Any().(*kasr.ListKeyMappingsRequest) + if !ok { + return resolverCtx, fmt.Errorf("unexpected request type: %T", req.Any()) + } + + res := resolverCtx.NewResource() + switch identifier := msg.GetIdentifier().(type) { + case *kasr.ListKeyMappingsRequest_Id: + if err := s.addKeyDimensionsByID(ctx, res, identifier.Id); err != nil { + return resolverCtx, err + } + case *kasr.ListKeyMappingsRequest_Key: + if err := s.addKASKeyIdentifierDimensions(ctx, res, identifier.Key); err != nil { + return resolverCtx, err + } + case nil: + // No dimensions means only wildcard-dimension policy can list all mappings. + default: + return resolverCtx, fmt.Errorf("unexpected key mapping identifier type: %T", identifier) + } + + return resolverCtx, nil +} + +func (s KeyAccessServerRegistry) addKeyRequestDimensions(ctx context.Context, res *authz.ResolverResource, identifier any) error { + switch keyIdentifier := identifier.(type) { + case *kasr.GetKeyRequest_Id: + return s.addKeyDimensionsByID(ctx, res, keyIdentifier.Id) + case *kasr.GetKeyRequest_Key: + return s.addKASKeyIdentifierDimensions(ctx, res, keyIdentifier.Key) + default: + return fmt.Errorf("unexpected key identifier type: %T", identifier) + } +} + +func (s KeyAccessServerRegistry) addKeyDimensionsByID(ctx context.Context, res *authz.ResolverResource, id string) error { + key, err := s.dbClient.GetKey(ctx, &kasr.GetKeyRequest_Id{Id: id}) + if err != nil { + return fmt.Errorf("failed to resolve key for authz: %w", err) + } + addKASKeyDimensions(res, key) + return nil +} + +func (s KeyAccessServerRegistry) addKASKeyIdentifierDimensions(ctx context.Context, res *authz.ResolverResource, identifier *kasr.KasKeyIdentifier) error { + if identifier == nil { + return errors.New("key identifier is required") + } + + switch kasIdentifier := identifier.GetIdentifier().(type) { + case *kasr.KasKeyIdentifier_KasId: + return s.addResolvedKASDimensions(ctx, res, &kasr.GetKeyAccessServerRequest_KasId{KasId: kasIdentifier.KasId}) + case *kasr.KasKeyIdentifier_Name: + return s.addResolvedKASDimensions(ctx, res, &kasr.GetKeyAccessServerRequest_Name{Name: kasIdentifier.Name}) + case *kasr.KasKeyIdentifier_Uri: + return s.addResolvedKASDimensions(ctx, res, &kasr.GetKeyAccessServerRequest_Uri{Uri: kasIdentifier.Uri}) + default: + return fmt.Errorf("unexpected KAS identifier type: %T", kasIdentifier) + } +} + +func (s KeyAccessServerRegistry) addListKeysFilterDimensions(ctx context.Context, res *authz.ResolverResource, filter any) error { + switch kasFilter := filter.(type) { + case *kasr.ListKeysRequest_KasId: + return s.addResolvedKASDimensions(ctx, res, &kasr.GetKeyAccessServerRequest_KasId{KasId: kasFilter.KasId}) + case *kasr.ListKeysRequest_KasName: + return s.addResolvedKASDimensions(ctx, res, &kasr.GetKeyAccessServerRequest_Name{Name: kasFilter.KasName}) + case *kasr.ListKeysRequest_KasUri: + return s.addResolvedKASDimensions(ctx, res, &kasr.GetKeyAccessServerRequest_Uri{Uri: kasFilter.KasUri}) + case nil: + // No dimensions means only wildcard-dimension policy can list all keys. + return nil + default: + return fmt.Errorf("unexpected KAS filter type: %T", kasFilter) + } +} + +func (s KeyAccessServerRegistry) addResolvedKASDimensions(ctx context.Context, res *authz.ResolverResource, identifier any) error { + kas, err := s.dbClient.GetKeyAccessServer(ctx, identifier) + if err != nil { + return fmt.Errorf("failed to resolve KAS for authz: %w", err) + } + + addKASDimensions(res, kas.GetId(), kas.GetName(), kas.GetUri()) + return nil +} + +func addKASKeyDimensions(res *authz.ResolverResource, key *policy.KasKey) { + if key == nil { + return + } + addKASDimensions(res, key.GetKasId(), "", key.GetKasUri()) +} + +func addKASDimensions(res *authz.ResolverResource, id, name, uri string) { + if id != "" { + res.AddDimension(authzDimKASID, id) + } + if name != "" { + res.AddDimension(authzDimKASName, name) + } + if uri != "" { + res.AddDimension(authzDimKASURI, uri) + } +} + func (s KeyAccessServerRegistry) CreateKeyAccessServer(ctx context.Context, req *connect.Request[kasr.CreateKeyAccessServerRequest], ) (*connect.Response[kasr.CreateKeyAccessServerResponse], error) {