diff --git a/example/go.mod b/example/go.mod index b02fbb93..0369bc0e 100644 --- a/example/go.mod +++ b/example/go.mod @@ -151,6 +151,9 @@ require ( github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.2.0 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect diff --git a/example/go.sum b/example/go.sum index a907e5fd..f97c46e6 100644 --- a/example/go.sum +++ b/example/go.sum @@ -407,6 +407,12 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= +github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= @@ -499,6 +505,7 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= diff --git a/go.mod b/go.mod index 68ef4e93..db1134f3 100644 --- a/go.mod +++ b/go.mod @@ -207,6 +207,9 @@ require ( github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect github.com/tliron/commonlog v0.2.8 // indirect github.com/tliron/kutil v0.3.11 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.2.0 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect diff --git a/go.sum b/go.sum index add50e32..97979362 100644 --- a/go.sum +++ b/go.sum @@ -506,6 +506,12 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0 h1:3UeQBvD0TFrlVjOeLOBz+CPAI8dnbqNSVwUwRrkp7vQ= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0/go.mod h1:IXCdmsXIht47RaVFLEdVnh1t+pgYtTAhQGj73kz+2DM= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= +github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -606,6 +612,7 @@ golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= diff --git a/module/auth_token_blacklist.go b/module/auth_token_blacklist.go new file mode 100644 index 00000000..1da3ceaf --- /dev/null +++ b/module/auth_token_blacklist.go @@ -0,0 +1,163 @@ +package module + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/redis/go-redis/v9" +) + +// TokenBlacklist is the interface for checking and adding revoked JWT IDs. +type TokenBlacklist interface { + Add(jti string, expiresAt time.Time) + IsBlacklisted(jti string) bool +} + +// TokenBlacklistModule maintains a set of revoked JWT IDs (JTIs). +// It supports two backends: "memory" (default) and "redis". +type TokenBlacklistModule struct { + name string + backend string + redisURL string + cleanupInterval time.Duration + + // memory backend + entries sync.Map // jti (string) -> expiry (time.Time) + + // redis backend + redisClient *redis.Client + + logger modular.Logger + stopCh chan struct{} +} + +// NewTokenBlacklistModule creates a new TokenBlacklistModule. +func NewTokenBlacklistModule(name, backend, redisURL string, cleanupInterval time.Duration) *TokenBlacklistModule { + if backend == "" { + backend = "memory" + } + if cleanupInterval <= 0 { + cleanupInterval = 5 * time.Minute + } + return &TokenBlacklistModule{ + name: name, + backend: backend, + redisURL: redisURL, + cleanupInterval: cleanupInterval, + stopCh: make(chan struct{}), + } +} + +// Name returns the module name. +func (m *TokenBlacklistModule) Name() string { return m.name } + +// Init initializes the module. +func (m *TokenBlacklistModule) Init(app modular.Application) error { + m.logger = app.Logger() + return nil +} + +// Start connects to Redis (if configured) and starts the cleanup goroutine. +func (m *TokenBlacklistModule) Start(ctx context.Context) error { + if m.backend == "redis" { + if m.redisURL == "" { + return fmt.Errorf("auth.token-blacklist %q: redis_url is required for redis backend", m.name) + } + opts, err := redis.ParseURL(m.redisURL) + if err != nil { + return fmt.Errorf("auth.token-blacklist %q: invalid redis_url: %w", m.name, err) + } + m.redisClient = redis.NewClient(opts) + if err := m.redisClient.Ping(ctx).Err(); err != nil { + _ = m.redisClient.Close() + m.redisClient = nil + return fmt.Errorf("auth.token-blacklist %q: redis ping failed: %w", m.name, err) + } + m.logger.Info("token blacklist started", "name", m.name, "backend", "redis") + return nil + } + + // memory backend: start cleanup goroutine + go m.runCleanup() + m.logger.Info("token blacklist started", "name", m.name, "backend", "memory") + return nil +} + +// Stop shuts down the module. +func (m *TokenBlacklistModule) Stop(_ context.Context) error { + select { + case <-m.stopCh: + // already closed + default: + close(m.stopCh) + } + if m.redisClient != nil { + return m.redisClient.Close() + } + return nil +} + +// Add marks a JTI as revoked until expiresAt. +func (m *TokenBlacklistModule) Add(jti string, expiresAt time.Time) { + if m.backend == "redis" && m.redisClient != nil { + ttl := time.Until(expiresAt) + if ttl <= 0 { + return // already expired, nothing to blacklist + } + _ = m.redisClient.Set(context.Background(), m.redisKey(jti), "1", ttl).Err() + return + } + m.entries.Store(jti, expiresAt) +} + +// IsBlacklisted returns true if the JTI is revoked and has not yet expired. +func (m *TokenBlacklistModule) IsBlacklisted(jti string) bool { + if m.backend == "redis" && m.redisClient != nil { + n, err := m.redisClient.Exists(context.Background(), m.redisKey(jti)).Result() + return err == nil && n > 0 + } + val, ok := m.entries.Load(jti) + if !ok { + return false + } + expiry, ok := val.(time.Time) + return ok && time.Now().Before(expiry) +} + +func (m *TokenBlacklistModule) redisKey(jti string) string { + return "blacklist:" + jti +} + +func (m *TokenBlacklistModule) runCleanup() { + ticker := time.NewTicker(m.cleanupInterval) + defer ticker.Stop() + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + now := time.Now() + m.entries.Range(func(key, value any) bool { + if expiry, ok := value.(time.Time); ok && now.After(expiry) { + m.entries.Delete(key) + } + return true + }) + } + } +} + +// ProvidesServices registers this module as a service. +func (m *TokenBlacklistModule) ProvidesServices() []modular.ServiceProvider { + return []modular.ServiceProvider{ + {Name: m.name, Description: "JWT token blacklist", Instance: m}, + } +} + +// RequiresServices returns service dependencies (none). +func (m *TokenBlacklistModule) RequiresServices() []modular.ServiceDependency { + return nil +} diff --git a/module/auth_token_blacklist_test.go b/module/auth_token_blacklist_test.go new file mode 100644 index 00000000..bb32985d --- /dev/null +++ b/module/auth_token_blacklist_test.go @@ -0,0 +1,89 @@ +package module + +import ( + "context" + "testing" + "time" +) + +func TestTokenBlacklistMemory_AddAndCheck(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", time.Minute) + + if bl.IsBlacklisted("jti-1") { + t.Fatal("expected jti-1 to not be blacklisted initially") + } + + bl.Add("jti-1", time.Now().Add(time.Hour)) + if !bl.IsBlacklisted("jti-1") { + t.Fatal("expected jti-1 to be blacklisted after Add") + } + + if bl.IsBlacklisted("jti-unknown") { + t.Fatal("expected jti-unknown to not be blacklisted") + } +} + +func TestTokenBlacklistMemory_ExpiredEntry(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", time.Minute) + + // Add a JTI that has already expired. + bl.Add("jti-expired", time.Now().Add(-time.Second)) + if bl.IsBlacklisted("jti-expired") { + t.Fatal("expected already-expired JTI to not be blacklisted") + } +} + +func TestTokenBlacklistMemory_Cleanup(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", 50*time.Millisecond) + + app := NewMockApplication() + if err := bl.Init(app); err != nil { + t.Fatalf("Init: %v", err) + } + if err := bl.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + defer func() { _ = bl.Stop(context.Background()) }() + + // Add a JTI that expires in 10ms. + bl.Add("jti-cleanup", time.Now().Add(10*time.Millisecond)) + + // Give it time to expire and be cleaned up by the cleanup goroutine. + time.Sleep(300 * time.Millisecond) + + if bl.IsBlacklisted("jti-cleanup") { + t.Fatal("expected cleaned-up JTI to no longer be blacklisted") + } +} + +func TestTokenBlacklistModule_StopIdempotent(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", time.Minute) + app := NewMockApplication() + if err := bl.Init(app); err != nil { + t.Fatalf("Init: %v", err) + } + if err := bl.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + // Calling Stop twice must not panic. + if err := bl.Stop(context.Background()); err != nil { + t.Fatalf("first Stop: %v", err) + } + if err := bl.Stop(context.Background()); err != nil { + t.Fatalf("second Stop: %v", err) + } +} + +func TestTokenBlacklistModule_ProvidesServices(t *testing.T) { + bl := NewTokenBlacklistModule("my-bl", "memory", "", time.Minute) + svcs := bl.ProvidesServices() + if len(svcs) != 1 { + t.Fatalf("expected 1 service, got %d", len(svcs)) + } + if svcs[0].Name != "my-bl" { + t.Fatalf("expected service name 'my-bl', got %q", svcs[0].Name) + } + if _, ok := svcs[0].Instance.(TokenBlacklist); !ok { + t.Fatal("expected service instance to implement TokenBlacklist") + } +} diff --git a/module/cache_redis.go b/module/cache_redis.go index 9c754770..8d750d7e 100644 --- a/module/cache_redis.go +++ b/module/cache_redis.go @@ -7,6 +7,7 @@ import ( "time" "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" "github.com/redis/go-redis/v9" ) @@ -30,10 +31,11 @@ type RedisClient interface { // RedisCacheConfig holds configuration for the cache.redis module. type RedisCacheConfig struct { Address string - Password string //nolint:gosec // G117: config struct field, not a hardcoded secret + Password string //nolint:gosec // G117: config struct field, not a hardcoded secret DB int Prefix string DefaultTTL time.Duration + TLS tlsutil.TLSConfig `yaml:"tls" json:"tls"` } // RedisCache is a module that connects to a Redis instance and exposes @@ -87,6 +89,14 @@ func (r *RedisCache) Start(ctx context.Context) error { opts.Password = r.cfg.Password } + if r.cfg.TLS.Enabled { + tlsCfg, err := tlsutil.LoadTLSConfig(r.cfg.TLS) + if err != nil { + return fmt.Errorf("cache.redis %q: TLS config: %w", r.name, err) + } + opts.TLSConfig = tlsCfg + } + r.client = redis.NewClient(opts) if err := r.client.Ping(ctx).Err(); err != nil { diff --git a/module/database.go b/module/database.go index 8b30917e..e07b1119 100644 --- a/module/database.go +++ b/module/database.go @@ -25,14 +25,22 @@ func validateIdentifier(name string) error { return nil } +// DatabaseTLSConfig holds TLS settings for database connections. +type DatabaseTLSConfig struct { + // Mode controls SSL behaviour: disable | require | verify-ca | verify-full (PostgreSQL naming). + Mode string `json:"mode" yaml:"mode"` + CAFile string `json:"ca_file" yaml:"ca_file"` +} + // DatabaseConfig holds configuration for the workflow database module type DatabaseConfig struct { - Driver string `json:"driver" yaml:"driver"` - DSN string `json:"dsn" yaml:"dsn"` - MaxOpenConns int `json:"maxOpenConns" yaml:"maxOpenConns"` - MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"` - ConnMaxLifetime time.Duration `json:"connMaxLifetime" yaml:"connMaxLifetime"` - MigrationsDir string `json:"migrationsDir" yaml:"migrationsDir"` + Driver string `json:"driver" yaml:"driver"` + DSN string `json:"dsn" yaml:"dsn"` + MaxOpenConns int `json:"maxOpenConns" yaml:"maxOpenConns"` + MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"` + ConnMaxLifetime time.Duration `json:"connMaxLifetime" yaml:"connMaxLifetime"` + MigrationsDir string `json:"migrationsDir" yaml:"migrationsDir"` + TLS DatabaseTLSConfig `json:"tls" yaml:"tls"` } // QueryResult represents the result of a query @@ -85,6 +93,28 @@ func (w *WorkflowDatabase) RequiresServices() []modular.ServiceDependency { return nil } +// buildDSN returns the DSN with TLS parameters appended for supported drivers. +func (w *WorkflowDatabase) buildDSN() string { + dsn := w.config.DSN + mode := w.config.TLS.Mode + if mode == "" || mode == "disable" { + return dsn + } + + switch w.config.Driver { + case "postgres", "pgx", "pgx/v5": + sep := "?" + if strings.ContainsRune(dsn, '?') { + sep = "&" + } + dsn += sep + "sslmode=" + mode + if w.config.TLS.CAFile != "" { + dsn += "&sslrootcert=" + w.config.TLS.CAFile + } + } + return dsn +} + // Open opens the database connection using config func (w *WorkflowDatabase) Open() (*sql.DB, error) { w.mu.Lock() @@ -94,7 +124,7 @@ func (w *WorkflowDatabase) Open() (*sql.DB, error) { return w.db, nil } - db, err := sql.Open(w.config.Driver, w.config.DSN) + db, err := sql.Open(w.config.Driver, w.buildDSN()) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } diff --git a/module/field_protection.go b/module/field_protection.go new file mode 100644 index 00000000..d0045d88 --- /dev/null +++ b/module/field_protection.go @@ -0,0 +1,159 @@ +package module + +import ( + "context" + "fmt" + "os" + + "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/fieldcrypt" +) + +// ProtectedFieldManager bundles the registry and key ring for field protection. +type ProtectedFieldManager struct { + Registry *fieldcrypt.Registry + KeyRing fieldcrypt.KeyRing + TenantIsolation bool + ScanDepth int + ScanArrays bool + defaultTenantID string + // rawMasterKey is retained for legacy enc:: decryption (version 0). + rawMasterKey []byte +} + +// resolveTenant returns the effective tenant ID, applying isolation policy. +func (m *ProtectedFieldManager) resolveTenant(tenantID string) string { + if !m.TenantIsolation { + return m.defaultTenantID + } + if tenantID == "" { + return m.defaultTenantID + } + return tenantID +} + +// EncryptMap encrypts protected fields in the data map in-place. +func (m *ProtectedFieldManager) EncryptMap(ctx context.Context, tenantID string, data map[string]any) error { + tid := m.resolveTenant(tenantID) + return fieldcrypt.ScanAndEncrypt(data, m.Registry, func() ([]byte, int, error) { + return m.KeyRing.CurrentKey(ctx, tid) + }, m.ScanDepth) +} + +// DecryptMap decrypts protected fields in the data map in-place. +// Version 0 is the legacy enc:: format and uses the raw master key. +func (m *ProtectedFieldManager) DecryptMap(ctx context.Context, tenantID string, data map[string]any) error { + tid := m.resolveTenant(tenantID) + return fieldcrypt.ScanAndDecrypt(data, m.Registry, func(version int) ([]byte, error) { + if version == 0 { + // Legacy enc:: values were encrypted with sha256(masterKey). + // encrypt.go's decryptLegacy calls keyFn(0) expecting the raw key + // bytes, then SHA256-hashes them before use. + return m.rawMasterKey, nil + } + return m.KeyRing.KeyByVersion(ctx, tid, version) + }, m.ScanDepth) +} + +// MaskMap returns a deep copy of data with protected fields masked for logging. +func (m *ProtectedFieldManager) MaskMap(data map[string]any) map[string]any { + return fieldcrypt.ScanAndMask(data, m.Registry, m.ScanDepth) +} + +// FieldProtectionModule implements modular.Module for security.field-protection. +type FieldProtectionModule struct { + name string + manager *ProtectedFieldManager +} + +// NewFieldProtectionModule parses config and creates a FieldProtectionModule. +func NewFieldProtectionModule(name string, cfg map[string]any) (*FieldProtectionModule, error) { + // Parse protected fields. + var fields []fieldcrypt.ProtectedField + if raw, ok := cfg["protected_fields"].([]any); ok { + for _, item := range raw { + if m, ok := item.(map[string]any); ok { + pf := fieldcrypt.ProtectedField{} + if v, ok := m["name"].(string); ok { + pf.Name = v + } + if v, ok := m["classification"].(string); ok { + pf.Classification = fieldcrypt.FieldClassification(v) + } + if v, ok := m["encryption"].(bool); ok { + pf.Encryption = v + } + if v, ok := m["log_behavior"].(string); ok { + pf.LogBehavior = fieldcrypt.LogBehavior(v) + } + if v, ok := m["mask_pattern"].(string); ok { + pf.MaskPattern = v + } + fields = append(fields, pf) + } + } + } + + // Resolve master key. + masterKeyStr, _ := cfg["master_key"].(string) + if masterKeyStr == "" { + masterKeyStr = os.Getenv("FIELD_ENCRYPTION_KEY") + } + + if masterKeyStr == "" { + return nil, fmt.Errorf("field-protection: master_key config or FIELD_ENCRYPTION_KEY env var is required") + } + masterKey := []byte(masterKeyStr) + + scanDepth := 10 + if v, ok := cfg["scan_depth"].(int); ok && v > 0 { + scanDepth = v + } + + scanArrays := true + if v, ok := cfg["scan_arrays"].(bool); ok { + scanArrays = v + } + + tenantIsolation := false + if v, ok := cfg["tenant_isolation"].(bool); ok { + tenantIsolation = v + } + + defaultTenantID := "default" + + mgr := &ProtectedFieldManager{ + Registry: fieldcrypt.NewRegistry(fields), + KeyRing: fieldcrypt.NewLocalKeyRing(masterKey), + TenantIsolation: tenantIsolation, + ScanDepth: scanDepth, + ScanArrays: scanArrays, + defaultTenantID: defaultTenantID, + rawMasterKey: masterKey, + } + + return &FieldProtectionModule{name: name, manager: mgr}, nil +} + +// Name returns the module name. +func (m *FieldProtectionModule) Name() string { return m.name } + +// Init initializes the module. +func (m *FieldProtectionModule) Init(_ modular.Application) error { return nil } + +// ProvidesServices returns the services provided by this module. +func (m *FieldProtectionModule) ProvidesServices() []modular.ServiceProvider { + return []modular.ServiceProvider{ + {Name: m.name, Description: "Protected field manager for encryption/masking", Instance: m.manager}, + } +} + +// RequiresServices returns service dependencies (none). +func (m *FieldProtectionModule) RequiresServices() []modular.ServiceDependency { + return nil +} + +// Manager returns the ProtectedFieldManager. +func (m *FieldProtectionModule) Manager() *ProtectedFieldManager { + return m.manager +} diff --git a/module/field_protection_test.go b/module/field_protection_test.go new file mode 100644 index 00000000..d6ad3090 --- /dev/null +++ b/module/field_protection_test.go @@ -0,0 +1,206 @@ +package module + +import ( + "context" + "testing" + + "github.com/GoCodeAlone/workflow/pkg/fieldcrypt" +) + +func TestFieldProtectionModuleConfigParsing(t *testing.T) { + cfg := map[string]any{ + "master_key": "test-master-key-that-is-long-en", + "tenant_isolation": true, + "scan_depth": 5, + "scan_arrays": false, + "protected_fields": []any{ + map[string]any{ + "name": "ssn", + "classification": "pii", + "encryption": true, + "log_behavior": "redact", + }, + map[string]any{ + "name": "email", + "classification": "pii", + "encryption": true, + "log_behavior": "mask", + }, + }, + } + + mod, err := NewFieldProtectionModule("test-fp", cfg) + if err != nil { + t.Fatalf("NewFieldProtectionModule: %v", err) + } + if mod.Name() != "test-fp" { + t.Errorf("name = %q", mod.Name()) + } + + mgr := mod.Manager() + if mgr == nil { + t.Fatal("manager is nil") + } + if !mgr.TenantIsolation { + t.Error("tenant_isolation should be true") + } + if mgr.ScanDepth != 5 { + t.Errorf("scan_depth = %d", mgr.ScanDepth) + } + if mgr.ScanArrays { + t.Error("scan_arrays should be false") + } + if !mgr.Registry.IsProtected("ssn") { + t.Error("ssn should be protected") + } + if !mgr.Registry.IsProtected("email") { + t.Error("email should be protected") + } + if mgr.Registry.IsProtected("name") { + t.Error("name should not be protected") + } +} + +func TestFieldProtectionEncryptDecryptRoundTrip(t *testing.T) { + cfg := map[string]any{ + "master_key": "test-key-for-encrypt-decrypt-rt!", + "protected_fields": []any{ + map[string]any{ + "name": "ssn", + "encryption": true, + }, + map[string]any{ + "name": "email", + "encryption": true, + }, + }, + } + + mod, err := NewFieldProtectionModule("fp", cfg) + if err != nil { + t.Fatal(err) + } + mgr := mod.Manager() + ctx := context.Background() + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "user@example.com", + "name": "John Doe", + } + + if err := mgr.EncryptMap(ctx, "tenant1", data); err != nil { + t.Fatalf("EncryptMap: %v", err) + } + + if !fieldcrypt.IsEncrypted(data["ssn"].(string)) { + t.Error("ssn should be encrypted") + } + if !fieldcrypt.IsEncrypted(data["email"].(string)) { + t.Error("email should be encrypted") + } + if data["name"] != "John Doe" { + t.Error("name should not be modified") + } + + if err := mgr.DecryptMap(ctx, "tenant1", data); err != nil { + t.Fatalf("DecryptMap: %v", err) + } + + if data["ssn"] != "123-45-6789" { + t.Errorf("ssn = %q", data["ssn"]) + } + if data["email"] != "user@example.com" { + t.Errorf("email = %q", data["email"]) + } +} + +func TestFieldProtectionMaskMap(t *testing.T) { + cfg := map[string]any{ + "master_key": "mask-test-key-32-chars-exactly!!", + "protected_fields": []any{ + map[string]any{ + "name": "ssn", + "log_behavior": "redact", + }, + map[string]any{ + "name": "email", + "log_behavior": "mask", + }, + map[string]any{ + "name": "phone", + "log_behavior": "hash", + }, + }, + } + + mod, err := NewFieldProtectionModule("fp", cfg) + if err != nil { + t.Fatal(err) + } + mgr := mod.Manager() + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "user@example.com", + "phone": "555-123-4567", + "name": "John Doe", + } + + masked := mgr.MaskMap(data) + + if masked["ssn"] != "[REDACTED]" { + t.Errorf("ssn mask = %q", masked["ssn"]) + } + if masked["email"] == "user@example.com" { + t.Error("email should be masked") + } + if masked["phone"] == "555-123-4567" { + t.Error("phone should be hashed") + } + if masked["name"] != "John Doe" { + t.Error("name should be unchanged") + } + + // Original data should be unmodified. + if data["ssn"] != "123-45-6789" { + t.Error("original ssn was modified") + } +} + +func TestFieldProtectionProvidesServices(t *testing.T) { + cfg := map[string]any{ + "master_key": "svc-test-key-32-chars-exactly!!", + "protected_fields": []any{}, + } + + mod, err := NewFieldProtectionModule("my-fp", cfg) + if err != nil { + t.Fatal(err) + } + + svcs := mod.ProvidesServices() + if len(svcs) != 1 { + t.Fatalf("expected 1 service, got %d", len(svcs)) + } + if svcs[0].Name != "my-fp" { + t.Errorf("service name = %q", svcs[0].Name) + } + if _, ok := svcs[0].Instance.(*ProtectedFieldManager); !ok { + t.Error("service instance should be *ProtectedFieldManager") + } +} + +func TestFieldProtectionRequiresMasterKey(t *testing.T) { + cfg := map[string]any{ + "protected_fields": []any{}, + } + + // Unset env var to ensure it's not picked up. + t.Setenv("FIELD_ENCRYPTION_KEY", "") + + _, err := NewFieldProtectionModule("fp-nokey", cfg) + if err == nil { + t.Fatal("expected error when no master_key is provided") + } +} diff --git a/module/http_server.go b/module/http_server.go index a93d79ee..f562becc 100644 --- a/module/http_server.go +++ b/module/http_server.go @@ -2,14 +2,26 @@ package module import ( "context" + "crypto/tls" "errors" "fmt" "net/http" "time" "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" + "golang.org/x/crypto/acme/autocert" ) +// HTTPServerTLSConfig holds TLS configuration for the HTTP server. +type HTTPServerTLSConfig struct { + Mode string `yaml:"mode" json:"mode"` // manual | autocert | disabled + Manual tlsutil.TLSConfig `yaml:"manual" json:"manual"` + Autocert tlsutil.AutocertConfig `yaml:"autocert" json:"autocert"` + ClientCAFile string `yaml:"client_ca_file" json:"client_ca_file"` + ClientAuth string `yaml:"client_auth" json:"client_auth"` // require | request | none +} + // StandardHTTPServer implements the HTTPServer interface and modular.Module interfaces type StandardHTTPServer struct { name string @@ -20,6 +32,7 @@ type StandardHTTPServer struct { readTimeout time.Duration writeTimeout time.Duration idleTimeout time.Duration + tlsCfg HTTPServerTLSConfig } // NewStandardHTTPServer creates a new HTTP server with the given name and address @@ -38,6 +51,11 @@ func (s *StandardHTTPServer) SetTimeouts(read, write, idle time.Duration) { s.idleTimeout = idle } +// SetTLSConfig configures TLS for the HTTP server. +func (s *StandardHTTPServer) SetTLSConfig(cfg HTTPServerTLSConfig) { + s.tlsCfg = cfg +} + // Name returns the unique identifier for this module func (s *StandardHTTPServer) Name() string { return s.name @@ -87,14 +105,90 @@ func (s *StandardHTTPServer) Start(ctx context.Context) error { IdleTimeout: timeoutOrDefault(s.idleTimeout, 120*time.Second), } - // Start the server in a goroutine + switch s.tlsCfg.Mode { + case "autocert": + return s.startAutocert(ctx) + case "manual": + return s.startManualTLS(ctx) + default: + // Plain HTTP + go func() { + if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("HTTP server error", "error", err) + } + }() + s.logger.Info("HTTP server started", "address", s.address) + return nil + } +} + +// startManualTLS starts the server with manually configured TLS certificates. +func (s *StandardHTTPServer) startManualTLS(ctx context.Context) error { + manualCfg := s.tlsCfg.Manual + manualCfg.Enabled = true + + // Overlay mTLS settings from the top-level fields when set + if s.tlsCfg.ClientCAFile != "" { + manualCfg.CAFile = s.tlsCfg.ClientCAFile + } + if s.tlsCfg.ClientAuth != "" { + manualCfg.ClientAuth = s.tlsCfg.ClientAuth + } + + tlsConfig, err := tlsutil.LoadTLSConfig(manualCfg) + if err != nil { + return fmt.Errorf("http server TLS config: %w", err) + } + s.server.TLSConfig = tlsConfig + go func() { - if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - s.logger.Error("HTTP server error", "error", err) + if err := s.server.ListenAndServeTLS(manualCfg.CertFile, manualCfg.KeyFile); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("HTTPS server error", "error", err) } }() + s.logger.Info("HTTPS server started (manual TLS)", "address", s.address) + return nil +} - s.logger.Info("HTTP server started", "address", s.address) +// startAutocert starts the server using Let's Encrypt via autocert. +func (s *StandardHTTPServer) startAutocert(ctx context.Context) error { + ac := s.tlsCfg.Autocert + if len(ac.Domains) == 0 { + return fmt.Errorf("http server autocert: at least one domain is required") + } + + m := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(ac.Domains...), + Email: ac.Email, + } + if ac.CacheDir != "" { + m.Cache = autocert.DirCache(ac.CacheDir) + } + + s.server.TLSConfig = &tls.Config{ + GetCertificate: m.GetCertificate, + MinVersion: tls.VersionTLS12, + } + + // ACME HTTP-01 challenge listener on :80 + go func() { + httpSrv := &http.Server{ + Addr: ":80", + Handler: m.HTTPHandler(nil), + ReadHeaderTimeout: 10 * time.Second, + } + if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("autocert HTTP-01 listener error", "error", err) + } + }() + + go func() { + if err := s.server.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("HTTPS server error (autocert)", "error", err) + } + }() + s.logger.Info("HTTPS server started (autocert)", "address", s.address, "domains", ac.Domains) return nil } diff --git a/module/jwt_auth.go b/module/jwt_auth.go index f047e0d1..d6762266 100644 --- a/module/jwt_auth.go +++ b/module/jwt_auth.go @@ -13,6 +13,7 @@ import ( "github.com/CrisisTextLine/modular" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "golang.org/x/crypto/bcrypt" ) @@ -44,6 +45,7 @@ type JWTAuthModule struct { persistence *PersistenceStore // optional write-through backend userStore *UserStore // optional external user store (from auth.user-store module) allowRegistration bool // when true, any visitor may self-register + tokenBlacklist TokenBlacklist // optional revocation check (wired by auth plugin) } // NewJWTAuthModule creates a new JWT auth module @@ -84,6 +86,12 @@ func (j *JWTAuthModule) SetAllowRegistration(allow bool) { j.allowRegistration = allow } +// SetTokenBlacklist wires a TokenBlacklist to this module so that revoked +// tokens are rejected during Authenticate. +func (j *JWTAuthModule) SetTokenBlacklist(bl TokenBlacklist) { + j.tokenBlacklist = bl +} + // Name returns the module name func (j *JWTAuthModule) Name() string { return j.name @@ -139,6 +147,15 @@ func (j *JWTAuthModule) Authenticate(tokenStr string) (bool, map[string]any, err return false, nil, nil } + // Check token revocation if a blacklist is wired. + if j.tokenBlacklist != nil { + if jti, ok := claims["jti"].(string); ok && jti != "" { + if j.tokenBlacklist.IsBlacklisted(jti) { + return false, nil, nil + } + } + } + result := make(map[string]any) maps.Copy(result, claims) return true, result, nil @@ -446,6 +463,7 @@ func (j *JWTAuthModule) generateToken(user *User) (string, error) { "email": user.Email, "name": user.Name, "iss": j.issuer, + "jti": uuid.NewString(), "iat": time.Now().Unix(), "exp": time.Now().Add(j.tokenExpiry).Unix(), } diff --git a/module/kafka_broker.go b/module/kafka_broker.go index e27177d0..4cdfd92e 100644 --- a/module/kafka_broker.go +++ b/module/kafka_broker.go @@ -2,13 +2,28 @@ package module import ( "context" + "encoding/json" "fmt" "sync" "github.com/CrisisTextLine/modular" "github.com/IBM/sarama" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" ) +// KafkaSASLConfig holds SASL authentication configuration for Kafka. +type KafkaSASLConfig struct { + Mechanism string `yaml:"mechanism" json:"mechanism"` // PLAIN | SCRAM-SHA-256 | SCRAM-SHA-512 + Username string `yaml:"username" json:"username"` + Password string `yaml:"password" json:"password"` //nolint:gosec // G117: config struct field +} + +// KafkaTLSConfig holds TLS and SASL configuration for the Kafka broker. +type KafkaTLSConfig struct { + tlsutil.TLSConfig `yaml:",inline" json:",inline"` + SASL KafkaSASLConfig `yaml:"sasl" json:"sasl"` +} + // KafkaBroker implements the MessageBroker interface using Apache Kafka via Sarama. type KafkaBroker struct { name string @@ -24,7 +39,9 @@ type KafkaBroker struct { logger modular.Logger healthy bool healthMsg string - encryptor *FieldEncryptor + encryptor *FieldEncryptor + fieldProtector *ProtectedFieldManager + tlsCfg KafkaTLSConfig } // NewKafkaBroker creates a new Kafka message broker. @@ -93,6 +110,22 @@ func (b *KafkaBroker) SetGroupID(groupID string) { b.groupID = groupID } +// SetTLSConfig sets the TLS and SASL configuration for the Kafka broker. +// SetFieldProtection sets the field-level encryption manager for this broker. +// When set, individual protected fields are encrypted/decrypted in JSON payloads +// before the legacy whole-message encryptor runs. +func (b *KafkaBroker) SetFieldProtection(mgr *ProtectedFieldManager) { + b.mu.Lock() + defer b.mu.Unlock() + b.fieldProtector = mgr +} + +func (b *KafkaBroker) SetTLSConfig(cfg KafkaTLSConfig) { + b.mu.Lock() + defer b.mu.Unlock() + b.tlsCfg = cfg +} + // HealthStatus implements the HealthCheckable interface. func (b *KafkaBroker) HealthStatus() HealthCheckResult { b.mu.RLock() @@ -146,6 +179,38 @@ func (b *KafkaBroker) Start(ctx context.Context) error { config.Consumer.Group.Rebalance.GroupStrategies = []sarama.BalanceStrategy{sarama.NewBalanceStrategyRoundRobin()} config.Consumer.Offsets.Initial = sarama.OffsetNewest + // Apply TLS configuration + if b.tlsCfg.Enabled { + tlsCfg, tlsErr := tlsutil.LoadTLSConfig(b.tlsCfg.TLSConfig) + if tlsErr != nil { + return fmt.Errorf("kafka broker %q: TLS config: %w", b.name, tlsErr) + } + config.Net.TLS.Enable = true + config.Net.TLS.Config = tlsCfg + } + + // Apply SASL configuration + sasl := b.tlsCfg.SASL + if sasl.Username != "" { + config.Net.SASL.Enable = true + config.Net.SASL.User = sasl.Username + config.Net.SASL.Password = sasl.Password + switch sasl.Mechanism { + case "SCRAM-SHA-256": + config.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA256 + config.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { + return &xDGSCRAMClient{HashGeneratorFcn: SHA256} + } + case "SCRAM-SHA-512": + config.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA512 + config.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { + return &xDGSCRAMClient{HashGeneratorFcn: SHA512} + } + default: // PLAIN + config.Net.SASL.Mechanism = sarama.SASLTypePlaintext + } + } + // Create sync producer producer, err := sarama.NewSyncProducer(b.brokers, config) if err != nil { @@ -244,16 +309,31 @@ func (p *kafkaProducerAdapter) SendMessage(topic string, message []byte) error { p.broker.mu.RLock() producer := p.broker.producer encryptor := p.broker.encryptor + fieldProt := p.broker.fieldProtector p.broker.mu.RUnlock() if producer == nil { return fmt.Errorf("kafka producer not initialized; call Start first") } - // Encrypt the message payload if encryption is enabled payload := message + + // Field-level encryption: encrypt individual protected fields in JSON payloads. + if fieldProt != nil { + var data map[string]any + if err := json.Unmarshal(payload, &data); err == nil { + if encErr := fieldProt.EncryptMap(context.Background(), "", data); encErr != nil { + return fmt.Errorf("failed to field-encrypt kafka message for topic %q: %w", topic, encErr) + } + if out, err := json.Marshal(data); err == nil { + payload = out + } + } + } + + // Legacy whole-message encryption (if configured via ENCRYPTION_KEY). if encryptor != nil && encryptor.Enabled() { - encrypted, err := encryptor.EncryptJSON(message) + encrypted, err := encryptor.EncryptJSON(payload) if err != nil { return fmt.Errorf("failed to encrypt kafka message for topic %q: %w", topic, err) } @@ -313,10 +393,11 @@ func (h *kafkaGroupHandler) ConsumeClaim(session sarama.ConsumerGroupSession, cl h.broker.mu.RLock() handler, ok := h.broker.handlers[msg.Topic] encryptor := h.broker.encryptor + fieldProt := h.broker.fieldProtector h.broker.mu.RUnlock() if ok { - // Decrypt message payload if encryption is enabled + // Legacy whole-message decryption first. payload := msg.Value if encryptor != nil && encryptor.Enabled() { decrypted, err := encryptor.DecryptJSON(payload) @@ -327,6 +408,17 @@ func (h *kafkaGroupHandler) ConsumeClaim(session sarama.ConsumerGroupSession, cl } payload = decrypted } + // Field-level decryption: decrypt individual protected fields. + if fieldProt != nil { + var data map[string]any + if err := json.Unmarshal(payload, &data); err == nil { + if decErr := fieldProt.DecryptMap(context.Background(), "", data); decErr == nil { + if out, err := json.Marshal(data); err == nil { + payload = out + } + } + } + } if err := handler.HandleMessage(payload); err != nil { h.broker.logger.Error("Error handling Kafka message", "topic", msg.Topic, "error", err) diff --git a/module/kafka_scram.go b/module/kafka_scram.go new file mode 100644 index 00000000..9f455cf0 --- /dev/null +++ b/module/kafka_scram.go @@ -0,0 +1,44 @@ +package module + +import ( + "crypto/sha256" + "crypto/sha512" + "hash" + + "github.com/xdg-go/scram" +) + +// SHA256 and SHA512 are hash generator functions used by the SCRAM client. +var ( + SHA256 scram.HashGeneratorFcn = sha256.New + SHA512 scram.HashGeneratorFcn = sha512.New +) + +// xDGSCRAMClient implements sarama.SCRAMClient using the xdg-go/scram package. +type xDGSCRAMClient struct { + *scram.Client + *scram.ClientConversation + scram.HashGeneratorFcn +} + +func (x *xDGSCRAMClient) Begin(userName, password, authzID string) error { + client, err := x.NewClient(userName, password, authzID) + if err != nil { + return err + } + x.Client = client + x.ClientConversation = client.NewConversation() + return nil +} + +func (x *xDGSCRAMClient) Step(challenge string) (string, error) { + return x.ClientConversation.Step(challenge) +} + +func (x *xDGSCRAMClient) Done() bool { + return x.ClientConversation.Done() +} + +// ensure hash functions satisfy the interface (compile-time check) +var _ func() hash.Hash = sha256.New +var _ func() hash.Hash = sha512.New diff --git a/module/nats_broker.go b/module/nats_broker.go index 1ef24451..ddb7eaea 100644 --- a/module/nats_broker.go +++ b/module/nats_broker.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" "github.com/nats-io/nats.go" ) @@ -20,6 +21,7 @@ type NATSBroker struct { producer *natsProducer consumer *natsConsumer logger modular.Logger + tlsCfg tlsutil.TLSConfig } // NewNATSBroker creates a new NATS message broker. @@ -80,6 +82,13 @@ func (b *NATSBroker) SetURL(url string) { b.url = url } +// SetTLSConfig configures TLS for the NATS broker connection. +func (b *NATSBroker) SetTLSConfig(cfg tlsutil.TLSConfig) { + b.mu.Lock() + defer b.mu.Unlock() + b.tlsCfg = cfg +} + // Producer returns the message producer interface. func (b *NATSBroker) Producer() MessageProducer { return b.producer @@ -100,7 +109,16 @@ func (b *NATSBroker) Start(ctx context.Context) error { b.mu.Lock() defer b.mu.Unlock() - conn, err := nats.Connect(b.url) + var opts []nats.Option + if b.tlsCfg.Enabled { + tlsCfg, tlsErr := tlsutil.LoadTLSConfig(b.tlsCfg) + if tlsErr != nil { + return fmt.Errorf("nats broker %q: TLS config: %w", b.name, tlsErr) + } + opts = append(opts, nats.Secure(tlsCfg)) + } + + conn, err := nats.Connect(b.url, opts...) if err != nil { return fmt.Errorf("failed to connect to NATS at %s: %w", b.url, err) } diff --git a/module/pipeline_step_field_reencrypt.go b/module/pipeline_step_field_reencrypt.go new file mode 100644 index 00000000..27d3173e --- /dev/null +++ b/module/pipeline_step_field_reencrypt.go @@ -0,0 +1,73 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" +) + +// FieldReencryptStep re-encrypts pipeline context data with the latest key version. +type FieldReencryptStep struct { + name string + module string // field-protection module name to look up + tenantID string // template expression for tenant ID + tmpl *TemplateEngine + app modular.Application +} + +// NewFieldReencryptStepFactory returns a StepFactory for step.field_reencrypt. +func NewFieldReencryptStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + moduleName, _ := config["module"].(string) + if moduleName == "" { + return nil, fmt.Errorf("field_reencrypt step %q: 'module' is required", name) + } + tenantID, _ := config["tenant_id"].(string) + + return &FieldReencryptStep{ + name: name, + module: moduleName, + tenantID: tenantID, + tmpl: NewTemplateEngine(), + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *FieldReencryptStep) Name() string { return s.name } + +// Execute re-encrypts data by decrypting with the old key and encrypting with the current key. +func (s *FieldReencryptStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + // Look up the ProtectedFieldManager from the service registry. + var manager *ProtectedFieldManager + if err := s.app.GetService(s.module, &manager); err != nil { + return nil, fmt.Errorf("field_reencrypt step %q: module %q not found: %w", s.name, s.module, err) + } + + // Resolve tenant ID from template expression. + tenantID := s.tenantID + if tenantID != "" { + resolved, err := s.tmpl.Resolve(tenantID, pc) + if err == nil && resolved != "" { + tenantID = resolved + } + } + + // Decrypt with old key version, then re-encrypt with current key. + data := pc.Current + if data == nil { + return &StepResult{Output: map[string]any{"reencrypted": false, "reason": "no data"}}, nil + } + + if err := manager.DecryptMap(ctx, tenantID, data); err != nil { + return nil, fmt.Errorf("field_reencrypt step %q: decrypt: %w", s.name, err) + } + + if err := manager.EncryptMap(ctx, tenantID, data); err != nil { + return nil, fmt.Errorf("field_reencrypt step %q: encrypt: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{"reencrypted": true}}, nil +} diff --git a/module/pipeline_step_sandbox_exec.go b/module/pipeline_step_sandbox_exec.go new file mode 100644 index 00000000..d63a4a88 --- /dev/null +++ b/module/pipeline_step_sandbox_exec.go @@ -0,0 +1,231 @@ +package module + +import ( + "context" + "fmt" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/sandbox" +) + +const defaultSandboxImage = "cgr.dev/chainguard/wolfi-base:latest" + +// SandboxExecStep runs a command in a hardened Docker sandbox container. +type SandboxExecStep struct { + name string + image string + command []string + securityProfile string + memoryLimit int64 + cpuLimit float64 + timeout time.Duration + network string + env map[string]string + mounts []sandbox.Mount + failOnError bool +} + +// NewSandboxExecStepFactory returns a StepFactory for step.sandbox_exec. +func NewSandboxExecStepFactory() StepFactory { + return func(name string, cfg map[string]any, _ modular.Application) (PipelineStep, error) { + step := &SandboxExecStep{ + name: name, + image: defaultSandboxImage, + securityProfile: "strict", + failOnError: true, + } + + if img, ok := cfg["image"].(string); ok && img != "" { + step.image = img + } + + // command + switch v := cfg["command"].(type) { + case []any: + for i, c := range v { + s, ok := c.(string) + if !ok { + return nil, fmt.Errorf("sandbox_exec step %q: command[%d] must be a string", name, i) + } + step.command = append(step.command, s) + } + case []string: + step.command = v + case nil: + // allowed — step may be used without a command for future use + default: + return nil, fmt.Errorf("sandbox_exec step %q: 'command' must be a list of strings", name) + } + + if profile, ok := cfg["security_profile"].(string); ok && profile != "" { + switch profile { + case "strict", "standard", "permissive": + step.securityProfile = profile + default: + return nil, fmt.Errorf("sandbox_exec step %q: security_profile must be strict, standard, or permissive", name) + } + } + + if ms, ok := cfg["memory_limit"].(string); ok && ms != "" { + limit, err := parseMemoryLimit(ms) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: invalid memory_limit: %w", name, err) + } + step.memoryLimit = limit + } + + if cpu, ok := cfg["cpu_limit"].(float64); ok { + step.cpuLimit = cpu + } + + if ts, ok := cfg["timeout"].(string); ok && ts != "" { + d, err := time.ParseDuration(ts) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: invalid timeout %q: %w", name, ts, err) + } + step.timeout = d + } + + if net, ok := cfg["network"].(string); ok && net != "" { + step.network = net + } + + if envRaw, ok := cfg["env"].(map[string]any); ok { + step.env = make(map[string]string, len(envRaw)) + for k, v := range envRaw { + step.env[k] = fmt.Sprintf("%v", v) + } + } + + if mountsRaw, ok := cfg["mounts"].([]any); ok { + for i, m := range mountsRaw { + mmap, ok := m.(map[string]any) + if !ok { + return nil, fmt.Errorf("sandbox_exec step %q: mounts[%d] must be a map", name, i) + } + src, _ := mmap["source"].(string) + tgt, _ := mmap["target"].(string) + ro, _ := mmap["read_only"].(bool) + step.mounts = append(step.mounts, sandbox.Mount{Source: src, Target: tgt, ReadOnly: ro}) + } + } + + if foe, ok := cfg["fail_on_error"].(bool); ok { + step.failOnError = foe + } + + return step, nil + } +} + +// Name returns the step name. +func (s *SandboxExecStep) Name() string { return s.name } + +// Execute runs the configured command in a Docker sandbox. +func (s *SandboxExecStep) Execute(ctx context.Context, _ *PipelineContext) (*StepResult, error) { + sbCfg := s.buildSandboxConfig() + + sb, err := sandbox.NewDockerSandbox(sbCfg) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: failed to create sandbox: %w", s.name, err) + } + defer sb.Close() + + result, err := sb.Exec(ctx, s.command) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: execution failed: %w", s.name, err) + } + + output := map[string]any{ + "exit_code": result.ExitCode, + "stdout": result.Stdout, + "stderr": result.Stderr, + } + + if result.ExitCode != 0 && s.failOnError { + return &StepResult{Output: output, Stop: true}, nil + } + + return &StepResult{Output: output}, nil +} + +// buildSandboxConfig constructs a SandboxConfig based on the security profile +// and any explicit overrides provided in the step config. +func (s *SandboxExecStep) buildSandboxConfig() sandbox.SandboxConfig { + var cfg sandbox.SandboxConfig + + switch s.securityProfile { + case "permissive": + cfg = sandbox.SandboxConfig{ + Image: s.image, + NetworkMode: "bridge", + } + case "standard": + cfg = sandbox.SandboxConfig{ + Image: s.image, + MemoryLimit: 256 * 1024 * 1024, + CPULimit: 0.5, + NetworkMode: "bridge", + CapDrop: []string{"NET_ADMIN", "SYS_ADMIN", "SYS_PTRACE", "SETUID", "SETGID"}, + CapAdd: []string{"NET_BIND_SERVICE"}, + NoNewPrivileges: true, + PidsLimit: 64, + Timeout: 5 * time.Minute, + } + default: // "strict" + cfg = sandbox.DefaultSecureSandboxConfig(s.image) + } + + // Apply explicit overrides + if s.memoryLimit > 0 { + cfg.MemoryLimit = s.memoryLimit + } + if s.cpuLimit > 0 { + cfg.CPULimit = s.cpuLimit + } + if s.timeout > 0 { + cfg.Timeout = s.timeout + } + if s.network != "" { + cfg.NetworkMode = s.network + } + if len(s.env) > 0 { + cfg.Env = s.env + } + if len(s.mounts) > 0 { + cfg.Mounts = s.mounts + } + + return cfg +} + +// parseMemoryLimit parses a human-readable memory string (e.g., "128m", "1g") into bytes. +func parseMemoryLimit(s string) (int64, error) { + if len(s) == 0 { + return 0, fmt.Errorf("empty memory limit") + } + last := s[len(s)-1] + var multiplier int64 = 1 + numStr := s + switch last { + case 'k', 'K': + multiplier = 1024 + numStr = s[:len(s)-1] + case 'm', 'M': + multiplier = 1024 * 1024 + numStr = s[:len(s)-1] + case 'g', 'G': + multiplier = 1024 * 1024 * 1024 + numStr = s[:len(s)-1] + case 'b', 'B': + numStr = s[:len(s)-1] + } + + var n int64 + _, err := fmt.Sscanf(numStr, "%d", &n) + if err != nil { + return 0, fmt.Errorf("invalid memory limit %q", s) + } + return n * multiplier, nil +} diff --git a/module/pipeline_step_sandbox_exec_test.go b/module/pipeline_step_sandbox_exec_test.go new file mode 100644 index 00000000..6a9f4e78 --- /dev/null +++ b/module/pipeline_step_sandbox_exec_test.go @@ -0,0 +1,302 @@ +package module + +import ( + "testing" + "time" +) + +func TestNewSandboxExecStepFactory_Defaults(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("test-step", map[string]any{ + "command": []any{"echo", "hello"}, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.name != "test-step" { + t.Fatalf("unexpected name: %s", s.name) + } + if s.image != defaultSandboxImage { + t.Fatalf("expected default image, got %s", s.image) + } + if s.securityProfile != "strict" { + t.Fatalf("expected strict profile, got %s", s.securityProfile) + } + if !s.failOnError { + t.Fatal("expected failOnError true by default") + } + if len(s.command) != 2 || s.command[0] != "echo" { + t.Fatalf("unexpected command: %v", s.command) + } +} + +func TestNewSandboxExecStepFactory_CustomImage(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "image": "alpine:3.19", + "command": []any{"ls"}, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.image != "alpine:3.19" { + t.Fatalf("unexpected image: %s", s.image) + } +} + +func TestNewSandboxExecStepFactory_SecurityProfiles(t *testing.T) { + factory := NewSandboxExecStepFactory() + + for _, profile := range []string{"strict", "standard", "permissive"} { + step, err := factory("s", map[string]any{ + "security_profile": profile, + "command": []any{"ls"}, + }, nil) + if err != nil { + t.Fatalf("unexpected error for profile %q: %v", profile, err) + } + s := step.(*SandboxExecStep) + if s.securityProfile != profile { + t.Fatalf("expected profile %q, got %q", profile, s.securityProfile) + } + } +} + +func TestNewSandboxExecStepFactory_InvalidProfile(t *testing.T) { + factory := NewSandboxExecStepFactory() + _, err := factory("s", map[string]any{ + "security_profile": "unknown", + "command": []any{"ls"}, + }, nil) + if err == nil { + t.Fatal("expected error for invalid security_profile") + } +} + +func TestNewSandboxExecStepFactory_MemoryLimit(t *testing.T) { + factory := NewSandboxExecStepFactory() + tests := []struct { + input string + expected int64 + }{ + {"128m", 128 * 1024 * 1024}, + {"256M", 256 * 1024 * 1024}, + {"1g", 1024 * 1024 * 1024}, + {"512k", 512 * 1024}, + } + for _, tt := range tests { + step, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "memory_limit": tt.input, + }, nil) + if err != nil { + t.Fatalf("input %q: unexpected error: %v", tt.input, err) + } + s := step.(*SandboxExecStep) + if s.memoryLimit != tt.expected { + t.Fatalf("input %q: expected %d, got %d", tt.input, tt.expected, s.memoryLimit) + } + } +} + +func TestNewSandboxExecStepFactory_Timeout(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "timeout": "30s", + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.timeout != 30*time.Second { + t.Fatalf("expected 30s, got %s", s.timeout) + } +} + +func TestNewSandboxExecStepFactory_InvalidTimeout(t *testing.T) { + factory := NewSandboxExecStepFactory() + _, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "timeout": "not-a-duration", + }, nil) + if err == nil { + t.Fatal("expected error for invalid timeout") + } +} + +func TestNewSandboxExecStepFactory_FailOnError(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "fail_on_error": false, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.failOnError { + t.Fatal("expected failOnError false") + } +} + +func TestNewSandboxExecStepFactory_EnvAndNetwork(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "command": []any{"env"}, + "env": map[string]any{"FOO": "bar", "NUM": 42}, + "network": "bridge", + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.env["FOO"] != "bar" { + t.Fatalf("unexpected FOO: %s", s.env["FOO"]) + } + if s.env["NUM"] != "42" { + t.Fatalf("unexpected NUM: %s", s.env["NUM"]) + } + if s.network != "bridge" { + t.Fatalf("unexpected network: %s", s.network) + } +} + +func TestSandboxExecStep_Name(t *testing.T) { + s := &SandboxExecStep{name: "my-step"} + if s.Name() != "my-step" { + t.Fatalf("unexpected name: %s", s.Name()) + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Strict(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "strict", + } + cfg := s.buildSandboxConfig() + + if cfg.NetworkMode != "none" { + t.Fatalf("strict: expected network none, got %s", cfg.NetworkMode) + } + if len(cfg.CapDrop) != 1 || cfg.CapDrop[0] != "ALL" { + t.Fatalf("strict: expected CapDrop ALL, got %v", cfg.CapDrop) + } + if !cfg.NoNewPrivileges { + t.Fatal("strict: expected NoNewPrivileges true") + } + if !cfg.ReadOnlyRootfs { + t.Fatal("strict: expected ReadOnlyRootfs true") + } + if cfg.PidsLimit != 64 { + t.Fatalf("strict: expected PidsLimit 64, got %d", cfg.PidsLimit) + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Standard(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "standard", + } + cfg := s.buildSandboxConfig() + + if cfg.NetworkMode != "bridge" { + t.Fatalf("standard: expected network bridge, got %s", cfg.NetworkMode) + } + if len(cfg.CapAdd) == 0 { + t.Fatal("standard: expected NET_BIND_SERVICE in CapAdd") + } + if !cfg.NoNewPrivileges { + t.Fatal("standard: expected NoNewPrivileges true") + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Permissive(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "permissive", + } + cfg := s.buildSandboxConfig() + + if cfg.NetworkMode != "bridge" { + t.Fatalf("permissive: expected network bridge, got %s", cfg.NetworkMode) + } + if len(cfg.CapDrop) > 0 { + t.Fatalf("permissive: expected no CapDrop, got %v", cfg.CapDrop) + } + if cfg.ReadOnlyRootfs { + t.Fatal("permissive: expected ReadOnlyRootfs false") + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Overrides(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "strict", + memoryLimit: 512 * 1024 * 1024, + cpuLimit: 2.0, + timeout: 10 * time.Second, + network: "bridge", + } + cfg := s.buildSandboxConfig() + + if cfg.MemoryLimit != 512*1024*1024 { + t.Fatalf("unexpected MemoryLimit: %d", cfg.MemoryLimit) + } + if cfg.CPULimit != 2.0 { + t.Fatalf("unexpected CPULimit: %f", cfg.CPULimit) + } + if cfg.Timeout != 10*time.Second { + t.Fatalf("unexpected Timeout: %s", cfg.Timeout) + } + if cfg.NetworkMode != "bridge" { + t.Fatalf("unexpected NetworkMode: %s", cfg.NetworkMode) + } +} + +func TestParseMemoryLimit(t *testing.T) { + tests := []struct { + input string + expected int64 + wantErr bool + }{ + {"128m", 128 * 1024 * 1024, false}, + {"256M", 256 * 1024 * 1024, false}, + {"1g", 1024 * 1024 * 1024, false}, + {"2G", 2 * 1024 * 1024 * 1024, false}, + {"512k", 512 * 1024, false}, + {"1024K", 1024 * 1024, false}, + {"1024", 1024, false}, + {"1024b", 1024, false}, + {"", 0, true}, + {"abc", 0, true}, + } + + for _, tt := range tests { + got, err := parseMemoryLimit(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("input %q: expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Fatalf("input %q: unexpected error: %v", tt.input, err) + } + if got != tt.expected { + t.Fatalf("input %q: expected %d, got %d", tt.input, tt.expected, got) + } + } +} + +func TestNewSandboxExecStepFactory_InvalidCommandType(t *testing.T) { + factory := NewSandboxExecStepFactory() + _, err := factory("s", map[string]any{ + "command": "should-be-a-list", + }, nil) + if err == nil { + t.Fatal("expected error for non-list command") + } +} diff --git a/module/pipeline_step_secret_rotate.go b/module/pipeline_step_secret_rotate.go new file mode 100644 index 00000000..473c89cd --- /dev/null +++ b/module/pipeline_step_secret_rotate.go @@ -0,0 +1,83 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/secrets" +) + +// SecretRotateStep rotates a secret in a RotationProvider and returns the new value. +type SecretRotateStep struct { + name string + provider string // service name of the secrets RotationProvider module + key string // the secret key to rotate + notifyModule string // optional module name to notify of rotation + app modular.Application +} + +// NewSecretRotateStepFactory returns a StepFactory for step.secret_rotate. +func NewSecretRotateStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + provider, _ := config["provider"].(string) + if provider == "" { + return nil, fmt.Errorf("secret_rotate step %q: 'provider' is required", name) + } + + key, _ := config["key"].(string) + if key == "" { + return nil, fmt.Errorf("secret_rotate step %q: 'key' is required", name) + } + + notifyModule, _ := config["notify_module"].(string) + + return &SecretRotateStep{ + name: name, + provider: provider, + key: key, + notifyModule: notifyModule, + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *SecretRotateStep) Name() string { return s.name } + +// Execute rotates the secret by calling RotationProvider.Rotate and returns output +// indicating the rotation was successful. +func (s *SecretRotateStep) Execute(ctx context.Context, _ *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("secret_rotate step %q: no application context", s.name) + } + + rp, err := s.resolveProvider() + if err != nil { + return nil, err + } + + if _, err := rp.Rotate(ctx, s.key); err != nil { + return nil, fmt.Errorf("secret_rotate step %q: rotate failed: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{ + "rotated": true, + "key": s.key, + "provider": s.provider, + }}, nil +} + +// resolveProvider looks up the secrets provider from the service registry and +// asserts it implements secrets.RotationProvider. +func (s *SecretRotateStep) resolveProvider() (secrets.RotationProvider, error) { + svc, ok := s.app.SvcRegistry()[s.provider] + if !ok { + return nil, fmt.Errorf("secret_rotate step %q: provider service %q not found", s.name, s.provider) + } + rp, ok := svc.(secrets.RotationProvider) + if !ok { + return nil, fmt.Errorf("secret_rotate step %q: service %q does not implement secrets.RotationProvider", s.name, s.provider) + } + return rp, nil +} diff --git a/module/pipeline_step_secret_rotate_test.go b/module/pipeline_step_secret_rotate_test.go new file mode 100644 index 00000000..11b7c63d --- /dev/null +++ b/module/pipeline_step_secret_rotate_test.go @@ -0,0 +1,184 @@ +package module + +import ( + "context" + "errors" + "testing" + + "github.com/GoCodeAlone/workflow/secrets" +) + +// mockRotationProvider is a mock secrets.RotationProvider for testing. +type mockRotationProvider struct { + rotateVal string + rotateErr error + prevVal string + prevErr error +} + +func (m *mockRotationProvider) Name() string { return "mock" } +func (m *mockRotationProvider) Get(_ context.Context, _ string) (string, error) { + return "", nil +} +func (m *mockRotationProvider) Set(_ context.Context, _, _ string) error { return nil } +func (m *mockRotationProvider) Delete(_ context.Context, _ string) error { return nil } +func (m *mockRotationProvider) List(_ context.Context) ([]string, error) { return nil, nil } +func (m *mockRotationProvider) Rotate(_ context.Context, _ string) (string, error) { + return m.rotateVal, m.rotateErr +} +func (m *mockRotationProvider) GetPrevious(_ context.Context, _ string) (string, error) { + return m.prevVal, m.prevErr +} + +// Compile-time check that mockRotationProvider satisfies secrets.RotationProvider. +var _ secrets.RotationProvider = (*mockRotationProvider)(nil) + +// ---- factory validation tests ---- + +func TestSecretRotateStep_MissingProvider(t *testing.T) { + factory := NewSecretRotateStepFactory() + _, err := factory("rotate-step", map[string]any{ + "key": "myapp/db-pass", + }, nil) + if err == nil { + t.Fatal("expected error when 'provider' is missing") + } +} + +func TestSecretRotateStep_MissingKey(t *testing.T) { + factory := NewSecretRotateStepFactory() + _, err := factory("rotate-step", map[string]any{ + "provider": "vault", + }, nil) + if err == nil { + t.Fatal("expected error when 'key' is missing") + } +} + +func TestSecretRotateStep_ValidConfig(t *testing.T) { + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + "notify_module": "slack-notifier", + }, nil) + if err != nil { + t.Fatalf("unexpected factory error: %v", err) + } + if step.Name() != "rotate-step" { + t.Errorf("expected name 'rotate-step', got %q", step.Name()) + } +} + +// ---- Execute tests ---- + +func TestSecretRotateStep_Execute_Success(t *testing.T) { + mock := &mockRotationProvider{rotateVal: "new-secret-abc123"} + app := NewMockApplication() + app.Services["vault"] = mock + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + + if result.Output["rotated"] != true { + t.Errorf("expected rotated=true, got %v", result.Output["rotated"]) + } + if result.Output["key"] != "myapp/db-pass" { + t.Errorf("expected key='myapp/db-pass', got %v", result.Output["key"]) + } + if result.Output["provider"] != "vault" { + t.Errorf("expected provider='vault', got %v", result.Output["provider"]) + } +} + +func TestSecretRotateStep_Execute_RotateError(t *testing.T) { + mock := &mockRotationProvider{rotateErr: errors.New("vault unavailable")} + app := NewMockApplication() + app.Services["vault"] = mock + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error from Rotate failure") + } +} + +func TestSecretRotateStep_Execute_ProviderNotFound(t *testing.T) { + app := NewMockApplication() + // No services registered. + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for missing provider service") + } +} + +func TestSecretRotateStep_Execute_ProviderWrongType(t *testing.T) { + app := NewMockApplication() + // Register something that doesn't implement RotationProvider. + app.Services["vault"] = "not-a-rotation-provider" + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when service does not implement RotationProvider") + } +} + +func TestSecretRotateStep_Execute_NoApp(t *testing.T) { + factory := NewSecretRotateStepFactory() + // Pass nil app at factory time is allowed; error comes at Execute time. + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when app is nil") + } +} diff --git a/module/pipeline_step_token_revoke.go b/module/pipeline_step_token_revoke.go new file mode 100644 index 00000000..2ff7504c --- /dev/null +++ b/module/pipeline_step_token_revoke.go @@ -0,0 +1,95 @@ +package module + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/golang-jwt/jwt/v5" +) + +// TokenRevokeStep extracts a JWT from the pipeline context, reads its jti and +// exp claims without signature validation, and adds the JTI to the configured +// token blacklist module. +type TokenRevokeStep struct { + name string + blacklistModule string // service name of the TokenBlacklist module + tokenSource string // dot-path to the token in pipeline context + app modular.Application +} + +// NewTokenRevokeStepFactory returns a StepFactory for step.token_revoke. +func NewTokenRevokeStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + blacklistModule, _ := config["blacklist_module"].(string) + if blacklistModule == "" { + return nil, fmt.Errorf("token_revoke step %q: 'blacklist_module' is required", name) + } + + tokenSource, _ := config["token_source"].(string) + if tokenSource == "" { + return nil, fmt.Errorf("token_revoke step %q: 'token_source' is required", name) + } + + return &TokenRevokeStep{ + name: name, + blacklistModule: blacklistModule, + tokenSource: tokenSource, + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *TokenRevokeStep) Name() string { return s.name } + +// Execute revokes the JWT by extracting its JTI and adding it to the blacklist. +func (s *TokenRevokeStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { + // 1. Extract token string from pipeline context. + rawToken := resolveBodyFrom(s.tokenSource, pc) + tokenStr, _ := rawToken.(string) + if tokenStr == "" { + return &StepResult{Output: map[string]any{"revoked": false, "error": "missing token"}}, nil + } + + // 2. Strip "Bearer " prefix if present. + tokenStr = strings.TrimPrefix(tokenStr, "Bearer ") + + // 3. Parse claims without signature validation (token is being revoked, not authenticated). + parser := jwt.NewParser() + token, _, parseErr := parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if parseErr != nil { + return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid token format"}}, parseErr + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid claims"}}, fmt.Errorf("invalid JWT claims type") + } + + jti, _ := claims["jti"].(string) + if jti == "" { + return &StepResult{Output: map[string]any{"revoked": false, "error": "token has no jti claim"}}, nil + } + + // 4. Determine token expiry from exp claim. + var expiresAt time.Time + switch exp := claims["exp"].(type) { + case float64: + expiresAt = time.Unix(int64(exp), 0) + default: + expiresAt = time.Now().Add(24 * time.Hour) // safe fallback + } + + // 5. Resolve the blacklist module and add the JTI. + var blacklist TokenBlacklist + if err := s.app.GetService(s.blacklistModule, &blacklist); err != nil { + return nil, fmt.Errorf("token_revoke step %q: blacklist module %q not found: %w", s.name, s.blacklistModule, err) + } + + blacklist.Add(jti, expiresAt) + + return &StepResult{Output: map[string]any{"revoked": true, "jti": jti}}, nil +} diff --git a/module/pipeline_step_token_revoke_test.go b/module/pipeline_step_token_revoke_test.go new file mode 100644 index 00000000..b3b79a84 --- /dev/null +++ b/module/pipeline_step_token_revoke_test.go @@ -0,0 +1,299 @@ +package module + +import ( + "context" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// makeTestJWT creates a signed JWT with the given claims using HS256. +func makeTestJWT(t *testing.T, secret string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + s, err := tok.SignedString([]byte(secret)) + if err != nil { + t.Fatalf("makeTestJWT: %v", err) + } + return s +} + +// mockBlacklist is a simple in-memory TokenBlacklist for testing. +type mockBlacklist struct { + entries map[string]time.Time +} + +func newMockBlacklist() *mockBlacklist { + return &mockBlacklist{entries: make(map[string]time.Time)} +} + +func (m *mockBlacklist) Add(jti string, expiresAt time.Time) { + m.entries[jti] = expiresAt +} + +func (m *mockBlacklist) IsBlacklisted(jti string) bool { + exp, ok := m.entries[jti] + return ok && time.Now().Before(exp) +} + +func newTokenRevokeApp(blName string, bl TokenBlacklist) *MockApplication { + app := NewMockApplication() + app.Services[blName] = bl + return app +} + +func TestTokenRevokeStep_RevokesToken(t *testing.T) { + bl := newMockBlacklist() + app := newTokenRevokeApp("my-blacklist", bl) + + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "my-blacklist", + "token_source": "steps.parse.authorization", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "jti": "test-jti-123", + "sub": "user-1", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "authorization": "Bearer " + tokenStr, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != true { + t.Errorf("expected revoked=true, got %v", result.Output["revoked"]) + } + if result.Output["jti"] != "test-jti-123" { + t.Errorf("expected jti=test-jti-123, got %v", result.Output["jti"]) + } + if !bl.IsBlacklisted("test-jti-123") { + t.Error("expected test-jti-123 to be in blacklist after revoke") + } +} + +func TestTokenRevokeStep_NoBearerPrefix(t *testing.T) { + bl := newMockBlacklist() + app := newTokenRevokeApp("bl", bl) + + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + // Token without Bearer prefix. + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "jti": "bare-jti", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{"token": tokenStr}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != true { + t.Errorf("expected revoked=true, got %v", result.Output["revoked"]) + } + if !bl.IsBlacklisted("bare-jti") { + t.Error("expected bare-jti to be blacklisted") + } +} + +func TestTokenRevokeStep_MissingToken(t *testing.T) { + factory := NewTokenRevokeStepFactory() + app := newTokenRevokeApp("bl", newMockBlacklist()) + step, err := factory("revoke", map[string]any{ + "blacklist_module": "bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != false { + t.Errorf("expected revoked=false for missing token, got %v", result.Output["revoked"]) + } +} + +func TestTokenRevokeStep_NoJTIClaim(t *testing.T) { + bl := newMockBlacklist() + app := newTokenRevokeApp("bl", bl) + + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + // Token without jti claim. + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "sub": "user-1", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{"token": tokenStr}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != false { + t.Errorf("expected revoked=false for token without jti, got %v", result.Output["revoked"]) + } + if result.Output["error"] != "token has no jti claim" { + t.Errorf("expected 'token has no jti claim' error, got %v", result.Output["error"]) + } +} + +func TestTokenRevokeStep_FactoryMissingBlacklistModule(t *testing.T) { + factory := NewTokenRevokeStepFactory() + _, err := factory("revoke", map[string]any{"token_source": "token"}, nil) + if err == nil { + t.Fatal("expected error for missing blacklist_module") + } +} + +func TestTokenRevokeStep_FactoryMissingTokenSource(t *testing.T) { + factory := NewTokenRevokeStepFactory() + _, err := factory("revoke", map[string]any{"blacklist_module": "bl"}, nil) + if err == nil { + t.Fatal("expected error for missing token_source") + } +} + +func TestTokenRevokeStep_BlacklistModuleNotFound(t *testing.T) { + app := NewMockApplication() // no services registered + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "missing-bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "jti": "some-jti", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{"token": tokenStr}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when blacklist module is not registered") + } +} + +// --- Integration test --- + +// TestBlacklistedTokenFailsAuth tests the full flow: issue token, revoke it, +// then verify that JWTAuthModule rejects it. +func TestBlacklistedTokenFailsAuth(t *testing.T) { + secret := "a-very-long-secret-key-for-testing-purposes-1234" + jwtMod := NewJWTAuthModule("auth", secret, time.Hour, "test") + + app := NewMockApplication() + if err := jwtMod.Init(app); err != nil { + t.Fatalf("JWTAuthModule.Init: %v", err) + } + + // Issue a token. + user := &User{ID: "1", Email: "test@example.com", Name: "Test"} + tokenStr, err := jwtMod.generateToken(user) + if err != nil { + t.Fatalf("generateToken: %v", err) + } + + // Token should be valid before revocation. + valid, claims, err := jwtMod.Authenticate(tokenStr) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if !valid { + t.Fatal("expected token to be valid before revocation") + } + + jti, ok := claims["jti"].(string) + if !ok || jti == "" { + t.Fatal("expected jti claim in token") + } + + // Wire a blacklist and revoke the token. + bl := NewTokenBlacklistModule("bl", "memory", "", time.Minute) + jwtMod.SetTokenBlacklist(bl) + bl.Add(jti, time.Now().Add(time.Hour)) + + // Token should now be rejected. + valid, _, err = jwtMod.Authenticate(tokenStr) + if err != nil { + t.Fatalf("Authenticate after revocation: %v", err) + } + if valid { + t.Fatal("expected token to be rejected after revocation") + } +} + +// TestNonRevokedTokenStillValid ensures that revoking one token does not +// affect other valid tokens. +func TestNonRevokedTokenStillValid(t *testing.T) { + secret := "a-very-long-secret-key-for-testing-purposes-5678" + jwtMod := NewJWTAuthModule("auth", secret, time.Hour, "test") + app := NewMockApplication() + if err := jwtMod.Init(app); err != nil { + t.Fatalf("JWTAuthModule.Init: %v", err) + } + + bl := NewTokenBlacklistModule("bl", "memory", "", time.Minute) + jwtMod.SetTokenBlacklist(bl) + + user := &User{ID: "2", Email: "other@example.com", Name: "Other"} + tokenStr, err := jwtMod.generateToken(user) + if err != nil { + t.Fatalf("generateToken: %v", err) + } + + // Revoke a *different* JTI. + bl.Add("some-other-jti", time.Now().Add(time.Hour)) + + valid, _, err := jwtMod.Authenticate(tokenStr) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if !valid { + t.Fatal("expected un-revoked token to remain valid") + } +} + +// Compile-time check: TokenBlacklistModule satisfies TokenBlacklist. +var _ TokenBlacklist = (*TokenBlacklistModule)(nil) + +// Compile-time check: mockBlacklist satisfies TokenBlacklist. +var _ TokenBlacklist = (*mockBlacklist)(nil) + diff --git a/module/tls_config_test.go b/module/tls_config_test.go new file mode 100644 index 00000000..899fc480 --- /dev/null +++ b/module/tls_config_test.go @@ -0,0 +1,152 @@ +package module + +import ( + "testing" + + "github.com/GoCodeAlone/workflow/pkg/tlsutil" +) + +func TestRedisCacheConfig_TLSField(t *testing.T) { + cfg := RedisCacheConfig{ + Address: "localhost:6380", + TLS: tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + }, + } + if !cfg.TLS.Enabled { + t.Error("expected TLS.Enabled to be true") + } +} + +func TestKafkaBroker_SetTLSConfig(t *testing.T) { + broker := NewKafkaBroker("test-kafka") + broker.SetTLSConfig(KafkaTLSConfig{ + TLSConfig: tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + }, + SASL: KafkaSASLConfig{ + Mechanism: "PLAIN", + Username: "user", + Password: "pass", + }, + }) + + broker.mu.RLock() + defer broker.mu.RUnlock() + if !broker.tlsCfg.Enabled { + t.Error("expected TLS enabled") + } + if broker.tlsCfg.SASL.Mechanism != "PLAIN" { + t.Errorf("expected PLAIN, got %s", broker.tlsCfg.SASL.Mechanism) + } +} + +func TestNATSBroker_SetTLSConfig(t *testing.T) { + broker := NewNATSBroker("test-nats") + broker.SetTLSConfig(tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + }) + + broker.mu.RLock() + defer broker.mu.RUnlock() + if !broker.tlsCfg.Enabled { + t.Error("expected TLS enabled") + } +} + +func TestHTTPServer_SetTLSConfig(t *testing.T) { + srv := NewStandardHTTPServer("test", ":8443") + srv.SetTLSConfig(HTTPServerTLSConfig{ + Mode: "manual", + Manual: tlsutil.TLSConfig{ + CertFile: "/tmp/cert.pem", + KeyFile: "/tmp/key.pem", + }, + }) + if srv.tlsCfg.Mode != "manual" { + t.Errorf("expected mode 'manual', got %q", srv.tlsCfg.Mode) + } +} + +func TestDatabaseConfig_TLSField(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb", + TLS: DatabaseTLSConfig{ + Mode: "verify-full", + CAFile: "/etc/ssl/ca.pem", + }, + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + + if dsn == cfg.DSN { + t.Error("expected DSN to be modified with TLS parameters") + } + if !contains(dsn, "sslmode=verify-full") { + t.Errorf("expected sslmode in DSN, got %s", dsn) + } + if !contains(dsn, "sslrootcert=/etc/ssl/ca.pem") { + t.Errorf("expected sslrootcert in DSN, got %s", dsn) + } +} + +func TestDatabaseConfig_TLSDisabled(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb", + TLS: DatabaseTLSConfig{Mode: "disable"}, + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + if dsn != cfg.DSN { + t.Errorf("expected unchanged DSN when TLS disabled, got %s", dsn) + } +} + +func TestDatabaseConfig_TLSDefault(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb", + // TLS field zero value: Mode="" + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + if dsn != cfg.DSN { + t.Errorf("expected unchanged DSN when TLS not configured, got %s", dsn) + } +} + +func TestDatabaseConfig_TLS_ExistingQueryString(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb?connect_timeout=10", + TLS: DatabaseTLSConfig{Mode: "require"}, + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + if !contains(dsn, "&sslmode=require") { + t.Errorf("expected & separator when query string exists, got %s", dsn) + } +} + +// contains is a simple substring check helper. +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && searchSubstring(s, sub)) +} + +func searchSubstring(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/pkg/fieldcrypt/encrypt.go b/pkg/fieldcrypt/encrypt.go new file mode 100644 index 00000000..a7bb6e9d --- /dev/null +++ b/pkg/fieldcrypt/encrypt.go @@ -0,0 +1,118 @@ +package fieldcrypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "strconv" + "strings" +) + +// Prefix is the marker for encrypted protected field values. +const Prefix = "epf:" + +// legacyPrefix is the old encryption prefix from the original FieldEncryptor. +const legacyPrefix = "enc::" + +// Encrypt encrypts plaintext with AES-256-GCM, returning "epf:v{version}:{base64(nonce+ciphertext)}". +func Encrypt(plaintext string, key []byte, version int) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create GCM: %w", err) + } + nonce := make([]byte, aead.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("fieldcrypt: generate nonce: %w", err) + } + ciphertext := aead.Seal(nonce, nonce, []byte(plaintext), nil) + encoded := base64.StdEncoding.EncodeToString(ciphertext) + return fmt.Sprintf("%sv%d:%s", Prefix, version, encoded), nil +} + +// Decrypt decrypts an epf:-prefixed value. It also handles legacy "enc::" prefix values. +// keyFn is called with the version number extracted from the prefix. +// For legacy enc:: values, keyFn(0) is called to obtain the raw master key, +// which is then SHA256-hashed to match the original FieldEncryptor behavior. +func Decrypt(ciphertext string, keyFn func(version int) ([]byte, error)) (string, error) { + if strings.HasPrefix(ciphertext, legacyPrefix) { + return decryptLegacy(ciphertext, keyFn) + } + if !strings.HasPrefix(ciphertext, Prefix) { + return "", fmt.Errorf("fieldcrypt: not an encrypted value") + } + + // Parse "epf:v{version}:{base64}" + rest := strings.TrimPrefix(ciphertext, Prefix) + idx := strings.Index(rest, ":") + if idx < 0 || !strings.HasPrefix(rest, "v") { + return "", fmt.Errorf("fieldcrypt: invalid format") + } + version, err := strconv.Atoi(rest[1:idx]) + if err != nil { + return "", fmt.Errorf("fieldcrypt: invalid version: %w", err) + } + encoded := rest[idx+1:] + + key, err := keyFn(version) + if err != nil { + return "", fmt.Errorf("fieldcrypt: key lookup: %w", err) + } + + raw, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", fmt.Errorf("fieldcrypt: base64 decode: %w", err) + } + + return decryptAESGCM(raw, key) +} + +// IsEncrypted checks if a value has the epf: prefix or legacy enc:: prefix. +func IsEncrypted(value string) bool { + return strings.HasPrefix(value, Prefix) || strings.HasPrefix(value, legacyPrefix) +} + +// decryptLegacy handles the old "enc::" prefix format. +func decryptLegacy(ciphertext string, keyFn func(version int) ([]byte, error)) (string, error) { + raw := strings.TrimPrefix(ciphertext, legacyPrefix) + decoded, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return "", fmt.Errorf("fieldcrypt: legacy base64 decode: %w", err) + } + // keyFn(0) returns the raw master key; hash it with SHA256 to match original behavior. + masterKey, err := keyFn(0) + if err != nil { + return "", fmt.Errorf("fieldcrypt: legacy key lookup: %w", err) + } + hash := sha256.Sum256(masterKey) + return decryptAESGCM(decoded, hash[:]) +} + +// decryptAESGCM decrypts raw bytes (nonce + ciphertext) with the given AES-256-GCM key. +func decryptAESGCM(raw, key []byte) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create GCM: %w", err) + } + nonceSize := aead.NonceSize() + if len(raw) < nonceSize { + return "", fmt.Errorf("fieldcrypt: ciphertext too short") + } + nonce, ct := raw[:nonceSize], raw[nonceSize:] + plaintext, err := aead.Open(nil, nonce, ct, nil) + if err != nil { + return "", fmt.Errorf("fieldcrypt: decryption failed: %w", err) + } + return string(plaintext), nil +} diff --git a/pkg/fieldcrypt/fieldcrypt.go b/pkg/fieldcrypt/fieldcrypt.go new file mode 100644 index 00000000..9e47ef9f --- /dev/null +++ b/pkg/fieldcrypt/fieldcrypt.go @@ -0,0 +1,71 @@ +package fieldcrypt + +// FieldClassification defines the sensitivity level. +type FieldClassification string + +const ( + ClassPII FieldClassification = "pii" + ClassPHI FieldClassification = "phi" +) + +// LogBehavior defines how a field appears in logs. +type LogBehavior string + +const ( + LogMask LogBehavior = "mask" + LogRedact LogBehavior = "redact" + LogHash LogBehavior = "hash" + LogAllow LogBehavior = "allow" +) + +// ProtectedField defines a field that requires encryption/masking. +type ProtectedField struct { + Name string `yaml:"name"` + Classification FieldClassification `yaml:"classification"` + Encryption bool `yaml:"encryption"` + LogBehavior LogBehavior `yaml:"log_behavior"` + MaskPattern string `yaml:"mask_pattern"` +} + +// Registry holds the set of protected fields for lookup. +type Registry struct { + fields map[string]ProtectedField +} + +// NewRegistry creates a Registry from a slice of ProtectedField definitions. +func NewRegistry(fields []ProtectedField) *Registry { + m := make(map[string]ProtectedField, len(fields)) + for _, f := range fields { + m[f.Name] = f + } + return &Registry{fields: m} +} + +// IsProtected returns true if the given field name is in the registry. +func (r *Registry) IsProtected(fieldName string) bool { + _, ok := r.fields[fieldName] + return ok +} + +// GetField returns the ProtectedField definition for the given name. +func (r *Registry) GetField(fieldName string) (*ProtectedField, bool) { + f, ok := r.fields[fieldName] + if !ok { + return nil, false + } + return &f, true +} + +// Len returns the number of registered protected fields. +func (r *Registry) Len() int { + return len(r.fields) +} + +// ProtectedFields returns all registered protected fields. +func (r *Registry) ProtectedFields() []ProtectedField { + out := make([]ProtectedField, 0, len(r.fields)) + for _, f := range r.fields { + out = append(out, f) + } + return out +} diff --git a/pkg/fieldcrypt/fieldcrypt_test.go b/pkg/fieldcrypt/fieldcrypt_test.go new file mode 100644 index 00000000..d40152ba --- /dev/null +++ b/pkg/fieldcrypt/fieldcrypt_test.go @@ -0,0 +1,377 @@ +package fieldcrypt + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "io" + "strings" + "testing" +) + +func TestEncryptDecryptRoundTrip(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + plaintext := "sensitive data here" + encrypted, err := Encrypt(plaintext, key, 1) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + if !strings.HasPrefix(encrypted, "epf:v1:") { + t.Fatalf("expected epf:v1: prefix, got %q", encrypted) + } + + decrypted, err := Decrypt(encrypted, func(version int) ([]byte, error) { + if version != 1 { + t.Fatalf("expected version 1, got %d", version) + } + return key, nil + }) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != plaintext { + t.Fatalf("expected %q, got %q", plaintext, decrypted) + } +} + +func TestEncryptDecryptVersion(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + encrypted, err := Encrypt("hello", key, 42) + if err != nil { + t.Fatal(err) + } + if !strings.HasPrefix(encrypted, "epf:v42:") { + t.Fatalf("expected epf:v42: prefix, got %q", encrypted) + } + + decrypted, err := Decrypt(encrypted, func(version int) ([]byte, error) { + if version != 42 { + t.Fatalf("expected version 42, got %d", version) + } + return key, nil + }) + if err != nil { + t.Fatal(err) + } + if decrypted != "hello" { + t.Fatalf("expected hello, got %q", decrypted) + } +} + +func TestIsEncrypted(t *testing.T) { + tests := []struct { + value string + want bool + }{ + {"epf:v1:abc", true}, + {"enc::abc", true}, + {"plaintext", false}, + {"", false}, + } + for _, tt := range tests { + if got := IsEncrypted(tt.value); got != tt.want { + t.Errorf("IsEncrypted(%q) = %v, want %v", tt.value, got, tt.want) + } + } +} + +func TestLegacyEncDecrypt(t *testing.T) { + // Simulate the old FieldEncryptor: SHA256(masterKey) -> AES-256-GCM. + masterKey := []byte("my-secret-key") + hash := sha256.Sum256(masterKey) + + block, err := aes.NewCipher(hash[:]) + if err != nil { + t.Fatal(err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + t.Fatal(err) + } + nonce := make([]byte, aead.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + t.Fatal(err) + } + ct := aead.Seal(nonce, nonce, []byte("legacy secret"), nil) + legacyEncoded := "enc::" + base64.StdEncoding.EncodeToString(ct) + + decrypted, err := Decrypt(legacyEncoded, func(version int) ([]byte, error) { + if version != 0 { + t.Fatalf("expected version 0 for legacy, got %d", version) + } + return masterKey, nil + }) + if err != nil { + t.Fatalf("Decrypt legacy: %v", err) + } + if decrypted != "legacy secret" { + t.Fatalf("expected 'legacy secret', got %q", decrypted) + } +} + +func TestMaskEmail(t *testing.T) { + got := MaskEmail("john@example.com") + if got != "j***@e***.com" { + t.Errorf("MaskEmail = %q, want %q", got, "j***@e***.com") + } +} + +func TestMaskPhone(t *testing.T) { + got := MaskPhone("555-123-4567") + if got != "***-***-4567" { + t.Errorf("MaskPhone = %q, want %q", got, "***-***-4567") + } +} + +func TestHashValue(t *testing.T) { + h := HashValue("test") + if len(h) != 64 { + t.Errorf("expected 64 char hex, got %d chars", len(h)) + } +} + +func TestRedactValue(t *testing.T) { + if got := RedactValue(); got != "[REDACTED]" { + t.Errorf("RedactValue = %q", got) + } +} + +func TestMaskValueBehaviors(t *testing.T) { + if MaskValue("secret", LogRedact, "") != "[REDACTED]" { + t.Error("LogRedact failed") + } + if MaskValue("secret", LogAllow, "") != "secret" { + t.Error("LogAllow failed") + } + h := MaskValue("secret", LogHash, "") + if len(h) != 64 { + t.Error("LogHash failed") + } +} + +func TestScanAndEncryptDecrypt(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + registry := NewRegistry([]ProtectedField{ + {Name: "ssn", Classification: ClassPII, Encryption: true, LogBehavior: LogRedact}, + {Name: "email", Classification: ClassPII, Encryption: true, LogBehavior: LogMask}, + }) + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "test@example.com", + "name": "John", + "nested": map[string]any{ + "ssn": "987-65-4321", + }, + "items": []any{ + map[string]any{"email": "a@b.com"}, + }, + } + + err := ScanAndEncrypt(data, registry, func() ([]byte, int, error) { + return key, 1, nil + }, 10) + if err != nil { + t.Fatalf("ScanAndEncrypt: %v", err) + } + + // Verify fields are encrypted. + if !IsEncrypted(data["ssn"].(string)) { + t.Error("ssn should be encrypted") + } + if !IsEncrypted(data["email"].(string)) { + t.Error("email should be encrypted") + } + if data["name"] != "John" { + t.Error("name should not be modified") + } + nested := data["nested"].(map[string]any) + if !IsEncrypted(nested["ssn"].(string)) { + t.Error("nested ssn should be encrypted") + } + items := data["items"].([]any) + item := items[0].(map[string]any) + if !IsEncrypted(item["email"].(string)) { + t.Error("array item email should be encrypted") + } + + // Now decrypt. + err = ScanAndDecrypt(data, registry, func(version int) ([]byte, error) { + return key, nil + }, 10) + if err != nil { + t.Fatalf("ScanAndDecrypt: %v", err) + } + + if data["ssn"] != "123-45-6789" { + t.Errorf("ssn = %q, want 123-45-6789", data["ssn"]) + } + if data["email"] != "test@example.com" { + t.Errorf("email = %q, want test@example.com", data["email"]) + } + if nested["ssn"] != "987-65-4321" { + t.Errorf("nested ssn = %q", nested["ssn"]) + } + item = items[0].(map[string]any) + if item["email"] != "a@b.com" { + t.Errorf("array item email = %q", item["email"]) + } +} + +func TestScanAndMask(t *testing.T) { + registry := NewRegistry([]ProtectedField{ + {Name: "ssn", Classification: ClassPII, LogBehavior: LogRedact}, + {Name: "email", Classification: ClassPII, LogBehavior: LogMask}, + }) + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "test@example.com", + "name": "John", + } + + masked := ScanAndMask(data, registry, 10) + + if masked["ssn"] != "[REDACTED]" { + t.Errorf("ssn mask = %q", masked["ssn"]) + } + if masked["email"] == "test@example.com" { + t.Error("email should be masked") + } + if masked["name"] != "John" { + t.Error("name should be unchanged") + } + + // Original should be unmodified. + if data["ssn"] != "123-45-6789" { + t.Error("original ssn was modified") + } +} + +func TestKeyRingTenantIsolation(t *testing.T) { + masterKey := make([]byte, 32) + if _, err := rand.Read(masterKey); err != nil { + t.Fatal(err) + } + + kr := NewLocalKeyRing(masterKey) + ctx := context.Background() + + keyA, verA, err := kr.CurrentKey(ctx, "tenant-a") + if err != nil { + t.Fatal(err) + } + if verA != 1 { + t.Fatalf("expected version 1, got %d", verA) + } + + keyB, verB, err := kr.CurrentKey(ctx, "tenant-b") + if err != nil { + t.Fatal(err) + } + if verB != 1 { + t.Fatalf("expected version 1, got %d", verB) + } + + // Keys should be different for different tenants. + if string(keyA) == string(keyB) { + t.Error("tenant keys should differ") + } +} + +func TestKeyRingRotation(t *testing.T) { + masterKey := make([]byte, 32) + if _, err := rand.Read(masterKey); err != nil { + t.Fatal(err) + } + + kr := NewLocalKeyRing(masterKey) + ctx := context.Background() + + key1, ver1, err := kr.CurrentKey(ctx, "t1") + if err != nil { + t.Fatal(err) + } + + key2, ver2, err := kr.Rotate(ctx, "t1") + if err != nil { + t.Fatal(err) + } + if ver2 != ver1+1 { + t.Fatalf("expected version %d, got %d", ver1+1, ver2) + } + if string(key1) == string(key2) { + t.Error("rotated key should differ") + } + + // Old key should still be retrievable. + oldKey, err := kr.KeyByVersion(ctx, "t1", ver1) + if err != nil { + t.Fatal(err) + } + if string(oldKey) != string(key1) { + t.Error("old key mismatch") + } + + // Current should return new key. + curKey, curVer, err := kr.CurrentKey(ctx, "t1") + if err != nil { + t.Fatal(err) + } + if curVer != ver2 { + t.Fatalf("current version = %d, want %d", curVer, ver2) + } + if string(curKey) != string(key2) { + t.Error("current key mismatch") + } +} + +func TestKeyRingVersionLookup(t *testing.T) { + masterKey := []byte("deterministic-master-key-for-test") + kr := NewLocalKeyRing(masterKey) + ctx := context.Background() + + // Create v1 by calling CurrentKey. + _, _, err := kr.CurrentKey(ctx, "t") + if err != nil { + t.Fatal(err) + } + + // Rotate to v2. + _, _, err = kr.Rotate(ctx, "t") + if err != nil { + t.Fatal(err) + } + + // Lookup v1 by version. + k1, err := kr.KeyByVersion(ctx, "t", 1) + if err != nil { + t.Fatal(err) + } + + // Lookup v2 by version. + k2, err := kr.KeyByVersion(ctx, "t", 2) + if err != nil { + t.Fatal(err) + } + + if string(k1) == string(k2) { + t.Error("v1 and v2 keys should differ") + } +} diff --git a/pkg/fieldcrypt/keyring.go b/pkg/fieldcrypt/keyring.go new file mode 100644 index 00000000..6a69f193 --- /dev/null +++ b/pkg/fieldcrypt/keyring.go @@ -0,0 +1,122 @@ +package fieldcrypt + +import ( + "context" + "crypto/sha256" + "fmt" + "io" + "sync" + + "golang.org/x/crypto/hkdf" +) + +// KeyRing manages versioned, tenant-isolated encryption keys. +type KeyRing interface { + CurrentKey(ctx context.Context, tenantID string) (key []byte, version int, err error) + KeyByVersion(ctx context.Context, tenantID string, version int) ([]byte, error) + Rotate(ctx context.Context, tenantID string) (key []byte, version int, err error) +} + +// LocalKeyRing stores keys in memory, keyed by tenant. +// Keys are derived from a master key using HKDF. +type LocalKeyRing struct { + masterKey []byte + mu sync.RWMutex + tenantVersions map[string]int // tenantID -> current version number + tenantKeys map[string][]byte // "tenantID:version" -> derived key +} + +// NewLocalKeyRing creates a new LocalKeyRing from a master key. +func NewLocalKeyRing(masterKey []byte) *LocalKeyRing { + return &LocalKeyRing{ + masterKey: masterKey, + tenantVersions: make(map[string]int), + tenantKeys: make(map[string][]byte), + } +} + +// CurrentKey returns the current key version for a tenant. +// If no key exists yet, generates version 1. +func (k *LocalKeyRing) CurrentKey(_ context.Context, tenantID string) ([]byte, int, error) { + k.mu.RLock() + ver, ok := k.tenantVersions[tenantID] + if ok { + cacheKey := fmt.Sprintf("%s:%d", tenantID, ver) + key := k.tenantKeys[cacheKey] + k.mu.RUnlock() + return key, ver, nil + } + k.mu.RUnlock() + + // No key yet; create version 1. + k.mu.Lock() + defer k.mu.Unlock() + + // Double-check after acquiring write lock. + if ver, ok := k.tenantVersions[tenantID]; ok { + cacheKey := fmt.Sprintf("%s:%d", tenantID, ver) + return k.tenantKeys[cacheKey], ver, nil + } + + key, err := k.deriveKey(tenantID, 1) + if err != nil { + return nil, 0, err + } + k.tenantVersions[tenantID] = 1 + k.tenantKeys[fmt.Sprintf("%s:%d", tenantID, 1)] = key + return key, 1, nil +} + +// KeyByVersion returns the key for a specific tenant+version. +func (k *LocalKeyRing) KeyByVersion(_ context.Context, tenantID string, version int) ([]byte, error) { + cacheKey := fmt.Sprintf("%s:%d", tenantID, version) + + k.mu.RLock() + if key, ok := k.tenantKeys[cacheKey]; ok { + k.mu.RUnlock() + return key, nil + } + k.mu.RUnlock() + + k.mu.Lock() + defer k.mu.Unlock() + + // Double-check. + if key, ok := k.tenantKeys[cacheKey]; ok { + return key, nil + } + + key, err := k.deriveKey(tenantID, version) + if err != nil { + return nil, err + } + k.tenantKeys[cacheKey] = key + return key, nil +} + +// Rotate increments the version and derives a new key for the tenant. +func (k *LocalKeyRing) Rotate(_ context.Context, tenantID string) ([]byte, int, error) { + k.mu.Lock() + defer k.mu.Unlock() + + ver := k.tenantVersions[tenantID] + 1 + key, err := k.deriveKey(tenantID, ver) + if err != nil { + return nil, 0, err + } + k.tenantVersions[tenantID] = ver + k.tenantKeys[fmt.Sprintf("%s:%d", tenantID, ver)] = key + return key, ver, nil +} + +// deriveKey uses HKDF-SHA256 to derive a 32-byte key. +// Info = "fieldcrypt:" + tenantID + ":v" + version. +func (k *LocalKeyRing) deriveKey(tenantID string, version int) ([]byte, error) { + info := fmt.Sprintf("fieldcrypt:%s:v%d", tenantID, version) + hkdfReader := hkdf.New(sha256.New, k.masterKey, nil, []byte(info)) + derived := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, derived); err != nil { + return nil, fmt.Errorf("fieldcrypt: key derivation failed: %w", err) + } + return derived, nil +} diff --git a/pkg/fieldcrypt/mask.go b/pkg/fieldcrypt/mask.go new file mode 100644 index 00000000..ca4aa5ca --- /dev/null +++ b/pkg/fieldcrypt/mask.go @@ -0,0 +1,122 @@ +package fieldcrypt + +import ( + "crypto/sha256" + "fmt" + "strings" +) + +// MaskValue applies masking based on LogBehavior and optional pattern. +func MaskValue(value string, behavior LogBehavior, pattern string) string { + switch behavior { + case LogRedact: + return RedactValue() + case LogHash: + return HashValue(value) + case LogAllow: + return value + case LogMask: + if pattern != "" { + return applyPattern(value, pattern) + } + // Auto-detect: try email first, then phone-like, then generic mask. + if strings.Contains(value, "@") { + return MaskEmail(value) + } + if looksLikePhone(value) { + return MaskPhone(value) + } + return genericMask(value) + default: + return RedactValue() + } +} + +// MaskEmail masks an email: "j***@e***.com". +func MaskEmail(email string) string { + at := strings.LastIndex(email, "@") + if at < 0 { + return genericMask(email) + } + local := email[:at] + domain := email[at+1:] + + maskedLocal := maskPart(local) + // Mask domain but keep TLD. + dot := strings.LastIndex(domain, ".") + if dot > 0 { + maskedDomain := maskPart(domain[:dot]) + domain[dot:] + return maskedLocal + "@" + maskedDomain + } + return maskedLocal + "@" + maskPart(domain) +} + +// MaskPhone masks all but last 4 digits: "***-***-1234". +func MaskPhone(phone string) string { + digits := extractDigits(phone) + if len(digits) <= 4 { + return "****" + } + last4 := string(digits[len(digits)-4:]) + return "***-***-" + last4 +} + +// HashValue returns SHA256 hex of the value. +func HashValue(value string) string { + h := sha256.Sum256([]byte(value)) + return fmt.Sprintf("%x", h) +} + +// RedactValue returns "[REDACTED]". +func RedactValue() string { + return "[REDACTED]" +} + +// maskPart keeps first char and replaces rest with ***. +func maskPart(s string) string { + if len(s) == 0 { + return "" + } + return string(s[0]) + "***" +} + +func genericMask(s string) string { + if len(s) <= 1 { + return "***" + } + return string(s[0]) + "***" +} + +func looksLikePhone(s string) bool { + digits := extractDigits(s) + return len(digits) >= 7 && len(digits) <= 15 +} + +func extractDigits(s string) []byte { + var digits []byte + for i := 0; i < len(s); i++ { + if s[i] >= '0' && s[i] <= '9' { + digits = append(digits, s[i]) + } + } + return digits +} + +func applyPattern(value, pattern string) string { + digits := extractDigits(value) + di := 0 + var result []byte + for i := 0; i < len(pattern); i++ { + if pattern[i] == '#' { + if di < len(digits) { + result = append(result, digits[di]) + di++ + } else { + result = append(result, '#') + } + } else { + result = append(result, pattern[i]) + } + } + return string(result) +} diff --git a/pkg/fieldcrypt/scanner.go b/pkg/fieldcrypt/scanner.go new file mode 100644 index 00000000..7c8413e4 --- /dev/null +++ b/pkg/fieldcrypt/scanner.go @@ -0,0 +1,117 @@ +package fieldcrypt + +// ScanAndEncrypt recursively scans a map, encrypting protected fields that have Encryption=true. +// maxDepth limits recursion depth. +func ScanAndEncrypt(data map[string]any, registry *Registry, keyFn func() ([]byte, int, error), maxDepth int) error { + return scanEncrypt(data, registry, keyFn, 0, maxDepth) +} + +func scanEncrypt(data map[string]any, registry *Registry, keyFn func() ([]byte, int, error), depth, maxDepth int) error { + if depth >= maxDepth { + return nil + } + for k, v := range data { + switch val := v.(type) { + case string: + if pf, ok := registry.GetField(k); ok && pf.Encryption && !IsEncrypted(val) { + key, version, err := keyFn() + if err != nil { + return err + } + encrypted, err := Encrypt(val, key, version) + if err != nil { + return err + } + data[k] = encrypted + } + case map[string]any: + if err := scanEncrypt(val, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + case []any: + for _, elem := range val { + if m, ok := elem.(map[string]any); ok { + if err := scanEncrypt(m, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + } + } + } + } + return nil +} + +// ScanAndDecrypt recursively scans a map, decrypting epf:-prefixed (and enc::-prefixed) protected fields. +func ScanAndDecrypt(data map[string]any, registry *Registry, keyFn func(version int) ([]byte, error), maxDepth int) error { + return scanDecrypt(data, registry, keyFn, 0, maxDepth) +} + +func scanDecrypt(data map[string]any, registry *Registry, keyFn func(version int) ([]byte, error), depth, maxDepth int) error { + if depth >= maxDepth { + return nil + } + for k, v := range data { + switch val := v.(type) { + case string: + if registry.IsProtected(k) && IsEncrypted(val) { + decrypted, err := Decrypt(val, keyFn) + if err != nil { + return err + } + data[k] = decrypted + } + case map[string]any: + if err := scanDecrypt(val, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + case []any: + for _, elem := range val { + if m, ok := elem.(map[string]any); ok { + if err := scanDecrypt(m, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + } + } + } + } + return nil +} + +// ScanAndMask returns a deep copy of data with protected fields masked (for logging). +// Does NOT modify the original map. +func ScanAndMask(data map[string]any, registry *Registry, maxDepth int) map[string]any { + return scanMask(data, registry, 0, maxDepth) +} + +func scanMask(data map[string]any, registry *Registry, depth, maxDepth int) map[string]any { + result := make(map[string]any, len(data)) + for k, v := range data { + switch val := v.(type) { + case string: + if pf, ok := registry.GetField(k); ok { + result[k] = MaskValue(val, pf.LogBehavior, pf.MaskPattern) + } else { + result[k] = val + } + case map[string]any: + if depth < maxDepth { + result[k] = scanMask(val, registry, depth+1, maxDepth) + } else { + result[k] = val + } + case []any: + masked := make([]any, len(val)) + for i, elem := range val { + if m, ok := elem.(map[string]any); ok && depth < maxDepth { + masked[i] = scanMask(m, registry, depth+1, maxDepth) + } else { + masked[i] = elem + } + } + result[k] = masked + default: + result[k] = v + } + } + return result +} diff --git a/pkg/tlsutil/tlsutil.go b/pkg/tlsutil/tlsutil.go new file mode 100644 index 00000000..9c49475b --- /dev/null +++ b/pkg/tlsutil/tlsutil.go @@ -0,0 +1,74 @@ +package tlsutil + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" +) + +// TLSConfig is the common TLS configuration used across all transports. +type TLSConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + CertFile string `yaml:"cert_file" json:"cert_file"` + KeyFile string `yaml:"key_file" json:"key_file"` + CAFile string `yaml:"ca_file" json:"ca_file"` + ClientAuth string `yaml:"client_auth" json:"client_auth"` // require | request | none + SkipVerify bool `yaml:"skip_verify" json:"skip_verify"` // for dev only +} + +// AutocertConfig holds Let's Encrypt autocert configuration. +type AutocertConfig struct { + Domains []string `yaml:"domains" json:"domains"` + CacheDir string `yaml:"cache_dir" json:"cache_dir"` + Email string `yaml:"email" json:"email"` +} + +// LoadTLSConfig builds a *tls.Config from the YAML-friendly struct. +func LoadTLSConfig(cfg TLSConfig) (*tls.Config, error) { + if !cfg.Enabled { + return nil, nil + } + + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: cfg.SkipVerify, //nolint:gosec // G402: intentional dev-only option + } + + // Load server/client certificate keypair if provided + if cfg.CertFile != "" && cfg.KeyFile != "" { + cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) + if err != nil { + return nil, fmt.Errorf("tlsutil: load key pair: %w", err) + } + tlsCfg.Certificates = []tls.Certificate{cert} + } + + // Load CA certificate for client verification or custom root CA + if cfg.CAFile != "" { + caPEM, err := os.ReadFile(cfg.CAFile) + if err != nil { + return nil, fmt.Errorf("tlsutil: read CA file: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPEM) { + return nil, fmt.Errorf("tlsutil: no valid certificates found in %s", cfg.CAFile) + } + tlsCfg.RootCAs = pool + tlsCfg.ClientCAs = pool + } + + // Configure client authentication policy + switch cfg.ClientAuth { + case "require": + tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert + case "request": + tlsCfg.ClientAuth = tls.RequestClientCert + case "none", "": + tlsCfg.ClientAuth = tls.NoClientCert + default: + return nil, fmt.Errorf("tlsutil: unknown client_auth %q (valid: require, request, none)", cfg.ClientAuth) + } + + return tlsCfg, nil +} diff --git a/pkg/tlsutil/tlsutil_test.go b/pkg/tlsutil/tlsutil_test.go new file mode 100644 index 00000000..b29aedda --- /dev/null +++ b/pkg/tlsutil/tlsutil_test.go @@ -0,0 +1,237 @@ +package tlsutil_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/GoCodeAlone/workflow/pkg/tlsutil" +) + +// generateSelfSignedCert writes a self-signed cert+key pair to tmpdir and +// returns (certFile, keyFile). +func generateSelfSignedCert(t *testing.T, dir string) (string, string) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + IsCA: true, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + certFile := filepath.Join(dir, "cert.pem") + keyFile := filepath.Join(dir, "key.pem") + + cf, err := os.Create(certFile) + if err != nil { + t.Fatal(err) + } + defer cf.Close() + if err := pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + t.Fatal(err) + } + + kf, err := os.Create(keyFile) + if err != nil { + t.Fatal(err) + } + defer kf.Close() + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + if err := pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { + t.Fatal(err) + } + + return certFile, keyFile +} + +func TestLoadTLSConfig_Disabled(t *testing.T) { + cfg := tlsutil.TLSConfig{Enabled: false} + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatal("expected nil tls.Config when disabled") + } +} + +func TestLoadTLSConfig_ValidCert(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil { + t.Fatal("expected non-nil tls.Config") + } + if len(result.Certificates) != 1 { + t.Fatalf("expected 1 certificate, got %d", len(result.Certificates)) + } +} + +func TestLoadTLSConfig_WithCA(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + CAFile: certFile, // reuse the self-signed cert as CA + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil { + t.Fatal("expected non-nil tls.Config") + } + if result.RootCAs == nil { + t.Fatal("expected non-nil RootCAs") + } +} + +func TestLoadTLSConfig_ClientAuthRequire(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + ClientAuth: "require", + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.ClientAuth != tls.RequireAndVerifyClientCert { + t.Errorf("expected RequireAndVerifyClientCert, got %v", result.ClientAuth) + } +} + +func TestLoadTLSConfig_ClientAuthRequest(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + ClientAuth: "request", + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.ClientAuth != tls.RequestClientCert { + t.Errorf("expected RequestClientCert, got %v", result.ClientAuth) + } +} + +func TestLoadTLSConfig_SkipVerify(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.InsecureSkipVerify { //nolint:gosec // G402: test assertion + t.Error("expected InsecureSkipVerify to be true") + } +} + +func TestLoadTLSConfig_InvalidCertFile(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: "/nonexistent/cert.pem", + KeyFile: "/nonexistent/key.pem", + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for nonexistent cert files") + } +} + +func TestLoadTLSConfig_InvalidCAFile(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + CAFile: "/nonexistent/ca.pem", + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for nonexistent CA file") + } +} + +func TestLoadTLSConfig_InvalidCAContent(t *testing.T) { + dir := t.TempDir() + badCA := filepath.Join(dir, "bad-ca.pem") + if err := os.WriteFile(badCA, []byte("not a valid pem"), 0600); err != nil { + t.Fatal(err) + } + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CAFile: badCA, + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for invalid CA content") + } +} + +func TestLoadTLSConfig_InvalidClientAuth(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + ClientAuth: "invalid-value", + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for invalid client_auth value") + } +} diff --git a/plugins/auth/plugin.go b/plugins/auth/plugin.go index 0dec02b1..6b87fb18 100644 --- a/plugins/auth/plugin.go +++ b/plugins/auth/plugin.go @@ -1,6 +1,7 @@ package auth import ( + "log" "time" "github.com/CrisisTextLine/modular" @@ -11,6 +12,17 @@ import ( "github.com/GoCodeAlone/workflow/schema" ) +// durationFromMap parses a duration string from a config map, returning the +// default value when the key is absent or unparseable. +func durationFromMap(m map[string]any, key string, defaultVal time.Duration) time.Duration { + if s, ok := m[key].(string); ok && s != "" { + if d, err := time.ParseDuration(s); err == nil { + return d + } + } + return defaultVal +} + // Plugin provides authentication capabilities: auth.jwt, auth.user-store, // auth.oauth2, and auth.m2m modules plus the wiring hook that connects // AuthProviders to AuthMiddleware. @@ -38,12 +50,14 @@ func New() *Plugin { "auth.user-store", "auth.oauth2", "auth.m2m", + "auth.token-blacklist", + "security.field-protection", }, Capabilities: []plugin.CapabilityDecl{ {Name: "authentication", Role: "provider", Priority: 10}, {Name: "user-management", Role: "provider", Priority: 10}, }, - WiringHooks: []string{"auth-provider-wiring", "oauth2-jwt-wiring"}, + WiringHooks: []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring", "field-protection-wiring"}, }, }, } @@ -127,6 +141,20 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { // jwtAuth will be wired during the wiring hook. return module.NewOAuth2Module(name, providerCfgs, nil) }, + "auth.token-blacklist": func(name string, cfg map[string]any) modular.Module { + backend := stringFromMap(cfg, "backend") + redisURL := stringFromMap(cfg, "redis_url") + cleanupInterval := durationFromMap(cfg, "cleanup_interval", 5*time.Minute) + return module.NewTokenBlacklistModule(name, backend, redisURL, cleanupInterval) + }, + "security.field-protection": func(name string, cfg map[string]any) modular.Module { + mod, err := module.NewFieldProtectionModule(name, cfg) + if err != nil { + log.Printf("ERROR: field-protection module %q: %v", name, err) + return nil + } + return mod + }, "auth.m2m": func(name string, cfg map[string]any) modular.Module { secret := stringFromMap(cfg, "secret") tokenExpiry := time.Hour @@ -235,6 +263,52 @@ func (p *Plugin) WiringHooks() []plugin.WiringHook { return nil }, }, + { + Name: "token-blacklist-wiring", + Priority: 70, + Hook: func(app modular.Application, _ *config.WorkflowConfig) error { + var blacklist *module.TokenBlacklistModule + for _, svc := range app.SvcRegistry() { + if bl, ok := svc.(*module.TokenBlacklistModule); ok { + blacklist = bl + break + } + } + if blacklist == nil { + return nil + } + for _, svc := range app.SvcRegistry() { + if j, ok := svc.(*module.JWTAuthModule); ok { + j.SetTokenBlacklist(blacklist) + } + } + return nil + }, + }, + { + Name: "field-protection-wiring", + Priority: 50, + Hook: func(app modular.Application, _ *config.WorkflowConfig) error { + var mgr *module.ProtectedFieldManager + for _, svc := range app.SvcRegistry() { + if m, ok := svc.(*module.ProtectedFieldManager); ok { + mgr = m + break + } + } + if mgr == nil { + return nil + } + // Wire field protection to Kafka brokers for field-level encryption. + for _, svc := range app.SvcRegistry() { + if kb, ok := svc.(*module.KafkaBroker); ok { + kb.SetFieldProtection(mgr) + } + } + log.Printf("field-protection: wired to %d registered fields", mgr.Registry.Len()) + return nil + }, + }, } } diff --git a/plugins/auth/plugin_test.go b/plugins/auth/plugin_test.go index 0bd49919..a031daf8 100644 --- a/plugins/auth/plugin_test.go +++ b/plugins/auth/plugin_test.go @@ -21,11 +21,11 @@ func TestPluginManifest(t *testing.T) { if m.Name != "auth" { t.Errorf("expected name %q, got %q", "auth", m.Name) } - if len(m.ModuleTypes) != 4 { - t.Errorf("expected 4 module types, got %d", len(m.ModuleTypes)) + if len(m.ModuleTypes) != 6 { + t.Errorf("expected 6 module types, got %d", len(m.ModuleTypes)) } - if len(m.WiringHooks) != 2 { - t.Errorf("expected 2 wiring hooks, got %d", len(m.WiringHooks)) + if len(m.WiringHooks) != 4 { + t.Errorf("expected 4 wiring hooks, got %d", len(m.WiringHooks)) } } @@ -47,10 +47,13 @@ func TestPluginCapabilities(t *testing.T) { } func TestModuleFactories(t *testing.T) { + // field-protection factory requires a master key via env var + t.Setenv("FIELD_ENCRYPTION_KEY", "test-master-key-32-bytes-long!!") + p := New() factories := p.ModuleFactories() - expectedTypes := []string{"auth.jwt", "auth.user-store", "auth.oauth2", "auth.m2m"} + expectedTypes := []string{"auth.jwt", "auth.user-store", "auth.oauth2", "auth.m2m", "auth.token-blacklist", "security.field-protection"} for _, typ := range expectedTypes { factory, ok := factories[typ] if !ok { @@ -83,8 +86,8 @@ func TestModuleFactoryJWTWithConfig(t *testing.T) { func TestWiringHooks(t *testing.T) { p := New() hooks := p.WiringHooks() - if len(hooks) != 2 { - t.Fatalf("expected 2 wiring hooks, got %d", len(hooks)) + if len(hooks) != 4 { + t.Fatalf("expected 4 wiring hooks, got %d", len(hooks)) } hookNames := map[string]bool{} for _, h := range hooks { @@ -93,7 +96,7 @@ func TestWiringHooks(t *testing.T) { t.Errorf("wiring hook %q function is nil", h.Name) } } - for _, expected := range []string{"auth-provider-wiring", "oauth2-jwt-wiring"} { + for _, expected := range []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring", "field-protection-wiring"} { if !hookNames[expected] { t.Errorf("missing wiring hook %q", expected) } diff --git a/plugins/pipelinesteps/plugin.go b/plugins/pipelinesteps/plugin.go index 52a46b69..5b9815ef 100644 --- a/plugins/pipelinesteps/plugin.go +++ b/plugins/pipelinesteps/plugin.go @@ -4,7 +4,7 @@ // validate_path_param, validate_pagination, validate_request_body, // foreach, webhook_verify, base64_decode, ui_scaffold, ui_scaffold_analyze, // dlq_send, dlq_replay, retry_with_backoff, circuit_breaker (wrapping), -// s3_upload, auth_validate. +// s3_upload, auth_validate, token_revoke, sandbox_exec. // It also provides the PipelineWorkflowHandler for composable pipelines. package pipelinesteps @@ -85,6 +85,9 @@ func New() *Plugin { "step.resilient_circuit_breaker", "step.s3_upload", "step.auth_validate", + "step.token_revoke", + "step.field_reencrypt", + "step.sandbox_exec", }, WorkflowTypes: []string{"pipeline"}, Capabilities: []plugin.CapabilityDecl{ @@ -148,7 +151,10 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { return p.concreteStepRegistry })), "step.s3_upload": wrapStepFactory(module.NewS3UploadStepFactory()), - "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), + "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), + "step.token_revoke": wrapStepFactory(module.NewTokenRevokeStepFactory()), + "step.field_reencrypt": wrapStepFactory(module.NewFieldReencryptStepFactory()), + "step.sandbox_exec": wrapStepFactory(module.NewSandboxExecStepFactory()), } } diff --git a/plugins/pipelinesteps/plugin_test.go b/plugins/pipelinesteps/plugin_test.go index a0e98c11..aa890748 100644 --- a/plugins/pipelinesteps/plugin_test.go +++ b/plugins/pipelinesteps/plugin_test.go @@ -62,7 +62,10 @@ func TestStepFactories(t *testing.T) { "step.resilient_circuit_breaker", "step.s3_upload", "step.auth_validate", + "step.token_revoke", "step.base64_decode", + "step.field_reencrypt", + "step.sandbox_exec", } for _, stepType := range expectedSteps { diff --git a/plugins/secrets/plugin.go b/plugins/secrets/plugin.go index 43b47633..a9920029 100644 --- a/plugins/secrets/plugin.go +++ b/plugins/secrets/plugin.go @@ -1,5 +1,6 @@ // Package secrets provides a plugin that registers secrets management modules: -// secrets.vault (HashiCorp Vault) and secrets.aws (AWS Secrets Manager). +// secrets.vault (HashiCorp Vault) and secrets.aws (AWS Secrets Manager), +// as well as the step.secret_rotate pipeline step type. package secrets import ( @@ -9,7 +10,7 @@ import ( "github.com/GoCodeAlone/workflow/plugin" ) -// Plugin registers secrets management module factories. +// Plugin registers secrets management module factories and step types. type Plugin struct { plugin.BaseEnginePlugin } @@ -30,6 +31,7 @@ func New() *Plugin { Description: "Secrets management modules (Vault, AWS Secrets Manager)", Tier: plugin.TierCore, ModuleTypes: []string{"secrets.vault", "secrets.aws"}, + StepTypes: []string{"step.secret_rotate"}, Capabilities: []plugin.CapabilityDecl{ {Name: "secrets-management", Role: "provider", Priority: 50}, }, @@ -85,3 +87,12 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { }, } } + +// StepFactories returns the step factories provided by this plugin. +func (p *Plugin) StepFactories() map[string]plugin.StepFactory { + return map[string]plugin.StepFactory{ + "step.secret_rotate": func(name string, cfg map[string]any, app modular.Application) (any, error) { + return module.NewSecretRotateStepFactory()(name, cfg, app) + }, + } +} diff --git a/sandbox/docker.go b/sandbox/docker.go index 9d5fda3c..6f2a4ad5 100644 --- a/sandbox/docker.go +++ b/sandbox/docker.go @@ -34,6 +34,38 @@ type SandboxConfig struct { CPULimit float64 `yaml:"cpu_limit"` Timeout time.Duration `yaml:"timeout"` NetworkMode string `yaml:"network_mode"` + + // Security hardening fields + SecurityOpts []string `yaml:"security_opts"` // e.g., ["seccomp=default.json"] + CapAdd []string `yaml:"cap_add"` // capabilities to add + CapDrop []string `yaml:"cap_drop"` // e.g., ["ALL"] + ReadOnlyRootfs bool `yaml:"read_only_rootfs"` + NoNewPrivileges bool `yaml:"no_new_privileges"` + User string `yaml:"user"` // e.g., "nobody:nogroup" + PidsLimit int64 `yaml:"pids_limit"` // max process count + Tmpfs map[string]string `yaml:"tmpfs"` // e.g., {"/tmp": "size=64m,noexec"} +} + +// DefaultSecureSandboxConfig returns a hardened SandboxConfig suitable for +// running untrusted workloads. It uses a minimal Wolfi-based image, drops all +// Linux capabilities, enables a read-only root filesystem, mounts /tmp as +// tmpfs with noexec, and disables network access. +func DefaultSecureSandboxConfig(image string) SandboxConfig { + if image == "" { + image = "cgr.dev/chainguard/wolfi-base:latest" + } + return SandboxConfig{ + Image: image, + MemoryLimit: 256 * 1024 * 1024, // 256MB + CPULimit: 0.5, + NetworkMode: "none", + CapDrop: []string{"ALL"}, + NoNewPrivileges: true, + ReadOnlyRootfs: true, + PidsLimit: 64, + Tmpfs: map[string]string{"/tmp": "size=64m,noexec"}, + Timeout: 5 * time.Minute, + } } // ExecResult holds the output from a command execution inside the sandbox. @@ -45,8 +77,9 @@ type ExecResult struct { // DockerSandbox wraps the Docker Engine SDK to execute commands in isolated containers. type DockerSandbox struct { - client *client.Client - config SandboxConfig + client *client.Client + config SandboxConfig + containerID string // set by CreateContainer, used by CopyIn/CopyOut/RemoveContainer } // NewDockerSandbox creates a new DockerSandbox with the given configuration. @@ -92,6 +125,9 @@ func (s *DockerSandbox) Exec(ctx context.Context, cmd []string) (*ExecResult, er Env: s.buildEnv(), WorkingDir: s.config.WorkDir, } + if s.config.User != "" { + containerConfig.User = s.config.User + } hostConfig := s.buildHostConfig() @@ -145,39 +181,51 @@ func (s *DockerSandbox) Exec(ctx context.Context, cmd []string) (*ExecResult, er }, nil } -// CopyIn copies a file from the host into a running or created container. +// CopyIn copies a file from the host into the active container. +// Call CreateContainer first to set the active container ID. func (s *DockerSandbox) CopyIn(ctx context.Context, srcPath, destPath string) error { - f, err := os.Open(srcPath) - if err != nil { - return fmt.Errorf("sandbox: failed to open source file: %w", err) + if s.containerID == "" { + return fmt.Errorf("sandbox: CopyIn requires an active container; call CreateContainer first") } - defer f.Close() + return s.copyToContainer(ctx, s.containerID, srcPath, destPath) +} - stat, err := f.Stat() +// CopyOut copies a file out of the active container. Returns a ReadCloser with the file contents. +// Call CreateContainer first to set the active container ID. +func (s *DockerSandbox) CopyOut(ctx context.Context, srcPath string) (io.ReadCloser, error) { + if s.containerID == "" { + return nil, fmt.Errorf("sandbox: CopyOut requires an active container; call CreateContainer first") + } + reader, _, err := s.client.CopyFromContainer(ctx, s.containerID, srcPath) if err != nil { - return fmt.Errorf("sandbox: failed to stat source file: %w", err) + return nil, fmt.Errorf("sandbox: CopyOut %q: %w", srcPath, err) } + return reader, nil +} - // Create a tar archive containing the file - tarReader, err := createTarFromFile(f, stat) +// CreateContainer creates and starts a container, storing its ID for use with CopyIn/CopyOut. +// Call RemoveContainer when done to clean up. +func (s *DockerSandbox) CreateContainer(ctx context.Context, cmd []string) error { + hostConfig := s.buildHostConfig() + resp, err := s.client.ContainerCreate(ctx, &container.Config{ + Image: s.config.Image, + Cmd: cmd, + }, hostConfig, nil, nil, "") if err != nil { - return fmt.Errorf("sandbox: failed to create tar archive: %w", err) + return fmt.Errorf("sandbox: create container: %w", err) } - - // We need a container to copy into. This method is intended to be used - // with a container that has been created but the caller manages its lifecycle. - // For the typical use case, the Exec method handles the full lifecycle. - // This is a lower-level utility for advanced usage. - _ = destPath - _ = tarReader - return fmt.Errorf("sandbox: CopyIn requires an active container ID; use Exec for typical workflows") + s.containerID = resp.ID + return nil } -// CopyOut copies a file out of a container. Returns a ReadCloser with the file contents. -func (s *DockerSandbox) CopyOut(ctx context.Context, srcPath string) (io.ReadCloser, error) { - // Similar to CopyIn, this requires an active container. - _ = srcPath - return nil, fmt.Errorf("sandbox: CopyOut requires an active container ID; use Exec for typical workflows") +// RemoveContainer stops and removes the active container. +func (s *DockerSandbox) RemoveContainer(ctx context.Context) error { + if s.containerID == "" { + return nil + } + id := s.containerID + s.containerID = "" + return s.client.ContainerRemove(ctx, id, container.RemoveOptions{Force: true}) } // ExecInContainer creates a container, copies files in, runs the command, and allows file extraction. @@ -200,6 +248,9 @@ func (s *DockerSandbox) ExecInContainer(ctx context.Context, cmd []string, copyI Env: s.buildEnv(), WorkingDir: s.config.WorkDir, } + if s.config.User != "" { + containerConfig.User = s.config.User + } hostConfig := s.buildHostConfig() @@ -320,6 +371,10 @@ func (s *DockerSandbox) buildHostConfig() *container.HostConfig { // Docker uses NanoCPUs (1 CPU = 1e9 NanoCPUs) hc.NanoCPUs = int64(s.config.CPULimit * 1e9) } + if s.config.PidsLimit > 0 { + limit := s.config.PidsLimit + hc.PidsLimit = &limit + } // Mounts if len(s.config.Mounts) > 0 { @@ -340,6 +395,30 @@ func (s *DockerSandbox) buildHostConfig() *container.HostConfig { hc.NetworkMode = container.NetworkMode(s.config.NetworkMode) } + // Security options + secOpts := make([]string, len(s.config.SecurityOpts)) + copy(secOpts, s.config.SecurityOpts) + if s.config.NoNewPrivileges { + secOpts = append(secOpts, "no-new-privileges:true") + } + if len(secOpts) > 0 { + hc.SecurityOpt = secOpts + } + + // Capabilities + if len(s.config.CapAdd) > 0 { + hc.CapAdd = s.config.CapAdd + } + if len(s.config.CapDrop) > 0 { + hc.CapDrop = s.config.CapDrop + } + + // Filesystem hardening + hc.ReadonlyRootfs = s.config.ReadOnlyRootfs + if len(s.config.Tmpfs) > 0 { + hc.Tmpfs = s.config.Tmpfs + } + return hc } diff --git a/sandbox/docker_test.go b/sandbox/docker_test.go index d317e427..722f4053 100644 --- a/sandbox/docker_test.go +++ b/sandbox/docker_test.go @@ -204,3 +204,129 @@ func TestBuildHostConfig_NoLimits(t *testing.T) { t.Fatalf("expected 0 mounts, got %d", len(hc.Mounts)) } } + +func TestBuildHostConfig_SecurityFields(t *testing.T) { + pidsLimit := int64(32) + sb := &DockerSandbox{ + config: SandboxConfig{ + SecurityOpts: []string{"seccomp=default.json"}, + CapAdd: []string{"NET_BIND_SERVICE"}, + CapDrop: []string{"ALL"}, + ReadOnlyRootfs: true, + NoNewPrivileges: true, + PidsLimit: pidsLimit, + Tmpfs: map[string]string{"/tmp": "size=32m,noexec"}, + }, + } + + hc := sb.buildHostConfig() + + // SecurityOpt should contain both seccomp and no-new-privileges + foundSeccomp := false + foundNoNewPriv := false + for _, opt := range hc.SecurityOpt { + if opt == "seccomp=default.json" { + foundSeccomp = true + } + if opt == "no-new-privileges:true" { + foundNoNewPriv = true + } + } + if !foundSeccomp { + t.Fatal("expected seccomp=default.json in SecurityOpt") + } + if !foundNoNewPriv { + t.Fatal("expected no-new-privileges:true in SecurityOpt") + } + + if len(hc.CapAdd) != 1 || hc.CapAdd[0] != "NET_BIND_SERVICE" { + t.Fatalf("unexpected CapAdd: %v", hc.CapAdd) + } + if len(hc.CapDrop) != 1 || hc.CapDrop[0] != "ALL" { + t.Fatalf("unexpected CapDrop: %v", hc.CapDrop) + } + if !hc.ReadonlyRootfs { + t.Fatal("expected ReadonlyRootfs true") + } + if hc.PidsLimit == nil || *hc.PidsLimit != pidsLimit { + t.Fatalf("expected PidsLimit %d, got %v", pidsLimit, hc.PidsLimit) + } + if hc.Tmpfs["/tmp"] != "size=32m,noexec" { + t.Fatalf("unexpected Tmpfs: %v", hc.Tmpfs) + } +} + +func TestBuildHostConfig_NoNewPrivilegesOnly(t *testing.T) { + sb := &DockerSandbox{ + config: SandboxConfig{ + NoNewPrivileges: true, + }, + } + + hc := sb.buildHostConfig() + + found := false + for _, opt := range hc.SecurityOpt { + if opt == "no-new-privileges:true" { + found = true + } + } + if !found { + t.Fatal("expected no-new-privileges:true in SecurityOpt") + } +} + +func TestBuildHostConfig_PidsLimitZeroNotSet(t *testing.T) { + sb := &DockerSandbox{ + config: SandboxConfig{PidsLimit: 0}, + } + + hc := sb.buildHostConfig() + + if hc.PidsLimit != nil { + t.Fatalf("expected nil PidsLimit when PidsLimit=0, got %v", hc.PidsLimit) + } +} + +func TestDefaultSecureSandboxConfig(t *testing.T) { + cfg := DefaultSecureSandboxConfig("alpine:3.19") + + if cfg.Image != "alpine:3.19" { + t.Fatalf("unexpected image: %s", cfg.Image) + } + if cfg.MemoryLimit != 256*1024*1024 { + t.Fatalf("unexpected MemoryLimit: %d", cfg.MemoryLimit) + } + if cfg.CPULimit != 0.5 { + t.Fatalf("unexpected CPULimit: %f", cfg.CPULimit) + } + if cfg.NetworkMode != "none" { + t.Fatalf("unexpected NetworkMode: %s", cfg.NetworkMode) + } + if len(cfg.CapDrop) != 1 || cfg.CapDrop[0] != "ALL" { + t.Fatalf("unexpected CapDrop: %v", cfg.CapDrop) + } + if !cfg.NoNewPrivileges { + t.Fatal("expected NoNewPrivileges true") + } + if !cfg.ReadOnlyRootfs { + t.Fatal("expected ReadOnlyRootfs true") + } + if cfg.PidsLimit != 64 { + t.Fatalf("unexpected PidsLimit: %d", cfg.PidsLimit) + } + if cfg.Tmpfs["/tmp"] != "size=64m,noexec" { + t.Fatalf("unexpected Tmpfs: %v", cfg.Tmpfs) + } + if cfg.Timeout != 5*time.Minute { + t.Fatalf("unexpected Timeout: %s", cfg.Timeout) + } +} + +func TestDefaultSecureSandboxConfig_DefaultImage(t *testing.T) { + cfg := DefaultSecureSandboxConfig("") + + if cfg.Image != "cgr.dev/chainguard/wolfi-base:latest" { + t.Fatalf("unexpected default image: %s", cfg.Image) + } +} diff --git a/secrets/rotation_test.go b/secrets/rotation_test.go new file mode 100644 index 00000000..d746ca35 --- /dev/null +++ b/secrets/rotation_test.go @@ -0,0 +1,343 @@ +package secrets + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +// Compile-time check: VaultProvider must implement RotationProvider. +var _ RotationProvider = (*VaultProvider)(nil) + +// versionedEntry stores all versions of a secret so GetPrevious can be tested. +type versionedEntry struct { + versions []map[string]interface{} // index 0 = version 1, index 1 = version 2, ... +} + +func (e *versionedEntry) current() map[string]interface{} { + if len(e.versions) == 0 { + return nil + } + return e.versions[len(e.versions)-1] +} + +func (e *versionedEntry) currentVersion() int { + return len(e.versions) +} + +func (e *versionedEntry) getVersion(v int) map[string]interface{} { + if v < 1 || v > len(e.versions) { + return nil + } + return e.versions[v-1] +} + +func (e *versionedEntry) put(data map[string]interface{}) { + e.versions = append(e.versions, data) +} + +// newVersionedVaultServer creates an httptest server that tracks all KV v2 versions. +// It supports ?version=N query params on GET for GetVersion requests. +func newVersionedVaultServer(t *testing.T) (*httptest.Server, map[string]*versionedEntry) { + t.Helper() + store := make(map[string]*versionedEntry) + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("X-Vault-Token") + if token == "" { + http.Error(w, `{"errors":["missing client token"]}`, http.StatusForbidden) + return + } + + path := r.URL.Path + switch { + case strings.Contains(path, "/data/"): + handleVersionedData(w, r, path, store) + default: + http.Error(w, `{"errors":["not found"]}`, http.StatusNotFound) + } + }) + + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + return server, store +} + +func handleVersionedData(w http.ResponseWriter, r *http.Request, path string, store map[string]*versionedEntry) { + parts := strings.SplitN(path, "/data/", 2) + if len(parts) < 2 { + http.Error(w, `{"errors":["invalid path"]}`, http.StatusBadRequest) + return + } + key := parts[1] + + switch r.Method { + case http.MethodGet: + entry, ok := store[key] + if !ok || len(entry.versions) == 0 { + http.Error(w, `{"errors":[]}`, http.StatusNotFound) + return + } + + // Check for ?version=N query param + versionParam := r.URL.Query().Get("version") + var data map[string]interface{} + var version int + + if versionParam != "" { + v, err := strconv.Atoi(versionParam) + if err != nil || v < 1 { + http.Error(w, `{"errors":["invalid version"]}`, http.StatusBadRequest) + return + } + data = entry.getVersion(v) + if data == nil { + http.Error(w, `{"errors":[]}`, http.StatusNotFound) + return + } + version = v + } else { + data = entry.current() + version = entry.currentVersion() + } + + resp := map[string]interface{}{ + "data": map[string]interface{}{ + "data": data, + "metadata": map[string]interface{}{ + "version": version, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + case http.MethodPost, http.MethodPut: + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, `{"errors":["read body failed"]}`, http.StatusBadRequest) + return + } + var payload struct { + Data map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(body, &payload); err != nil { + http.Error(w, `{"errors":["invalid json"]}`, http.StatusBadRequest) + return + } + + if store[key] == nil { + store[key] = &versionedEntry{} + } + store[key].put(payload.Data) + + newVersion := store[key].currentVersion() + resp := map[string]interface{}{ + "data": map[string]interface{}{ + "version": newVersion, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + default: + http.Error(w, `{"errors":["method not allowed"]}`, http.StatusMethodNotAllowed) + } +} + +func TestVaultProvider_Rotate(t *testing.T) { + server, store := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + MountPath: "secret", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + ctx := context.Background() + + // Rotate on a fresh key — should create version 1. + newVal, err := p.Rotate(ctx, "myapp/api-key") + if err != nil { + t.Fatalf("Rotate: %v", err) + } + if newVal == "" { + t.Fatal("expected non-empty rotated value") + } + // 32 bytes hex-encoded = 64 chars + if len(newVal) != 64 { + t.Errorf("expected 64-char hex value, got %d chars: %q", len(newVal), newVal) + } + + entry, ok := store["myapp/api-key"] + if !ok { + t.Fatal("expected key to exist in store") + } + if entry.currentVersion() != 1 { + t.Errorf("expected version 1 after first rotate, got %d", entry.currentVersion()) + } + + // Rotate again — should create version 2. + newVal2, err := p.Rotate(ctx, "myapp/api-key") + if err != nil { + t.Fatalf("Rotate (second): %v", err) + } + if newVal2 == newVal { + t.Error("expected different value after second rotate") + } + if entry.currentVersion() != 2 { + t.Errorf("expected version 2 after second rotate, got %d", entry.currentVersion()) + } +} + +func TestVaultProvider_Rotate_EmptyKey(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.Rotate(context.Background(), "") + if err != ErrInvalidKey { + t.Errorf("expected ErrInvalidKey, got %v", err) + } +} + +func TestVaultProvider_GetPrevious(t *testing.T) { + server, store := newVersionedVaultServer(t) + + // Pre-populate two versions directly in the store. + store["myapp/db-pass"] = &versionedEntry{ + versions: []map[string]interface{}{ + {"value": "old-secret"}, + {"value": "new-secret"}, + }, + } + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + MountPath: "secret", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + ctx := context.Background() + + // GetPrevious should return version 1 (old-secret). + prev, err := p.GetPrevious(ctx, "myapp/db-pass#value") + if err != nil { + t.Fatalf("GetPrevious: %v", err) + } + if prev != "old-secret" { + t.Errorf("expected 'old-secret', got %q", prev) + } +} + +func TestVaultProvider_GetPrevious_NoHistory(t *testing.T) { + server, store := newVersionedVaultServer(t) + + // Only one version exists — GetPrevious should return ErrNotFound. + store["myapp/only-one"] = &versionedEntry{ + versions: []map[string]interface{}{ + {"value": "only"}, + }, + } + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.GetPrevious(context.Background(), "myapp/only-one") + if err == nil { + t.Fatal("expected error when only one version exists") + } +} + +func TestVaultProvider_GetPrevious_EmptyKey(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.GetPrevious(context.Background(), "") + if err != ErrInvalidKey { + t.Errorf("expected ErrInvalidKey, got %v", err) + } +} + +func TestVaultProvider_GetPrevious_NotFound(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.GetPrevious(context.Background(), "nonexistent/key") + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestVaultProvider_Rotate_ThenGetPrevious(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + MountPath: "secret", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + ctx := context.Background() + + // First rotation creates version 1. + val1, err := p.Rotate(ctx, "svc/token") + if err != nil { + t.Fatalf("first Rotate: %v", err) + } + + // Second rotation creates version 2. + _, err = p.Rotate(ctx, "svc/token") + if err != nil { + t.Fatalf("second Rotate: %v", err) + } + + // GetPrevious should return the version 1 value (stored as "value" field). + prev, err := p.GetPrevious(ctx, "svc/token#value") + if err != nil { + t.Fatalf("GetPrevious: %v", err) + } + if prev != val1 { + t.Errorf("expected previous value %q, got %q", val1, prev) + } +} diff --git a/secrets/secrets.go b/secrets/secrets.go index 2b7f7eef..84532b78 100644 --- a/secrets/secrets.go +++ b/secrets/secrets.go @@ -34,6 +34,15 @@ type Provider interface { List(ctx context.Context) ([]string, error) } +// RotationProvider extends Provider with key rotation capabilities. +type RotationProvider interface { + Provider + // Rotate generates a new secret value and stores it, returning the new value. + Rotate(ctx context.Context, key string) (string, error) + // GetPrevious retrieves the previous version of a rotated secret (for grace periods). + GetPrevious(ctx context.Context, key string) (string, error) +} + // --- Environment Variable Provider --- // EnvProvider reads secrets from environment variables. diff --git a/secrets/vault_provider.go b/secrets/vault_provider.go index 1f7dbb3d..b0c3fa98 100644 --- a/secrets/vault_provider.go +++ b/secrets/vault_provider.go @@ -2,6 +2,8 @@ package secrets import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "strings" @@ -157,6 +159,85 @@ func (p *VaultProvider) List(ctx context.Context) ([]string, error) { return p.listRecursive(ctx, "") } +// Rotate generates a new random 32-byte hex-encoded secret and stores it at the given key, +// creating a new version in Vault KV v2. It returns the newly generated value. +func (p *VaultProvider) Rotate(ctx context.Context, key string) (string, error) { + if key == "" { + return "", ErrInvalidKey + } + + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("secrets: failed to generate random secret: %w", err) + } + newValue := hex.EncodeToString(raw) + + path, _ := parseVaultKey(key) + kv := p.client.KVv2(p.config.MountPath) + if _, err := kv.Put(ctx, path, map[string]interface{}{ + "value": newValue, + }); err != nil { + return "", fmt.Errorf("secrets: vault rotate failed: %w", err) + } + + return newValue, nil +} + +// GetPrevious retrieves version N-1 of the secret at the given key from Vault KV v2. +// It reads the current version metadata to determine N, then fetches version N-1. +// Returns ErrNotFound if the secret has only one version or does not exist. +func (p *VaultProvider) GetPrevious(ctx context.Context, key string) (string, error) { + if key == "" { + return "", ErrInvalidKey + } + + path, field := parseVaultKey(key) + kv := p.client.KVv2(p.config.MountPath) + + // Get the current version to determine the previous version number. + current, err := kv.Get(ctx, path) + if err != nil { + if isVaultNotFound(err) { + return "", fmt.Errorf("%w: vault returned not found for key %q", ErrNotFound, key) + } + return "", fmt.Errorf("secrets: vault get (for previous) failed: %w", err) + } + if current == nil || current.VersionMetadata == nil { + return "", fmt.Errorf("%w: no version metadata for key %q", ErrNotFound, key) + } + + currentVersion := current.VersionMetadata.Version + if currentVersion <= 1 { + return "", fmt.Errorf("%w: no previous version exists for key %q (current version is %d)", ErrNotFound, key, currentVersion) + } + + prevVersion := currentVersion - 1 + prev, err := kv.GetVersion(ctx, path, prevVersion) + if err != nil { + if isVaultNotFound(err) { + return "", fmt.Errorf("%w: previous version %d not found for key %q", ErrNotFound, prevVersion, key) + } + return "", fmt.Errorf("secrets: vault get version %d failed: %w", prevVersion, err) + } + if prev == nil || prev.Data == nil { + return "", fmt.Errorf("%w: no data in previous version %d for key %q", ErrNotFound, prevVersion, key) + } + + if field != "" { + val, ok := prev.Data[field] + if !ok { + return "", fmt.Errorf("%w: field %q not found in previous version of key %q", ErrNotFound, field, path) + } + return fmt.Sprintf("%v", val), nil + } + + data, err := json.Marshal(prev.Data) + if err != nil { + return "", fmt.Errorf("secrets: failed to marshal vault previous version data: %w", err) + } + return string(data), nil +} + // listRecursive walks the metadata tree and collects all leaf keys. func (p *VaultProvider) listRecursive(ctx context.Context, prefix string) ([]string, error) { // Construct metadata path: {mount}/metadata or {mount}/metadata/{prefix}