From 6aecfc6144a9b5c1e4195d1cf579d6967642151b Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sat, 28 Feb 2026 04:47:57 -0500 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20security=20hardening=20=E2=80=94=20?= =?UTF-8?q?TLS,=20token=20blacklist,=20field-level=20encryption,=20sandbox?= =?UTF-8?q?,=20secret=20rotation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1: TLS support for all transports - pkg/tlsutil: shared TLS config (manual, autocert, mTLS) - HTTP server: Let's Encrypt autocert + manual TLS + mTLS - Redis: TLS with client cert support - Kafka: SASL (PLAIN/SCRAM-SHA-256/512) + TLS - NATS: TLS via nats.Secure() - Database: explicit TLS fields (sslmode, ca_file) Phase 2: Token blacklist - auth.token-blacklist module (memory + redis backends) - step.token_revoke pipeline step - JTI generation + blacklist check in JWTAuthModule Phase 3: Field-level data protection - pkg/fieldcrypt: AES-256-GCM encryption, masking, HKDF key derivation - Tenant-isolated KeyRing with versioned keys - ProtectedFieldManager module (security.field-protection) - step.field_reencrypt for key rotation re-encryption - Backward compat: legacy enc:: prefix handled alongside new epf:v{n}: format Phase 4: Docker sandbox hardening - seccomp profiles, capability dropping, read-only rootfs, no-new-privileges - step.sandbox_exec with strict/standard/permissive security profiles - Default secure config with Wolfi base image (cgr.dev/chainguard/wolfi-base) Phase 5: Secret rotation - RotationProvider interface in secrets package - Vault provider Rotate() + GetPrevious() via versioned KV v2 - step.secret_rotate pipeline step Co-Authored-By: Claude Opus 4.6 --- go.mod | 3 + go.sum | 7 + module/auth_token_blacklist.go | 163 +++++++++ module/auth_token_blacklist_test.go | 89 +++++ module/cache_redis.go | 12 +- module/database.go | 44 ++- module/field_protection.go | 162 +++++++++ module/field_protection_test.go | 192 +++++++++++ module/http_server.go | 102 +++++- module/jwt_auth.go | 18 + module/kafka_broker.go | 54 +++ module/kafka_scram.go | 44 +++ module/nats_broker.go | 20 +- module/pipeline_step_field_reencrypt.go | 73 ++++ module/pipeline_step_sandbox_exec.go | 231 +++++++++++++ module/pipeline_step_sandbox_exec_test.go | 302 +++++++++++++++++ module/pipeline_step_secret_rotate.go | 83 +++++ module/pipeline_step_secret_rotate_test.go | 184 ++++++++++ module/pipeline_step_token_revoke.go | 95 ++++++ module/pipeline_step_token_revoke_test.go | 309 +++++++++++++++++ module/tls_config_test.go | 152 +++++++++ pkg/fieldcrypt/encrypt.go | 118 +++++++ pkg/fieldcrypt/fieldcrypt.go | 66 ++++ pkg/fieldcrypt/fieldcrypt_test.go | 377 +++++++++++++++++++++ pkg/fieldcrypt/keyring.go | 122 +++++++ pkg/fieldcrypt/mask.go | 122 +++++++ pkg/fieldcrypt/scanner.go | 117 +++++++ pkg/tlsutil/tlsutil.go | 74 ++++ pkg/tlsutil/tlsutil_test.go | 237 +++++++++++++ plugins/auth/plugin.go | 47 ++- plugins/auth/plugin_test.go | 16 +- plugins/pipelinesteps/plugin.go | 10 +- plugins/pipelinesteps/plugin_test.go | 3 + plugins/secrets/plugin.go | 15 +- sandbox/docker.go | 66 ++++ sandbox/docker_test.go | 126 +++++++ secrets/rotation_test.go | 343 +++++++++++++++++++ secrets/secrets.go | 9 + secrets/vault_provider.go | 81 +++++ 39 files changed, 4262 insertions(+), 26 deletions(-) create mode 100644 module/auth_token_blacklist.go create mode 100644 module/auth_token_blacklist_test.go create mode 100644 module/field_protection.go create mode 100644 module/field_protection_test.go create mode 100644 module/kafka_scram.go create mode 100644 module/pipeline_step_field_reencrypt.go create mode 100644 module/pipeline_step_sandbox_exec.go create mode 100644 module/pipeline_step_sandbox_exec_test.go create mode 100644 module/pipeline_step_secret_rotate.go create mode 100644 module/pipeline_step_secret_rotate_test.go create mode 100644 module/pipeline_step_token_revoke.go create mode 100644 module/pipeline_step_token_revoke_test.go create mode 100644 module/tls_config_test.go create mode 100644 pkg/fieldcrypt/encrypt.go create mode 100644 pkg/fieldcrypt/fieldcrypt.go create mode 100644 pkg/fieldcrypt/fieldcrypt_test.go create mode 100644 pkg/fieldcrypt/keyring.go create mode 100644 pkg/fieldcrypt/mask.go create mode 100644 pkg/fieldcrypt/scanner.go create mode 100644 pkg/tlsutil/tlsutil.go create mode 100644 pkg/tlsutil/tlsutil_test.go create mode 100644 secrets/rotation_test.go diff --git a/go.mod b/go.mod index 68ef4e93..db1134f3 100644 --- a/go.mod +++ b/go.mod @@ -207,6 +207,9 @@ require ( github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect github.com/tliron/commonlog v0.2.8 // indirect github.com/tliron/kutil v0.3.11 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.2.0 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect diff --git a/go.sum b/go.sum index add50e32..97979362 100644 --- a/go.sum +++ b/go.sum @@ -506,6 +506,12 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0 h1:3UeQBvD0TFrlVjOeLOBz+CPAI8dnbqNSVwUwRrkp7vQ= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0/go.mod h1:IXCdmsXIht47RaVFLEdVnh1t+pgYtTAhQGj73kz+2DM= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= +github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -606,6 +612,7 @@ golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= diff --git a/module/auth_token_blacklist.go b/module/auth_token_blacklist.go new file mode 100644 index 00000000..1da3ceaf --- /dev/null +++ b/module/auth_token_blacklist.go @@ -0,0 +1,163 @@ +package module + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/redis/go-redis/v9" +) + +// TokenBlacklist is the interface for checking and adding revoked JWT IDs. +type TokenBlacklist interface { + Add(jti string, expiresAt time.Time) + IsBlacklisted(jti string) bool +} + +// TokenBlacklistModule maintains a set of revoked JWT IDs (JTIs). +// It supports two backends: "memory" (default) and "redis". +type TokenBlacklistModule struct { + name string + backend string + redisURL string + cleanupInterval time.Duration + + // memory backend + entries sync.Map // jti (string) -> expiry (time.Time) + + // redis backend + redisClient *redis.Client + + logger modular.Logger + stopCh chan struct{} +} + +// NewTokenBlacklistModule creates a new TokenBlacklistModule. +func NewTokenBlacklistModule(name, backend, redisURL string, cleanupInterval time.Duration) *TokenBlacklistModule { + if backend == "" { + backend = "memory" + } + if cleanupInterval <= 0 { + cleanupInterval = 5 * time.Minute + } + return &TokenBlacklistModule{ + name: name, + backend: backend, + redisURL: redisURL, + cleanupInterval: cleanupInterval, + stopCh: make(chan struct{}), + } +} + +// Name returns the module name. +func (m *TokenBlacklistModule) Name() string { return m.name } + +// Init initializes the module. +func (m *TokenBlacklistModule) Init(app modular.Application) error { + m.logger = app.Logger() + return nil +} + +// Start connects to Redis (if configured) and starts the cleanup goroutine. +func (m *TokenBlacklistModule) Start(ctx context.Context) error { + if m.backend == "redis" { + if m.redisURL == "" { + return fmt.Errorf("auth.token-blacklist %q: redis_url is required for redis backend", m.name) + } + opts, err := redis.ParseURL(m.redisURL) + if err != nil { + return fmt.Errorf("auth.token-blacklist %q: invalid redis_url: %w", m.name, err) + } + m.redisClient = redis.NewClient(opts) + if err := m.redisClient.Ping(ctx).Err(); err != nil { + _ = m.redisClient.Close() + m.redisClient = nil + return fmt.Errorf("auth.token-blacklist %q: redis ping failed: %w", m.name, err) + } + m.logger.Info("token blacklist started", "name", m.name, "backend", "redis") + return nil + } + + // memory backend: start cleanup goroutine + go m.runCleanup() + m.logger.Info("token blacklist started", "name", m.name, "backend", "memory") + return nil +} + +// Stop shuts down the module. +func (m *TokenBlacklistModule) Stop(_ context.Context) error { + select { + case <-m.stopCh: + // already closed + default: + close(m.stopCh) + } + if m.redisClient != nil { + return m.redisClient.Close() + } + return nil +} + +// Add marks a JTI as revoked until expiresAt. +func (m *TokenBlacklistModule) Add(jti string, expiresAt time.Time) { + if m.backend == "redis" && m.redisClient != nil { + ttl := time.Until(expiresAt) + if ttl <= 0 { + return // already expired, nothing to blacklist + } + _ = m.redisClient.Set(context.Background(), m.redisKey(jti), "1", ttl).Err() + return + } + m.entries.Store(jti, expiresAt) +} + +// IsBlacklisted returns true if the JTI is revoked and has not yet expired. +func (m *TokenBlacklistModule) IsBlacklisted(jti string) bool { + if m.backend == "redis" && m.redisClient != nil { + n, err := m.redisClient.Exists(context.Background(), m.redisKey(jti)).Result() + return err == nil && n > 0 + } + val, ok := m.entries.Load(jti) + if !ok { + return false + } + expiry, ok := val.(time.Time) + return ok && time.Now().Before(expiry) +} + +func (m *TokenBlacklistModule) redisKey(jti string) string { + return "blacklist:" + jti +} + +func (m *TokenBlacklistModule) runCleanup() { + ticker := time.NewTicker(m.cleanupInterval) + defer ticker.Stop() + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + now := time.Now() + m.entries.Range(func(key, value any) bool { + if expiry, ok := value.(time.Time); ok && now.After(expiry) { + m.entries.Delete(key) + } + return true + }) + } + } +} + +// ProvidesServices registers this module as a service. +func (m *TokenBlacklistModule) ProvidesServices() []modular.ServiceProvider { + return []modular.ServiceProvider{ + {Name: m.name, Description: "JWT token blacklist", Instance: m}, + } +} + +// RequiresServices returns service dependencies (none). +func (m *TokenBlacklistModule) RequiresServices() []modular.ServiceDependency { + return nil +} diff --git a/module/auth_token_blacklist_test.go b/module/auth_token_blacklist_test.go new file mode 100644 index 00000000..bb32985d --- /dev/null +++ b/module/auth_token_blacklist_test.go @@ -0,0 +1,89 @@ +package module + +import ( + "context" + "testing" + "time" +) + +func TestTokenBlacklistMemory_AddAndCheck(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", time.Minute) + + if bl.IsBlacklisted("jti-1") { + t.Fatal("expected jti-1 to not be blacklisted initially") + } + + bl.Add("jti-1", time.Now().Add(time.Hour)) + if !bl.IsBlacklisted("jti-1") { + t.Fatal("expected jti-1 to be blacklisted after Add") + } + + if bl.IsBlacklisted("jti-unknown") { + t.Fatal("expected jti-unknown to not be blacklisted") + } +} + +func TestTokenBlacklistMemory_ExpiredEntry(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", time.Minute) + + // Add a JTI that has already expired. + bl.Add("jti-expired", time.Now().Add(-time.Second)) + if bl.IsBlacklisted("jti-expired") { + t.Fatal("expected already-expired JTI to not be blacklisted") + } +} + +func TestTokenBlacklistMemory_Cleanup(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", 50*time.Millisecond) + + app := NewMockApplication() + if err := bl.Init(app); err != nil { + t.Fatalf("Init: %v", err) + } + if err := bl.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + defer func() { _ = bl.Stop(context.Background()) }() + + // Add a JTI that expires in 10ms. + bl.Add("jti-cleanup", time.Now().Add(10*time.Millisecond)) + + // Give it time to expire and be cleaned up by the cleanup goroutine. + time.Sleep(300 * time.Millisecond) + + if bl.IsBlacklisted("jti-cleanup") { + t.Fatal("expected cleaned-up JTI to no longer be blacklisted") + } +} + +func TestTokenBlacklistModule_StopIdempotent(t *testing.T) { + bl := NewTokenBlacklistModule("test-bl", "memory", "", time.Minute) + app := NewMockApplication() + if err := bl.Init(app); err != nil { + t.Fatalf("Init: %v", err) + } + if err := bl.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + // Calling Stop twice must not panic. + if err := bl.Stop(context.Background()); err != nil { + t.Fatalf("first Stop: %v", err) + } + if err := bl.Stop(context.Background()); err != nil { + t.Fatalf("second Stop: %v", err) + } +} + +func TestTokenBlacklistModule_ProvidesServices(t *testing.T) { + bl := NewTokenBlacklistModule("my-bl", "memory", "", time.Minute) + svcs := bl.ProvidesServices() + if len(svcs) != 1 { + t.Fatalf("expected 1 service, got %d", len(svcs)) + } + if svcs[0].Name != "my-bl" { + t.Fatalf("expected service name 'my-bl', got %q", svcs[0].Name) + } + if _, ok := svcs[0].Instance.(TokenBlacklist); !ok { + t.Fatal("expected service instance to implement TokenBlacklist") + } +} diff --git a/module/cache_redis.go b/module/cache_redis.go index 9c754770..8d750d7e 100644 --- a/module/cache_redis.go +++ b/module/cache_redis.go @@ -7,6 +7,7 @@ import ( "time" "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" "github.com/redis/go-redis/v9" ) @@ -30,10 +31,11 @@ type RedisClient interface { // RedisCacheConfig holds configuration for the cache.redis module. type RedisCacheConfig struct { Address string - Password string //nolint:gosec // G117: config struct field, not a hardcoded secret + Password string //nolint:gosec // G117: config struct field, not a hardcoded secret DB int Prefix string DefaultTTL time.Duration + TLS tlsutil.TLSConfig `yaml:"tls" json:"tls"` } // RedisCache is a module that connects to a Redis instance and exposes @@ -87,6 +89,14 @@ func (r *RedisCache) Start(ctx context.Context) error { opts.Password = r.cfg.Password } + if r.cfg.TLS.Enabled { + tlsCfg, err := tlsutil.LoadTLSConfig(r.cfg.TLS) + if err != nil { + return fmt.Errorf("cache.redis %q: TLS config: %w", r.name, err) + } + opts.TLSConfig = tlsCfg + } + r.client = redis.NewClient(opts) if err := r.client.Ping(ctx).Err(); err != nil { diff --git a/module/database.go b/module/database.go index 8b30917e..e07b1119 100644 --- a/module/database.go +++ b/module/database.go @@ -25,14 +25,22 @@ func validateIdentifier(name string) error { return nil } +// DatabaseTLSConfig holds TLS settings for database connections. +type DatabaseTLSConfig struct { + // Mode controls SSL behaviour: disable | require | verify-ca | verify-full (PostgreSQL naming). + Mode string `json:"mode" yaml:"mode"` + CAFile string `json:"ca_file" yaml:"ca_file"` +} + // DatabaseConfig holds configuration for the workflow database module type DatabaseConfig struct { - Driver string `json:"driver" yaml:"driver"` - DSN string `json:"dsn" yaml:"dsn"` - MaxOpenConns int `json:"maxOpenConns" yaml:"maxOpenConns"` - MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"` - ConnMaxLifetime time.Duration `json:"connMaxLifetime" yaml:"connMaxLifetime"` - MigrationsDir string `json:"migrationsDir" yaml:"migrationsDir"` + Driver string `json:"driver" yaml:"driver"` + DSN string `json:"dsn" yaml:"dsn"` + MaxOpenConns int `json:"maxOpenConns" yaml:"maxOpenConns"` + MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"` + ConnMaxLifetime time.Duration `json:"connMaxLifetime" yaml:"connMaxLifetime"` + MigrationsDir string `json:"migrationsDir" yaml:"migrationsDir"` + TLS DatabaseTLSConfig `json:"tls" yaml:"tls"` } // QueryResult represents the result of a query @@ -85,6 +93,28 @@ func (w *WorkflowDatabase) RequiresServices() []modular.ServiceDependency { return nil } +// buildDSN returns the DSN with TLS parameters appended for supported drivers. +func (w *WorkflowDatabase) buildDSN() string { + dsn := w.config.DSN + mode := w.config.TLS.Mode + if mode == "" || mode == "disable" { + return dsn + } + + switch w.config.Driver { + case "postgres", "pgx", "pgx/v5": + sep := "?" + if strings.ContainsRune(dsn, '?') { + sep = "&" + } + dsn += sep + "sslmode=" + mode + if w.config.TLS.CAFile != "" { + dsn += "&sslrootcert=" + w.config.TLS.CAFile + } + } + return dsn +} + // Open opens the database connection using config func (w *WorkflowDatabase) Open() (*sql.DB, error) { w.mu.Lock() @@ -94,7 +124,7 @@ func (w *WorkflowDatabase) Open() (*sql.DB, error) { return w.db, nil } - db, err := sql.Open(w.config.Driver, w.config.DSN) + db, err := sql.Open(w.config.Driver, w.buildDSN()) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } diff --git a/module/field_protection.go b/module/field_protection.go new file mode 100644 index 00000000..8da17141 --- /dev/null +++ b/module/field_protection.go @@ -0,0 +1,162 @@ +package module + +import ( + "context" + "log" + "os" + + "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/fieldcrypt" +) + +// ProtectedFieldManager bundles the registry and key ring for field protection. +type ProtectedFieldManager struct { + Registry *fieldcrypt.Registry + KeyRing fieldcrypt.KeyRing + TenantIsolation bool + ScanDepth int + ScanArrays bool + defaultTenantID string + // rawMasterKey is retained for legacy enc:: decryption (version 0). + rawMasterKey []byte +} + +// resolveTenant returns the effective tenant ID, applying isolation policy. +func (m *ProtectedFieldManager) resolveTenant(tenantID string) string { + if !m.TenantIsolation { + return m.defaultTenantID + } + if tenantID == "" { + return m.defaultTenantID + } + return tenantID +} + +// EncryptMap encrypts protected fields in the data map in-place. +func (m *ProtectedFieldManager) EncryptMap(ctx context.Context, tenantID string, data map[string]any) error { + tid := m.resolveTenant(tenantID) + return fieldcrypt.ScanAndEncrypt(data, m.Registry, func() ([]byte, int, error) { + return m.KeyRing.CurrentKey(ctx, tid) + }, m.ScanDepth) +} + +// DecryptMap decrypts protected fields in the data map in-place. +// Version 0 is the legacy enc:: format and uses the raw master key. +func (m *ProtectedFieldManager) DecryptMap(ctx context.Context, tenantID string, data map[string]any) error { + tid := m.resolveTenant(tenantID) + return fieldcrypt.ScanAndDecrypt(data, m.Registry, func(version int) ([]byte, error) { + if version == 0 { + // Legacy enc:: values were encrypted with sha256(masterKey). + // encrypt.go's decryptLegacy calls keyFn(0) expecting the raw key + // bytes, then SHA256-hashes them before use. + return m.rawMasterKey, nil + } + return m.KeyRing.KeyByVersion(ctx, tid, version) + }, m.ScanDepth) +} + +// MaskMap returns a deep copy of data with protected fields masked for logging. +func (m *ProtectedFieldManager) MaskMap(data map[string]any) map[string]any { + return fieldcrypt.ScanAndMask(data, m.Registry, m.ScanDepth) +} + +// FieldProtectionModule implements modular.Module for security.field-protection. +type FieldProtectionModule struct { + name string + manager *ProtectedFieldManager +} + +// NewFieldProtectionModule parses config and creates a FieldProtectionModule. +func NewFieldProtectionModule(name string, cfg map[string]any) (*FieldProtectionModule, error) { + // Parse protected fields. + var fields []fieldcrypt.ProtectedField + if raw, ok := cfg["protected_fields"].([]any); ok { + for _, item := range raw { + if m, ok := item.(map[string]any); ok { + pf := fieldcrypt.ProtectedField{} + if v, ok := m["name"].(string); ok { + pf.Name = v + } + if v, ok := m["classification"].(string); ok { + pf.Classification = fieldcrypt.FieldClassification(v) + } + if v, ok := m["encryption"].(bool); ok { + pf.Encryption = v + } + if v, ok := m["log_behavior"].(string); ok { + pf.LogBehavior = fieldcrypt.LogBehavior(v) + } + if v, ok := m["mask_pattern"].(string); ok { + pf.MaskPattern = v + } + fields = append(fields, pf) + } + } + } + + // Resolve master key. + masterKeyStr, _ := cfg["master_key"].(string) + if masterKeyStr == "" { + masterKeyStr = os.Getenv("FIELD_ENCRYPTION_KEY") + } + + var masterKey []byte + if masterKeyStr != "" { + masterKey = []byte(masterKeyStr) + } else { + log.Println("WARNING: field-protection module using zero key — set master_key or FIELD_ENCRYPTION_KEY") + masterKey = make([]byte, 32) + } + + scanDepth := 10 + if v, ok := cfg["scan_depth"].(int); ok && v > 0 { + scanDepth = v + } + + scanArrays := true + if v, ok := cfg["scan_arrays"].(bool); ok { + scanArrays = v + } + + tenantIsolation := false + if v, ok := cfg["tenant_isolation"].(bool); ok { + tenantIsolation = v + } + + defaultTenantID := "default" + + mgr := &ProtectedFieldManager{ + Registry: fieldcrypt.NewRegistry(fields), + KeyRing: fieldcrypt.NewLocalKeyRing(masterKey), + TenantIsolation: tenantIsolation, + ScanDepth: scanDepth, + ScanArrays: scanArrays, + defaultTenantID: defaultTenantID, + rawMasterKey: masterKey, + } + + return &FieldProtectionModule{name: name, manager: mgr}, nil +} + +// Name returns the module name. +func (m *FieldProtectionModule) Name() string { return m.name } + +// Init initializes the module. +func (m *FieldProtectionModule) Init(_ modular.Application) error { return nil } + +// ProvidesServices returns the services provided by this module. +func (m *FieldProtectionModule) ProvidesServices() []modular.ServiceProvider { + return []modular.ServiceProvider{ + {Name: m.name, Description: "Protected field manager for encryption/masking", Instance: m.manager}, + } +} + +// RequiresServices returns service dependencies (none). +func (m *FieldProtectionModule) RequiresServices() []modular.ServiceDependency { + return nil +} + +// Manager returns the ProtectedFieldManager. +func (m *FieldProtectionModule) Manager() *ProtectedFieldManager { + return m.manager +} diff --git a/module/field_protection_test.go b/module/field_protection_test.go new file mode 100644 index 00000000..599015ff --- /dev/null +++ b/module/field_protection_test.go @@ -0,0 +1,192 @@ +package module + +import ( + "context" + "testing" + + "github.com/GoCodeAlone/workflow/pkg/fieldcrypt" +) + +func TestFieldProtectionModuleConfigParsing(t *testing.T) { + cfg := map[string]any{ + "master_key": "test-master-key-that-is-long-en", + "tenant_isolation": true, + "scan_depth": 5, + "scan_arrays": false, + "protected_fields": []any{ + map[string]any{ + "name": "ssn", + "classification": "pii", + "encryption": true, + "log_behavior": "redact", + }, + map[string]any{ + "name": "email", + "classification": "pii", + "encryption": true, + "log_behavior": "mask", + }, + }, + } + + mod, err := NewFieldProtectionModule("test-fp", cfg) + if err != nil { + t.Fatalf("NewFieldProtectionModule: %v", err) + } + if mod.Name() != "test-fp" { + t.Errorf("name = %q", mod.Name()) + } + + mgr := mod.Manager() + if mgr == nil { + t.Fatal("manager is nil") + } + if !mgr.TenantIsolation { + t.Error("tenant_isolation should be true") + } + if mgr.ScanDepth != 5 { + t.Errorf("scan_depth = %d", mgr.ScanDepth) + } + if mgr.ScanArrays { + t.Error("scan_arrays should be false") + } + if !mgr.Registry.IsProtected("ssn") { + t.Error("ssn should be protected") + } + if !mgr.Registry.IsProtected("email") { + t.Error("email should be protected") + } + if mgr.Registry.IsProtected("name") { + t.Error("name should not be protected") + } +} + +func TestFieldProtectionEncryptDecryptRoundTrip(t *testing.T) { + cfg := map[string]any{ + "master_key": "test-key-for-encrypt-decrypt-rt!", + "protected_fields": []any{ + map[string]any{ + "name": "ssn", + "encryption": true, + }, + map[string]any{ + "name": "email", + "encryption": true, + }, + }, + } + + mod, err := NewFieldProtectionModule("fp", cfg) + if err != nil { + t.Fatal(err) + } + mgr := mod.Manager() + ctx := context.Background() + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "user@example.com", + "name": "John Doe", + } + + if err := mgr.EncryptMap(ctx, "tenant1", data); err != nil { + t.Fatalf("EncryptMap: %v", err) + } + + if !fieldcrypt.IsEncrypted(data["ssn"].(string)) { + t.Error("ssn should be encrypted") + } + if !fieldcrypt.IsEncrypted(data["email"].(string)) { + t.Error("email should be encrypted") + } + if data["name"] != "John Doe" { + t.Error("name should not be modified") + } + + if err := mgr.DecryptMap(ctx, "tenant1", data); err != nil { + t.Fatalf("DecryptMap: %v", err) + } + + if data["ssn"] != "123-45-6789" { + t.Errorf("ssn = %q", data["ssn"]) + } + if data["email"] != "user@example.com" { + t.Errorf("email = %q", data["email"]) + } +} + +func TestFieldProtectionMaskMap(t *testing.T) { + cfg := map[string]any{ + "master_key": "mask-test-key-32-chars-exactly!!", + "protected_fields": []any{ + map[string]any{ + "name": "ssn", + "log_behavior": "redact", + }, + map[string]any{ + "name": "email", + "log_behavior": "mask", + }, + map[string]any{ + "name": "phone", + "log_behavior": "hash", + }, + }, + } + + mod, err := NewFieldProtectionModule("fp", cfg) + if err != nil { + t.Fatal(err) + } + mgr := mod.Manager() + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "user@example.com", + "phone": "555-123-4567", + "name": "John Doe", + } + + masked := mgr.MaskMap(data) + + if masked["ssn"] != "[REDACTED]" { + t.Errorf("ssn mask = %q", masked["ssn"]) + } + if masked["email"] == "user@example.com" { + t.Error("email should be masked") + } + if masked["phone"] == "555-123-4567" { + t.Error("phone should be hashed") + } + if masked["name"] != "John Doe" { + t.Error("name should be unchanged") + } + + // Original data should be unmodified. + if data["ssn"] != "123-45-6789" { + t.Error("original ssn was modified") + } +} + +func TestFieldProtectionProvidesServices(t *testing.T) { + cfg := map[string]any{ + "master_key": "svc-test-key-32-chars-exactly!!", + "protected_fields": []any{}, + } + + mod, err := NewFieldProtectionModule("my-fp", cfg) + if err != nil { + t.Fatal(err) + } + + svcs := mod.ProvidesServices() + if len(svcs) != 1 { + t.Fatalf("expected 1 service, got %d", len(svcs)) + } + if svcs[0].Name != "my-fp" { + t.Errorf("service name = %q", svcs[0].Name) + } + if _, ok := svcs[0].Instance.(*ProtectedFieldManager); !ok { + t.Error("service instance should be *ProtectedFieldManager") + } +} diff --git a/module/http_server.go b/module/http_server.go index a93d79ee..f562becc 100644 --- a/module/http_server.go +++ b/module/http_server.go @@ -2,14 +2,26 @@ package module import ( "context" + "crypto/tls" "errors" "fmt" "net/http" "time" "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" + "golang.org/x/crypto/acme/autocert" ) +// HTTPServerTLSConfig holds TLS configuration for the HTTP server. +type HTTPServerTLSConfig struct { + Mode string `yaml:"mode" json:"mode"` // manual | autocert | disabled + Manual tlsutil.TLSConfig `yaml:"manual" json:"manual"` + Autocert tlsutil.AutocertConfig `yaml:"autocert" json:"autocert"` + ClientCAFile string `yaml:"client_ca_file" json:"client_ca_file"` + ClientAuth string `yaml:"client_auth" json:"client_auth"` // require | request | none +} + // StandardHTTPServer implements the HTTPServer interface and modular.Module interfaces type StandardHTTPServer struct { name string @@ -20,6 +32,7 @@ type StandardHTTPServer struct { readTimeout time.Duration writeTimeout time.Duration idleTimeout time.Duration + tlsCfg HTTPServerTLSConfig } // NewStandardHTTPServer creates a new HTTP server with the given name and address @@ -38,6 +51,11 @@ func (s *StandardHTTPServer) SetTimeouts(read, write, idle time.Duration) { s.idleTimeout = idle } +// SetTLSConfig configures TLS for the HTTP server. +func (s *StandardHTTPServer) SetTLSConfig(cfg HTTPServerTLSConfig) { + s.tlsCfg = cfg +} + // Name returns the unique identifier for this module func (s *StandardHTTPServer) Name() string { return s.name @@ -87,14 +105,90 @@ func (s *StandardHTTPServer) Start(ctx context.Context) error { IdleTimeout: timeoutOrDefault(s.idleTimeout, 120*time.Second), } - // Start the server in a goroutine + switch s.tlsCfg.Mode { + case "autocert": + return s.startAutocert(ctx) + case "manual": + return s.startManualTLS(ctx) + default: + // Plain HTTP + go func() { + if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("HTTP server error", "error", err) + } + }() + s.logger.Info("HTTP server started", "address", s.address) + return nil + } +} + +// startManualTLS starts the server with manually configured TLS certificates. +func (s *StandardHTTPServer) startManualTLS(ctx context.Context) error { + manualCfg := s.tlsCfg.Manual + manualCfg.Enabled = true + + // Overlay mTLS settings from the top-level fields when set + if s.tlsCfg.ClientCAFile != "" { + manualCfg.CAFile = s.tlsCfg.ClientCAFile + } + if s.tlsCfg.ClientAuth != "" { + manualCfg.ClientAuth = s.tlsCfg.ClientAuth + } + + tlsConfig, err := tlsutil.LoadTLSConfig(manualCfg) + if err != nil { + return fmt.Errorf("http server TLS config: %w", err) + } + s.server.TLSConfig = tlsConfig + go func() { - if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - s.logger.Error("HTTP server error", "error", err) + if err := s.server.ListenAndServeTLS(manualCfg.CertFile, manualCfg.KeyFile); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("HTTPS server error", "error", err) } }() + s.logger.Info("HTTPS server started (manual TLS)", "address", s.address) + return nil +} - s.logger.Info("HTTP server started", "address", s.address) +// startAutocert starts the server using Let's Encrypt via autocert. +func (s *StandardHTTPServer) startAutocert(ctx context.Context) error { + ac := s.tlsCfg.Autocert + if len(ac.Domains) == 0 { + return fmt.Errorf("http server autocert: at least one domain is required") + } + + m := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(ac.Domains...), + Email: ac.Email, + } + if ac.CacheDir != "" { + m.Cache = autocert.DirCache(ac.CacheDir) + } + + s.server.TLSConfig = &tls.Config{ + GetCertificate: m.GetCertificate, + MinVersion: tls.VersionTLS12, + } + + // ACME HTTP-01 challenge listener on :80 + go func() { + httpSrv := &http.Server{ + Addr: ":80", + Handler: m.HTTPHandler(nil), + ReadHeaderTimeout: 10 * time.Second, + } + if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("autocert HTTP-01 listener error", "error", err) + } + }() + + go func() { + if err := s.server.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("HTTPS server error (autocert)", "error", err) + } + }() + s.logger.Info("HTTPS server started (autocert)", "address", s.address, "domains", ac.Domains) return nil } diff --git a/module/jwt_auth.go b/module/jwt_auth.go index f047e0d1..d6762266 100644 --- a/module/jwt_auth.go +++ b/module/jwt_auth.go @@ -13,6 +13,7 @@ import ( "github.com/CrisisTextLine/modular" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "golang.org/x/crypto/bcrypt" ) @@ -44,6 +45,7 @@ type JWTAuthModule struct { persistence *PersistenceStore // optional write-through backend userStore *UserStore // optional external user store (from auth.user-store module) allowRegistration bool // when true, any visitor may self-register + tokenBlacklist TokenBlacklist // optional revocation check (wired by auth plugin) } // NewJWTAuthModule creates a new JWT auth module @@ -84,6 +86,12 @@ func (j *JWTAuthModule) SetAllowRegistration(allow bool) { j.allowRegistration = allow } +// SetTokenBlacklist wires a TokenBlacklist to this module so that revoked +// tokens are rejected during Authenticate. +func (j *JWTAuthModule) SetTokenBlacklist(bl TokenBlacklist) { + j.tokenBlacklist = bl +} + // Name returns the module name func (j *JWTAuthModule) Name() string { return j.name @@ -139,6 +147,15 @@ func (j *JWTAuthModule) Authenticate(tokenStr string) (bool, map[string]any, err return false, nil, nil } + // Check token revocation if a blacklist is wired. + if j.tokenBlacklist != nil { + if jti, ok := claims["jti"].(string); ok && jti != "" { + if j.tokenBlacklist.IsBlacklisted(jti) { + return false, nil, nil + } + } + } + result := make(map[string]any) maps.Copy(result, claims) return true, result, nil @@ -446,6 +463,7 @@ func (j *JWTAuthModule) generateToken(user *User) (string, error) { "email": user.Email, "name": user.Name, "iss": j.issuer, + "jti": uuid.NewString(), "iat": time.Now().Unix(), "exp": time.Now().Add(j.tokenExpiry).Unix(), } diff --git a/module/kafka_broker.go b/module/kafka_broker.go index e27177d0..b20d2c7a 100644 --- a/module/kafka_broker.go +++ b/module/kafka_broker.go @@ -7,8 +7,22 @@ import ( "github.com/CrisisTextLine/modular" "github.com/IBM/sarama" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" ) +// KafkaSASLConfig holds SASL authentication configuration for Kafka. +type KafkaSASLConfig struct { + Mechanism string `yaml:"mechanism" json:"mechanism"` // PLAIN | SCRAM-SHA-256 | SCRAM-SHA-512 + Username string `yaml:"username" json:"username"` + Password string `yaml:"password" json:"password"` //nolint:gosec // G117: config struct field +} + +// KafkaTLSConfig holds TLS and SASL configuration for the Kafka broker. +type KafkaTLSConfig struct { + tlsutil.TLSConfig `yaml:",inline" json:",inline"` + SASL KafkaSASLConfig `yaml:"sasl" json:"sasl"` +} + // KafkaBroker implements the MessageBroker interface using Apache Kafka via Sarama. type KafkaBroker struct { name string @@ -25,6 +39,7 @@ type KafkaBroker struct { healthy bool healthMsg string encryptor *FieldEncryptor + tlsCfg KafkaTLSConfig } // NewKafkaBroker creates a new Kafka message broker. @@ -93,6 +108,13 @@ func (b *KafkaBroker) SetGroupID(groupID string) { b.groupID = groupID } +// SetTLSConfig sets the TLS and SASL configuration for the Kafka broker. +func (b *KafkaBroker) SetTLSConfig(cfg KafkaTLSConfig) { + b.mu.Lock() + defer b.mu.Unlock() + b.tlsCfg = cfg +} + // HealthStatus implements the HealthCheckable interface. func (b *KafkaBroker) HealthStatus() HealthCheckResult { b.mu.RLock() @@ -146,6 +168,38 @@ func (b *KafkaBroker) Start(ctx context.Context) error { config.Consumer.Group.Rebalance.GroupStrategies = []sarama.BalanceStrategy{sarama.NewBalanceStrategyRoundRobin()} config.Consumer.Offsets.Initial = sarama.OffsetNewest + // Apply TLS configuration + if b.tlsCfg.Enabled { + tlsCfg, tlsErr := tlsutil.LoadTLSConfig(b.tlsCfg.TLSConfig) + if tlsErr != nil { + return fmt.Errorf("kafka broker %q: TLS config: %w", b.name, tlsErr) + } + config.Net.TLS.Enable = true + config.Net.TLS.Config = tlsCfg + } + + // Apply SASL configuration + sasl := b.tlsCfg.SASL + if sasl.Username != "" { + config.Net.SASL.Enable = true + config.Net.SASL.User = sasl.Username + config.Net.SASL.Password = sasl.Password + switch sasl.Mechanism { + case "SCRAM-SHA-256": + config.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA256 + config.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { + return &xDGSCRAMClient{HashGeneratorFcn: SHA256} + } + case "SCRAM-SHA-512": + config.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA512 + config.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { + return &xDGSCRAMClient{HashGeneratorFcn: SHA512} + } + default: // PLAIN + config.Net.SASL.Mechanism = sarama.SASLTypePlaintext + } + } + // Create sync producer producer, err := sarama.NewSyncProducer(b.brokers, config) if err != nil { diff --git a/module/kafka_scram.go b/module/kafka_scram.go new file mode 100644 index 00000000..34c005fa --- /dev/null +++ b/module/kafka_scram.go @@ -0,0 +1,44 @@ +package module + +import ( + "crypto/sha256" + "crypto/sha512" + "hash" + + "github.com/xdg-go/scram" +) + +// SHA256 and SHA512 are hash generator functions used by the SCRAM client. +var ( + SHA256 scram.HashGeneratorFcn = sha256.New + SHA512 scram.HashGeneratorFcn = sha512.New +) + +// xDGSCRAMClient implements sarama.SCRAMClient using the xdg-go/scram package. +type xDGSCRAMClient struct { + *scram.Client + *scram.ClientConversation + scram.HashGeneratorFcn +} + +func (x *xDGSCRAMClient) Begin(userName, password, authzID string) error { + client, err := x.HashGeneratorFcn.NewClient(userName, password, authzID) + if err != nil { + return err + } + x.Client = client + x.ClientConversation = client.NewConversation() + return nil +} + +func (x *xDGSCRAMClient) Step(challenge string) (string, error) { + return x.ClientConversation.Step(challenge) +} + +func (x *xDGSCRAMClient) Done() bool { + return x.ClientConversation.Done() +} + +// ensure hash functions satisfy the interface (compile-time check) +var _ func() hash.Hash = sha256.New +var _ func() hash.Hash = sha512.New diff --git a/module/nats_broker.go b/module/nats_broker.go index 1ef24451..ddb7eaea 100644 --- a/module/nats_broker.go +++ b/module/nats_broker.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/pkg/tlsutil" "github.com/nats-io/nats.go" ) @@ -20,6 +21,7 @@ type NATSBroker struct { producer *natsProducer consumer *natsConsumer logger modular.Logger + tlsCfg tlsutil.TLSConfig } // NewNATSBroker creates a new NATS message broker. @@ -80,6 +82,13 @@ func (b *NATSBroker) SetURL(url string) { b.url = url } +// SetTLSConfig configures TLS for the NATS broker connection. +func (b *NATSBroker) SetTLSConfig(cfg tlsutil.TLSConfig) { + b.mu.Lock() + defer b.mu.Unlock() + b.tlsCfg = cfg +} + // Producer returns the message producer interface. func (b *NATSBroker) Producer() MessageProducer { return b.producer @@ -100,7 +109,16 @@ func (b *NATSBroker) Start(ctx context.Context) error { b.mu.Lock() defer b.mu.Unlock() - conn, err := nats.Connect(b.url) + var opts []nats.Option + if b.tlsCfg.Enabled { + tlsCfg, tlsErr := tlsutil.LoadTLSConfig(b.tlsCfg) + if tlsErr != nil { + return fmt.Errorf("nats broker %q: TLS config: %w", b.name, tlsErr) + } + opts = append(opts, nats.Secure(tlsCfg)) + } + + conn, err := nats.Connect(b.url, opts...) if err != nil { return fmt.Errorf("failed to connect to NATS at %s: %w", b.url, err) } diff --git a/module/pipeline_step_field_reencrypt.go b/module/pipeline_step_field_reencrypt.go new file mode 100644 index 00000000..27d3173e --- /dev/null +++ b/module/pipeline_step_field_reencrypt.go @@ -0,0 +1,73 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" +) + +// FieldReencryptStep re-encrypts pipeline context data with the latest key version. +type FieldReencryptStep struct { + name string + module string // field-protection module name to look up + tenantID string // template expression for tenant ID + tmpl *TemplateEngine + app modular.Application +} + +// NewFieldReencryptStepFactory returns a StepFactory for step.field_reencrypt. +func NewFieldReencryptStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + moduleName, _ := config["module"].(string) + if moduleName == "" { + return nil, fmt.Errorf("field_reencrypt step %q: 'module' is required", name) + } + tenantID, _ := config["tenant_id"].(string) + + return &FieldReencryptStep{ + name: name, + module: moduleName, + tenantID: tenantID, + tmpl: NewTemplateEngine(), + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *FieldReencryptStep) Name() string { return s.name } + +// Execute re-encrypts data by decrypting with the old key and encrypting with the current key. +func (s *FieldReencryptStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + // Look up the ProtectedFieldManager from the service registry. + var manager *ProtectedFieldManager + if err := s.app.GetService(s.module, &manager); err != nil { + return nil, fmt.Errorf("field_reencrypt step %q: module %q not found: %w", s.name, s.module, err) + } + + // Resolve tenant ID from template expression. + tenantID := s.tenantID + if tenantID != "" { + resolved, err := s.tmpl.Resolve(tenantID, pc) + if err == nil && resolved != "" { + tenantID = resolved + } + } + + // Decrypt with old key version, then re-encrypt with current key. + data := pc.Current + if data == nil { + return &StepResult{Output: map[string]any{"reencrypted": false, "reason": "no data"}}, nil + } + + if err := manager.DecryptMap(ctx, tenantID, data); err != nil { + return nil, fmt.Errorf("field_reencrypt step %q: decrypt: %w", s.name, err) + } + + if err := manager.EncryptMap(ctx, tenantID, data); err != nil { + return nil, fmt.Errorf("field_reencrypt step %q: encrypt: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{"reencrypted": true}}, nil +} diff --git a/module/pipeline_step_sandbox_exec.go b/module/pipeline_step_sandbox_exec.go new file mode 100644 index 00000000..d63a4a88 --- /dev/null +++ b/module/pipeline_step_sandbox_exec.go @@ -0,0 +1,231 @@ +package module + +import ( + "context" + "fmt" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/sandbox" +) + +const defaultSandboxImage = "cgr.dev/chainguard/wolfi-base:latest" + +// SandboxExecStep runs a command in a hardened Docker sandbox container. +type SandboxExecStep struct { + name string + image string + command []string + securityProfile string + memoryLimit int64 + cpuLimit float64 + timeout time.Duration + network string + env map[string]string + mounts []sandbox.Mount + failOnError bool +} + +// NewSandboxExecStepFactory returns a StepFactory for step.sandbox_exec. +func NewSandboxExecStepFactory() StepFactory { + return func(name string, cfg map[string]any, _ modular.Application) (PipelineStep, error) { + step := &SandboxExecStep{ + name: name, + image: defaultSandboxImage, + securityProfile: "strict", + failOnError: true, + } + + if img, ok := cfg["image"].(string); ok && img != "" { + step.image = img + } + + // command + switch v := cfg["command"].(type) { + case []any: + for i, c := range v { + s, ok := c.(string) + if !ok { + return nil, fmt.Errorf("sandbox_exec step %q: command[%d] must be a string", name, i) + } + step.command = append(step.command, s) + } + case []string: + step.command = v + case nil: + // allowed — step may be used without a command for future use + default: + return nil, fmt.Errorf("sandbox_exec step %q: 'command' must be a list of strings", name) + } + + if profile, ok := cfg["security_profile"].(string); ok && profile != "" { + switch profile { + case "strict", "standard", "permissive": + step.securityProfile = profile + default: + return nil, fmt.Errorf("sandbox_exec step %q: security_profile must be strict, standard, or permissive", name) + } + } + + if ms, ok := cfg["memory_limit"].(string); ok && ms != "" { + limit, err := parseMemoryLimit(ms) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: invalid memory_limit: %w", name, err) + } + step.memoryLimit = limit + } + + if cpu, ok := cfg["cpu_limit"].(float64); ok { + step.cpuLimit = cpu + } + + if ts, ok := cfg["timeout"].(string); ok && ts != "" { + d, err := time.ParseDuration(ts) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: invalid timeout %q: %w", name, ts, err) + } + step.timeout = d + } + + if net, ok := cfg["network"].(string); ok && net != "" { + step.network = net + } + + if envRaw, ok := cfg["env"].(map[string]any); ok { + step.env = make(map[string]string, len(envRaw)) + for k, v := range envRaw { + step.env[k] = fmt.Sprintf("%v", v) + } + } + + if mountsRaw, ok := cfg["mounts"].([]any); ok { + for i, m := range mountsRaw { + mmap, ok := m.(map[string]any) + if !ok { + return nil, fmt.Errorf("sandbox_exec step %q: mounts[%d] must be a map", name, i) + } + src, _ := mmap["source"].(string) + tgt, _ := mmap["target"].(string) + ro, _ := mmap["read_only"].(bool) + step.mounts = append(step.mounts, sandbox.Mount{Source: src, Target: tgt, ReadOnly: ro}) + } + } + + if foe, ok := cfg["fail_on_error"].(bool); ok { + step.failOnError = foe + } + + return step, nil + } +} + +// Name returns the step name. +func (s *SandboxExecStep) Name() string { return s.name } + +// Execute runs the configured command in a Docker sandbox. +func (s *SandboxExecStep) Execute(ctx context.Context, _ *PipelineContext) (*StepResult, error) { + sbCfg := s.buildSandboxConfig() + + sb, err := sandbox.NewDockerSandbox(sbCfg) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: failed to create sandbox: %w", s.name, err) + } + defer sb.Close() + + result, err := sb.Exec(ctx, s.command) + if err != nil { + return nil, fmt.Errorf("sandbox_exec step %q: execution failed: %w", s.name, err) + } + + output := map[string]any{ + "exit_code": result.ExitCode, + "stdout": result.Stdout, + "stderr": result.Stderr, + } + + if result.ExitCode != 0 && s.failOnError { + return &StepResult{Output: output, Stop: true}, nil + } + + return &StepResult{Output: output}, nil +} + +// buildSandboxConfig constructs a SandboxConfig based on the security profile +// and any explicit overrides provided in the step config. +func (s *SandboxExecStep) buildSandboxConfig() sandbox.SandboxConfig { + var cfg sandbox.SandboxConfig + + switch s.securityProfile { + case "permissive": + cfg = sandbox.SandboxConfig{ + Image: s.image, + NetworkMode: "bridge", + } + case "standard": + cfg = sandbox.SandboxConfig{ + Image: s.image, + MemoryLimit: 256 * 1024 * 1024, + CPULimit: 0.5, + NetworkMode: "bridge", + CapDrop: []string{"NET_ADMIN", "SYS_ADMIN", "SYS_PTRACE", "SETUID", "SETGID"}, + CapAdd: []string{"NET_BIND_SERVICE"}, + NoNewPrivileges: true, + PidsLimit: 64, + Timeout: 5 * time.Minute, + } + default: // "strict" + cfg = sandbox.DefaultSecureSandboxConfig(s.image) + } + + // Apply explicit overrides + if s.memoryLimit > 0 { + cfg.MemoryLimit = s.memoryLimit + } + if s.cpuLimit > 0 { + cfg.CPULimit = s.cpuLimit + } + if s.timeout > 0 { + cfg.Timeout = s.timeout + } + if s.network != "" { + cfg.NetworkMode = s.network + } + if len(s.env) > 0 { + cfg.Env = s.env + } + if len(s.mounts) > 0 { + cfg.Mounts = s.mounts + } + + return cfg +} + +// parseMemoryLimit parses a human-readable memory string (e.g., "128m", "1g") into bytes. +func parseMemoryLimit(s string) (int64, error) { + if len(s) == 0 { + return 0, fmt.Errorf("empty memory limit") + } + last := s[len(s)-1] + var multiplier int64 = 1 + numStr := s + switch last { + case 'k', 'K': + multiplier = 1024 + numStr = s[:len(s)-1] + case 'm', 'M': + multiplier = 1024 * 1024 + numStr = s[:len(s)-1] + case 'g', 'G': + multiplier = 1024 * 1024 * 1024 + numStr = s[:len(s)-1] + case 'b', 'B': + numStr = s[:len(s)-1] + } + + var n int64 + _, err := fmt.Sscanf(numStr, "%d", &n) + if err != nil { + return 0, fmt.Errorf("invalid memory limit %q", s) + } + return n * multiplier, nil +} diff --git a/module/pipeline_step_sandbox_exec_test.go b/module/pipeline_step_sandbox_exec_test.go new file mode 100644 index 00000000..6a9f4e78 --- /dev/null +++ b/module/pipeline_step_sandbox_exec_test.go @@ -0,0 +1,302 @@ +package module + +import ( + "testing" + "time" +) + +func TestNewSandboxExecStepFactory_Defaults(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("test-step", map[string]any{ + "command": []any{"echo", "hello"}, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.name != "test-step" { + t.Fatalf("unexpected name: %s", s.name) + } + if s.image != defaultSandboxImage { + t.Fatalf("expected default image, got %s", s.image) + } + if s.securityProfile != "strict" { + t.Fatalf("expected strict profile, got %s", s.securityProfile) + } + if !s.failOnError { + t.Fatal("expected failOnError true by default") + } + if len(s.command) != 2 || s.command[0] != "echo" { + t.Fatalf("unexpected command: %v", s.command) + } +} + +func TestNewSandboxExecStepFactory_CustomImage(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "image": "alpine:3.19", + "command": []any{"ls"}, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.image != "alpine:3.19" { + t.Fatalf("unexpected image: %s", s.image) + } +} + +func TestNewSandboxExecStepFactory_SecurityProfiles(t *testing.T) { + factory := NewSandboxExecStepFactory() + + for _, profile := range []string{"strict", "standard", "permissive"} { + step, err := factory("s", map[string]any{ + "security_profile": profile, + "command": []any{"ls"}, + }, nil) + if err != nil { + t.Fatalf("unexpected error for profile %q: %v", profile, err) + } + s := step.(*SandboxExecStep) + if s.securityProfile != profile { + t.Fatalf("expected profile %q, got %q", profile, s.securityProfile) + } + } +} + +func TestNewSandboxExecStepFactory_InvalidProfile(t *testing.T) { + factory := NewSandboxExecStepFactory() + _, err := factory("s", map[string]any{ + "security_profile": "unknown", + "command": []any{"ls"}, + }, nil) + if err == nil { + t.Fatal("expected error for invalid security_profile") + } +} + +func TestNewSandboxExecStepFactory_MemoryLimit(t *testing.T) { + factory := NewSandboxExecStepFactory() + tests := []struct { + input string + expected int64 + }{ + {"128m", 128 * 1024 * 1024}, + {"256M", 256 * 1024 * 1024}, + {"1g", 1024 * 1024 * 1024}, + {"512k", 512 * 1024}, + } + for _, tt := range tests { + step, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "memory_limit": tt.input, + }, nil) + if err != nil { + t.Fatalf("input %q: unexpected error: %v", tt.input, err) + } + s := step.(*SandboxExecStep) + if s.memoryLimit != tt.expected { + t.Fatalf("input %q: expected %d, got %d", tt.input, tt.expected, s.memoryLimit) + } + } +} + +func TestNewSandboxExecStepFactory_Timeout(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "timeout": "30s", + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.timeout != 30*time.Second { + t.Fatalf("expected 30s, got %s", s.timeout) + } +} + +func TestNewSandboxExecStepFactory_InvalidTimeout(t *testing.T) { + factory := NewSandboxExecStepFactory() + _, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "timeout": "not-a-duration", + }, nil) + if err == nil { + t.Fatal("expected error for invalid timeout") + } +} + +func TestNewSandboxExecStepFactory_FailOnError(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "command": []any{"ls"}, + "fail_on_error": false, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.failOnError { + t.Fatal("expected failOnError false") + } +} + +func TestNewSandboxExecStepFactory_EnvAndNetwork(t *testing.T) { + factory := NewSandboxExecStepFactory() + step, err := factory("s", map[string]any{ + "command": []any{"env"}, + "env": map[string]any{"FOO": "bar", "NUM": 42}, + "network": "bridge", + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := step.(*SandboxExecStep) + if s.env["FOO"] != "bar" { + t.Fatalf("unexpected FOO: %s", s.env["FOO"]) + } + if s.env["NUM"] != "42" { + t.Fatalf("unexpected NUM: %s", s.env["NUM"]) + } + if s.network != "bridge" { + t.Fatalf("unexpected network: %s", s.network) + } +} + +func TestSandboxExecStep_Name(t *testing.T) { + s := &SandboxExecStep{name: "my-step"} + if s.Name() != "my-step" { + t.Fatalf("unexpected name: %s", s.Name()) + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Strict(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "strict", + } + cfg := s.buildSandboxConfig() + + if cfg.NetworkMode != "none" { + t.Fatalf("strict: expected network none, got %s", cfg.NetworkMode) + } + if len(cfg.CapDrop) != 1 || cfg.CapDrop[0] != "ALL" { + t.Fatalf("strict: expected CapDrop ALL, got %v", cfg.CapDrop) + } + if !cfg.NoNewPrivileges { + t.Fatal("strict: expected NoNewPrivileges true") + } + if !cfg.ReadOnlyRootfs { + t.Fatal("strict: expected ReadOnlyRootfs true") + } + if cfg.PidsLimit != 64 { + t.Fatalf("strict: expected PidsLimit 64, got %d", cfg.PidsLimit) + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Standard(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "standard", + } + cfg := s.buildSandboxConfig() + + if cfg.NetworkMode != "bridge" { + t.Fatalf("standard: expected network bridge, got %s", cfg.NetworkMode) + } + if len(cfg.CapAdd) == 0 { + t.Fatal("standard: expected NET_BIND_SERVICE in CapAdd") + } + if !cfg.NoNewPrivileges { + t.Fatal("standard: expected NoNewPrivileges true") + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Permissive(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "permissive", + } + cfg := s.buildSandboxConfig() + + if cfg.NetworkMode != "bridge" { + t.Fatalf("permissive: expected network bridge, got %s", cfg.NetworkMode) + } + if len(cfg.CapDrop) > 0 { + t.Fatalf("permissive: expected no CapDrop, got %v", cfg.CapDrop) + } + if cfg.ReadOnlyRootfs { + t.Fatal("permissive: expected ReadOnlyRootfs false") + } +} + +func TestSandboxExecStep_BuildSandboxConfig_Overrides(t *testing.T) { + s := &SandboxExecStep{ + image: "alpine:3.19", + securityProfile: "strict", + memoryLimit: 512 * 1024 * 1024, + cpuLimit: 2.0, + timeout: 10 * time.Second, + network: "bridge", + } + cfg := s.buildSandboxConfig() + + if cfg.MemoryLimit != 512*1024*1024 { + t.Fatalf("unexpected MemoryLimit: %d", cfg.MemoryLimit) + } + if cfg.CPULimit != 2.0 { + t.Fatalf("unexpected CPULimit: %f", cfg.CPULimit) + } + if cfg.Timeout != 10*time.Second { + t.Fatalf("unexpected Timeout: %s", cfg.Timeout) + } + if cfg.NetworkMode != "bridge" { + t.Fatalf("unexpected NetworkMode: %s", cfg.NetworkMode) + } +} + +func TestParseMemoryLimit(t *testing.T) { + tests := []struct { + input string + expected int64 + wantErr bool + }{ + {"128m", 128 * 1024 * 1024, false}, + {"256M", 256 * 1024 * 1024, false}, + {"1g", 1024 * 1024 * 1024, false}, + {"2G", 2 * 1024 * 1024 * 1024, false}, + {"512k", 512 * 1024, false}, + {"1024K", 1024 * 1024, false}, + {"1024", 1024, false}, + {"1024b", 1024, false}, + {"", 0, true}, + {"abc", 0, true}, + } + + for _, tt := range tests { + got, err := parseMemoryLimit(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("input %q: expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Fatalf("input %q: unexpected error: %v", tt.input, err) + } + if got != tt.expected { + t.Fatalf("input %q: expected %d, got %d", tt.input, tt.expected, got) + } + } +} + +func TestNewSandboxExecStepFactory_InvalidCommandType(t *testing.T) { + factory := NewSandboxExecStepFactory() + _, err := factory("s", map[string]any{ + "command": "should-be-a-list", + }, nil) + if err == nil { + t.Fatal("expected error for non-list command") + } +} diff --git a/module/pipeline_step_secret_rotate.go b/module/pipeline_step_secret_rotate.go new file mode 100644 index 00000000..473c89cd --- /dev/null +++ b/module/pipeline_step_secret_rotate.go @@ -0,0 +1,83 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" + "github.com/GoCodeAlone/workflow/secrets" +) + +// SecretRotateStep rotates a secret in a RotationProvider and returns the new value. +type SecretRotateStep struct { + name string + provider string // service name of the secrets RotationProvider module + key string // the secret key to rotate + notifyModule string // optional module name to notify of rotation + app modular.Application +} + +// NewSecretRotateStepFactory returns a StepFactory for step.secret_rotate. +func NewSecretRotateStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + provider, _ := config["provider"].(string) + if provider == "" { + return nil, fmt.Errorf("secret_rotate step %q: 'provider' is required", name) + } + + key, _ := config["key"].(string) + if key == "" { + return nil, fmt.Errorf("secret_rotate step %q: 'key' is required", name) + } + + notifyModule, _ := config["notify_module"].(string) + + return &SecretRotateStep{ + name: name, + provider: provider, + key: key, + notifyModule: notifyModule, + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *SecretRotateStep) Name() string { return s.name } + +// Execute rotates the secret by calling RotationProvider.Rotate and returns output +// indicating the rotation was successful. +func (s *SecretRotateStep) Execute(ctx context.Context, _ *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("secret_rotate step %q: no application context", s.name) + } + + rp, err := s.resolveProvider() + if err != nil { + return nil, err + } + + if _, err := rp.Rotate(ctx, s.key); err != nil { + return nil, fmt.Errorf("secret_rotate step %q: rotate failed: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{ + "rotated": true, + "key": s.key, + "provider": s.provider, + }}, nil +} + +// resolveProvider looks up the secrets provider from the service registry and +// asserts it implements secrets.RotationProvider. +func (s *SecretRotateStep) resolveProvider() (secrets.RotationProvider, error) { + svc, ok := s.app.SvcRegistry()[s.provider] + if !ok { + return nil, fmt.Errorf("secret_rotate step %q: provider service %q not found", s.name, s.provider) + } + rp, ok := svc.(secrets.RotationProvider) + if !ok { + return nil, fmt.Errorf("secret_rotate step %q: service %q does not implement secrets.RotationProvider", s.name, s.provider) + } + return rp, nil +} diff --git a/module/pipeline_step_secret_rotate_test.go b/module/pipeline_step_secret_rotate_test.go new file mode 100644 index 00000000..11b7c63d --- /dev/null +++ b/module/pipeline_step_secret_rotate_test.go @@ -0,0 +1,184 @@ +package module + +import ( + "context" + "errors" + "testing" + + "github.com/GoCodeAlone/workflow/secrets" +) + +// mockRotationProvider is a mock secrets.RotationProvider for testing. +type mockRotationProvider struct { + rotateVal string + rotateErr error + prevVal string + prevErr error +} + +func (m *mockRotationProvider) Name() string { return "mock" } +func (m *mockRotationProvider) Get(_ context.Context, _ string) (string, error) { + return "", nil +} +func (m *mockRotationProvider) Set(_ context.Context, _, _ string) error { return nil } +func (m *mockRotationProvider) Delete(_ context.Context, _ string) error { return nil } +func (m *mockRotationProvider) List(_ context.Context) ([]string, error) { return nil, nil } +func (m *mockRotationProvider) Rotate(_ context.Context, _ string) (string, error) { + return m.rotateVal, m.rotateErr +} +func (m *mockRotationProvider) GetPrevious(_ context.Context, _ string) (string, error) { + return m.prevVal, m.prevErr +} + +// Compile-time check that mockRotationProvider satisfies secrets.RotationProvider. +var _ secrets.RotationProvider = (*mockRotationProvider)(nil) + +// ---- factory validation tests ---- + +func TestSecretRotateStep_MissingProvider(t *testing.T) { + factory := NewSecretRotateStepFactory() + _, err := factory("rotate-step", map[string]any{ + "key": "myapp/db-pass", + }, nil) + if err == nil { + t.Fatal("expected error when 'provider' is missing") + } +} + +func TestSecretRotateStep_MissingKey(t *testing.T) { + factory := NewSecretRotateStepFactory() + _, err := factory("rotate-step", map[string]any{ + "provider": "vault", + }, nil) + if err == nil { + t.Fatal("expected error when 'key' is missing") + } +} + +func TestSecretRotateStep_ValidConfig(t *testing.T) { + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + "notify_module": "slack-notifier", + }, nil) + if err != nil { + t.Fatalf("unexpected factory error: %v", err) + } + if step.Name() != "rotate-step" { + t.Errorf("expected name 'rotate-step', got %q", step.Name()) + } +} + +// ---- Execute tests ---- + +func TestSecretRotateStep_Execute_Success(t *testing.T) { + mock := &mockRotationProvider{rotateVal: "new-secret-abc123"} + app := NewMockApplication() + app.Services["vault"] = mock + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + + if result.Output["rotated"] != true { + t.Errorf("expected rotated=true, got %v", result.Output["rotated"]) + } + if result.Output["key"] != "myapp/db-pass" { + t.Errorf("expected key='myapp/db-pass', got %v", result.Output["key"]) + } + if result.Output["provider"] != "vault" { + t.Errorf("expected provider='vault', got %v", result.Output["provider"]) + } +} + +func TestSecretRotateStep_Execute_RotateError(t *testing.T) { + mock := &mockRotationProvider{rotateErr: errors.New("vault unavailable")} + app := NewMockApplication() + app.Services["vault"] = mock + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error from Rotate failure") + } +} + +func TestSecretRotateStep_Execute_ProviderNotFound(t *testing.T) { + app := NewMockApplication() + // No services registered. + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for missing provider service") + } +} + +func TestSecretRotateStep_Execute_ProviderWrongType(t *testing.T) { + app := NewMockApplication() + // Register something that doesn't implement RotationProvider. + app.Services["vault"] = "not-a-rotation-provider" + + factory := NewSecretRotateStepFactory() + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when service does not implement RotationProvider") + } +} + +func TestSecretRotateStep_Execute_NoApp(t *testing.T) { + factory := NewSecretRotateStepFactory() + // Pass nil app at factory time is allowed; error comes at Execute time. + step, err := factory("rotate-step", map[string]any{ + "provider": "vault", + "key": "myapp/db-pass", + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when app is nil") + } +} diff --git a/module/pipeline_step_token_revoke.go b/module/pipeline_step_token_revoke.go new file mode 100644 index 00000000..c3737943 --- /dev/null +++ b/module/pipeline_step_token_revoke.go @@ -0,0 +1,95 @@ +package module + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/golang-jwt/jwt/v5" +) + +// TokenRevokeStep extracts a JWT from the pipeline context, reads its jti and +// exp claims without signature validation, and adds the JTI to the configured +// token blacklist module. +type TokenRevokeStep struct { + name string + blacklistModule string // service name of the TokenBlacklist module + tokenSource string // dot-path to the token in pipeline context + app modular.Application +} + +// NewTokenRevokeStepFactory returns a StepFactory for step.token_revoke. +func NewTokenRevokeStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + blacklistModule, _ := config["blacklist_module"].(string) + if blacklistModule == "" { + return nil, fmt.Errorf("token_revoke step %q: 'blacklist_module' is required", name) + } + + tokenSource, _ := config["token_source"].(string) + if tokenSource == "" { + return nil, fmt.Errorf("token_revoke step %q: 'token_source' is required", name) + } + + return &TokenRevokeStep{ + name: name, + blacklistModule: blacklistModule, + tokenSource: tokenSource, + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *TokenRevokeStep) Name() string { return s.name } + +// Execute revokes the JWT by extracting its JTI and adding it to the blacklist. +func (s *TokenRevokeStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { + // 1. Extract token string from pipeline context. + rawToken := resolveBodyFrom(s.tokenSource, pc) + tokenStr, _ := rawToken.(string) + if tokenStr == "" { + return &StepResult{Output: map[string]any{"revoked": false, "error": "missing token"}}, nil + } + + // 2. Strip "Bearer " prefix if present. + tokenStr = strings.TrimPrefix(tokenStr, "Bearer ") + + // 3. Parse claims without signature validation (token is being revoked, not authenticated). + parser := jwt.NewParser() + token, _, err := parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid token format"}}, nil + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid claims"}}, nil + } + + jti, _ := claims["jti"].(string) + if jti == "" { + return &StepResult{Output: map[string]any{"revoked": false, "error": "token has no jti claim"}}, nil + } + + // 4. Determine token expiry from exp claim. + var expiresAt time.Time + switch exp := claims["exp"].(type) { + case float64: + expiresAt = time.Unix(int64(exp), 0) + default: + expiresAt = time.Now().Add(24 * time.Hour) // safe fallback + } + + // 5. Resolve the blacklist module and add the JTI. + var blacklist TokenBlacklist + if err := s.app.GetService(s.blacklistModule, &blacklist); err != nil { + return nil, fmt.Errorf("token_revoke step %q: blacklist module %q not found: %w", s.name, s.blacklistModule, err) + } + + blacklist.Add(jti, expiresAt) + + return &StepResult{Output: map[string]any{"revoked": true, "jti": jti}}, nil +} diff --git a/module/pipeline_step_token_revoke_test.go b/module/pipeline_step_token_revoke_test.go new file mode 100644 index 00000000..74650dd5 --- /dev/null +++ b/module/pipeline_step_token_revoke_test.go @@ -0,0 +1,309 @@ +package module + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// makeTestJWT creates a signed JWT with the given claims using HS256. +func makeTestJWT(t *testing.T, secret string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + s, err := tok.SignedString([]byte(secret)) + if err != nil { + t.Fatalf("makeTestJWT: %v", err) + } + return s +} + +// mockBlacklist is a simple in-memory TokenBlacklist for testing. +type mockBlacklist struct { + entries map[string]time.Time +} + +func newMockBlacklist() *mockBlacklist { + return &mockBlacklist{entries: make(map[string]time.Time)} +} + +func (m *mockBlacklist) Add(jti string, expiresAt time.Time) { + m.entries[jti] = expiresAt +} + +func (m *mockBlacklist) IsBlacklisted(jti string) bool { + exp, ok := m.entries[jti] + return ok && time.Now().Before(exp) +} + +func newTokenRevokeApp(blName string, bl TokenBlacklist) *MockApplication { + app := NewMockApplication() + app.Services[blName] = bl + return app +} + +func TestTokenRevokeStep_RevokesToken(t *testing.T) { + bl := newMockBlacklist() + app := newTokenRevokeApp("my-blacklist", bl) + + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "my-blacklist", + "token_source": "steps.parse.authorization", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "jti": "test-jti-123", + "sub": "user-1", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "authorization": "Bearer " + tokenStr, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != true { + t.Errorf("expected revoked=true, got %v", result.Output["revoked"]) + } + if result.Output["jti"] != "test-jti-123" { + t.Errorf("expected jti=test-jti-123, got %v", result.Output["jti"]) + } + if !bl.IsBlacklisted("test-jti-123") { + t.Error("expected test-jti-123 to be in blacklist after revoke") + } +} + +func TestTokenRevokeStep_NoBearerPrefix(t *testing.T) { + bl := newMockBlacklist() + app := newTokenRevokeApp("bl", bl) + + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + // Token without Bearer prefix. + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "jti": "bare-jti", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{"token": tokenStr}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != true { + t.Errorf("expected revoked=true, got %v", result.Output["revoked"]) + } + if !bl.IsBlacklisted("bare-jti") { + t.Error("expected bare-jti to be blacklisted") + } +} + +func TestTokenRevokeStep_MissingToken(t *testing.T) { + factory := NewTokenRevokeStepFactory() + app := newTokenRevokeApp("bl", newMockBlacklist()) + step, err := factory("revoke", map[string]any{ + "blacklist_module": "bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != false { + t.Errorf("expected revoked=false for missing token, got %v", result.Output["revoked"]) + } +} + +func TestTokenRevokeStep_NoJTIClaim(t *testing.T) { + bl := newMockBlacklist() + app := newTokenRevokeApp("bl", bl) + + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + // Token without jti claim. + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "sub": "user-1", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{"token": tokenStr}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output["revoked"] != false { + t.Errorf("expected revoked=false for token without jti, got %v", result.Output["revoked"]) + } + if result.Output["error"] != "token has no jti claim" { + t.Errorf("expected 'token has no jti claim' error, got %v", result.Output["error"]) + } +} + +func TestTokenRevokeStep_FactoryMissingBlacklistModule(t *testing.T) { + factory := NewTokenRevokeStepFactory() + _, err := factory("revoke", map[string]any{"token_source": "token"}, nil) + if err == nil { + t.Fatal("expected error for missing blacklist_module") + } +} + +func TestTokenRevokeStep_FactoryMissingTokenSource(t *testing.T) { + factory := NewTokenRevokeStepFactory() + _, err := factory("revoke", map[string]any{"blacklist_module": "bl"}, nil) + if err == nil { + t.Fatal("expected error for missing token_source") + } +} + +func TestTokenRevokeStep_BlacklistModuleNotFound(t *testing.T) { + app := NewMockApplication() // no services registered + factory := NewTokenRevokeStepFactory() + step, err := factory("revoke", map[string]any{ + "blacklist_module": "missing-bl", + "token_source": "steps.parse.token", + }, app) + if err != nil { + t.Fatalf("factory: %v", err) + } + + tokenStr := makeTestJWT(t, "aaaabbbbccccddddeeeeffffgggghhhh", jwt.MapClaims{ + "jti": "some-jti", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{"token": tokenStr}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when blacklist module is not registered") + } +} + +// --- Integration test --- + +// TestBlacklistedTokenFailsAuth tests the full flow: issue token, revoke it, +// then verify that JWTAuthModule rejects it. +func TestBlacklistedTokenFailsAuth(t *testing.T) { + secret := "a-very-long-secret-key-for-testing-purposes-1234" + jwtMod := NewJWTAuthModule("auth", secret, time.Hour, "test") + + app := NewMockApplication() + if err := jwtMod.Init(app); err != nil { + t.Fatalf("JWTAuthModule.Init: %v", err) + } + + // Issue a token. + user := &User{ID: "1", Email: "test@example.com", Name: "Test"} + tokenStr, err := jwtMod.generateToken(user) + if err != nil { + t.Fatalf("generateToken: %v", err) + } + + // Token should be valid before revocation. + valid, claims, err := jwtMod.Authenticate(tokenStr) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if !valid { + t.Fatal("expected token to be valid before revocation") + } + + jti, ok := claims["jti"].(string) + if !ok || jti == "" { + t.Fatal("expected jti claim in token") + } + + // Wire a blacklist and revoke the token. + bl := NewTokenBlacklistModule("bl", "memory", "", time.Minute) + jwtMod.SetTokenBlacklist(bl) + bl.Add(jti, time.Now().Add(time.Hour)) + + // Token should now be rejected. + valid, _, err = jwtMod.Authenticate(tokenStr) + if err != nil { + t.Fatalf("Authenticate after revocation: %v", err) + } + if valid { + t.Fatal("expected token to be rejected after revocation") + } +} + +// TestNonRevokedTokenStillValid ensures that revoking one token does not +// affect other valid tokens. +func TestNonRevokedTokenStillValid(t *testing.T) { + secret := "a-very-long-secret-key-for-testing-purposes-5678" + jwtMod := NewJWTAuthModule("auth", secret, time.Hour, "test") + app := NewMockApplication() + if err := jwtMod.Init(app); err != nil { + t.Fatalf("JWTAuthModule.Init: %v", err) + } + + bl := NewTokenBlacklistModule("bl", "memory", "", time.Minute) + jwtMod.SetTokenBlacklist(bl) + + user := &User{ID: "2", Email: "other@example.com", Name: "Other"} + tokenStr, err := jwtMod.generateToken(user) + if err != nil { + t.Fatalf("generateToken: %v", err) + } + + // Revoke a *different* JTI. + bl.Add("some-other-jti", time.Now().Add(time.Hour)) + + valid, _, err := jwtMod.Authenticate(tokenStr) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if !valid { + t.Fatal("expected un-revoked token to remain valid") + } +} + +// Compile-time check: TokenBlacklistModule satisfies TokenBlacklist. +var _ TokenBlacklist = (*TokenBlacklistModule)(nil) + +// Compile-time check: mockBlacklist satisfies TokenBlacklist. +var _ TokenBlacklist = (*mockBlacklist)(nil) + +// fakeTokenRevokeBlacklistError implements TokenBlacklist but always errors on GetService. +type alwaysErrorApp struct { + *MockApplication + serviceErr error +} + +func (a *alwaysErrorApp) GetService(name string, out any) error { + return fmt.Errorf("forced error: %w", a.serviceErr) +} diff --git a/module/tls_config_test.go b/module/tls_config_test.go new file mode 100644 index 00000000..899fc480 --- /dev/null +++ b/module/tls_config_test.go @@ -0,0 +1,152 @@ +package module + +import ( + "testing" + + "github.com/GoCodeAlone/workflow/pkg/tlsutil" +) + +func TestRedisCacheConfig_TLSField(t *testing.T) { + cfg := RedisCacheConfig{ + Address: "localhost:6380", + TLS: tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + }, + } + if !cfg.TLS.Enabled { + t.Error("expected TLS.Enabled to be true") + } +} + +func TestKafkaBroker_SetTLSConfig(t *testing.T) { + broker := NewKafkaBroker("test-kafka") + broker.SetTLSConfig(KafkaTLSConfig{ + TLSConfig: tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + }, + SASL: KafkaSASLConfig{ + Mechanism: "PLAIN", + Username: "user", + Password: "pass", + }, + }) + + broker.mu.RLock() + defer broker.mu.RUnlock() + if !broker.tlsCfg.Enabled { + t.Error("expected TLS enabled") + } + if broker.tlsCfg.SASL.Mechanism != "PLAIN" { + t.Errorf("expected PLAIN, got %s", broker.tlsCfg.SASL.Mechanism) + } +} + +func TestNATSBroker_SetTLSConfig(t *testing.T) { + broker := NewNATSBroker("test-nats") + broker.SetTLSConfig(tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + }) + + broker.mu.RLock() + defer broker.mu.RUnlock() + if !broker.tlsCfg.Enabled { + t.Error("expected TLS enabled") + } +} + +func TestHTTPServer_SetTLSConfig(t *testing.T) { + srv := NewStandardHTTPServer("test", ":8443") + srv.SetTLSConfig(HTTPServerTLSConfig{ + Mode: "manual", + Manual: tlsutil.TLSConfig{ + CertFile: "/tmp/cert.pem", + KeyFile: "/tmp/key.pem", + }, + }) + if srv.tlsCfg.Mode != "manual" { + t.Errorf("expected mode 'manual', got %q", srv.tlsCfg.Mode) + } +} + +func TestDatabaseConfig_TLSField(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb", + TLS: DatabaseTLSConfig{ + Mode: "verify-full", + CAFile: "/etc/ssl/ca.pem", + }, + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + + if dsn == cfg.DSN { + t.Error("expected DSN to be modified with TLS parameters") + } + if !contains(dsn, "sslmode=verify-full") { + t.Errorf("expected sslmode in DSN, got %s", dsn) + } + if !contains(dsn, "sslrootcert=/etc/ssl/ca.pem") { + t.Errorf("expected sslrootcert in DSN, got %s", dsn) + } +} + +func TestDatabaseConfig_TLSDisabled(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb", + TLS: DatabaseTLSConfig{Mode: "disable"}, + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + if dsn != cfg.DSN { + t.Errorf("expected unchanged DSN when TLS disabled, got %s", dsn) + } +} + +func TestDatabaseConfig_TLSDefault(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb", + // TLS field zero value: Mode="" + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + if dsn != cfg.DSN { + t.Errorf("expected unchanged DSN when TLS not configured, got %s", dsn) + } +} + +func TestDatabaseConfig_TLS_ExistingQueryString(t *testing.T) { + cfg := DatabaseConfig{ + Driver: "postgres", + DSN: "postgres://localhost:5432/mydb?connect_timeout=10", + TLS: DatabaseTLSConfig{Mode: "require"}, + } + + db := NewWorkflowDatabase("test-db", cfg) + dsn := db.buildDSN() + if !contains(dsn, "&sslmode=require") { + t.Errorf("expected & separator when query string exists, got %s", dsn) + } +} + +// contains is a simple substring check helper. +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && searchSubstring(s, sub)) +} + +func searchSubstring(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/pkg/fieldcrypt/encrypt.go b/pkg/fieldcrypt/encrypt.go new file mode 100644 index 00000000..a7bb6e9d --- /dev/null +++ b/pkg/fieldcrypt/encrypt.go @@ -0,0 +1,118 @@ +package fieldcrypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "strconv" + "strings" +) + +// Prefix is the marker for encrypted protected field values. +const Prefix = "epf:" + +// legacyPrefix is the old encryption prefix from the original FieldEncryptor. +const legacyPrefix = "enc::" + +// Encrypt encrypts plaintext with AES-256-GCM, returning "epf:v{version}:{base64(nonce+ciphertext)}". +func Encrypt(plaintext string, key []byte, version int) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create GCM: %w", err) + } + nonce := make([]byte, aead.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("fieldcrypt: generate nonce: %w", err) + } + ciphertext := aead.Seal(nonce, nonce, []byte(plaintext), nil) + encoded := base64.StdEncoding.EncodeToString(ciphertext) + return fmt.Sprintf("%sv%d:%s", Prefix, version, encoded), nil +} + +// Decrypt decrypts an epf:-prefixed value. It also handles legacy "enc::" prefix values. +// keyFn is called with the version number extracted from the prefix. +// For legacy enc:: values, keyFn(0) is called to obtain the raw master key, +// which is then SHA256-hashed to match the original FieldEncryptor behavior. +func Decrypt(ciphertext string, keyFn func(version int) ([]byte, error)) (string, error) { + if strings.HasPrefix(ciphertext, legacyPrefix) { + return decryptLegacy(ciphertext, keyFn) + } + if !strings.HasPrefix(ciphertext, Prefix) { + return "", fmt.Errorf("fieldcrypt: not an encrypted value") + } + + // Parse "epf:v{version}:{base64}" + rest := strings.TrimPrefix(ciphertext, Prefix) + idx := strings.Index(rest, ":") + if idx < 0 || !strings.HasPrefix(rest, "v") { + return "", fmt.Errorf("fieldcrypt: invalid format") + } + version, err := strconv.Atoi(rest[1:idx]) + if err != nil { + return "", fmt.Errorf("fieldcrypt: invalid version: %w", err) + } + encoded := rest[idx+1:] + + key, err := keyFn(version) + if err != nil { + return "", fmt.Errorf("fieldcrypt: key lookup: %w", err) + } + + raw, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", fmt.Errorf("fieldcrypt: base64 decode: %w", err) + } + + return decryptAESGCM(raw, key) +} + +// IsEncrypted checks if a value has the epf: prefix or legacy enc:: prefix. +func IsEncrypted(value string) bool { + return strings.HasPrefix(value, Prefix) || strings.HasPrefix(value, legacyPrefix) +} + +// decryptLegacy handles the old "enc::" prefix format. +func decryptLegacy(ciphertext string, keyFn func(version int) ([]byte, error)) (string, error) { + raw := strings.TrimPrefix(ciphertext, legacyPrefix) + decoded, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return "", fmt.Errorf("fieldcrypt: legacy base64 decode: %w", err) + } + // keyFn(0) returns the raw master key; hash it with SHA256 to match original behavior. + masterKey, err := keyFn(0) + if err != nil { + return "", fmt.Errorf("fieldcrypt: legacy key lookup: %w", err) + } + hash := sha256.Sum256(masterKey) + return decryptAESGCM(decoded, hash[:]) +} + +// decryptAESGCM decrypts raw bytes (nonce + ciphertext) with the given AES-256-GCM key. +func decryptAESGCM(raw, key []byte) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("fieldcrypt: create GCM: %w", err) + } + nonceSize := aead.NonceSize() + if len(raw) < nonceSize { + return "", fmt.Errorf("fieldcrypt: ciphertext too short") + } + nonce, ct := raw[:nonceSize], raw[nonceSize:] + plaintext, err := aead.Open(nil, nonce, ct, nil) + if err != nil { + return "", fmt.Errorf("fieldcrypt: decryption failed: %w", err) + } + return string(plaintext), nil +} diff --git a/pkg/fieldcrypt/fieldcrypt.go b/pkg/fieldcrypt/fieldcrypt.go new file mode 100644 index 00000000..9265ef94 --- /dev/null +++ b/pkg/fieldcrypt/fieldcrypt.go @@ -0,0 +1,66 @@ +package fieldcrypt + +// FieldClassification defines the sensitivity level. +type FieldClassification string + +const ( + ClassPII FieldClassification = "pii" + ClassPHI FieldClassification = "phi" +) + +// LogBehavior defines how a field appears in logs. +type LogBehavior string + +const ( + LogMask LogBehavior = "mask" + LogRedact LogBehavior = "redact" + LogHash LogBehavior = "hash" + LogAllow LogBehavior = "allow" +) + +// ProtectedField defines a field that requires encryption/masking. +type ProtectedField struct { + Name string `yaml:"name"` + Classification FieldClassification `yaml:"classification"` + Encryption bool `yaml:"encryption"` + LogBehavior LogBehavior `yaml:"log_behavior"` + MaskPattern string `yaml:"mask_pattern"` +} + +// Registry holds the set of protected fields for lookup. +type Registry struct { + fields map[string]ProtectedField +} + +// NewRegistry creates a Registry from a slice of ProtectedField definitions. +func NewRegistry(fields []ProtectedField) *Registry { + m := make(map[string]ProtectedField, len(fields)) + for _, f := range fields { + m[f.Name] = f + } + return &Registry{fields: m} +} + +// IsProtected returns true if the given field name is in the registry. +func (r *Registry) IsProtected(fieldName string) bool { + _, ok := r.fields[fieldName] + return ok +} + +// GetField returns the ProtectedField definition for the given name. +func (r *Registry) GetField(fieldName string) (*ProtectedField, bool) { + f, ok := r.fields[fieldName] + if !ok { + return nil, false + } + return &f, true +} + +// ProtectedFields returns all registered protected fields. +func (r *Registry) ProtectedFields() []ProtectedField { + out := make([]ProtectedField, 0, len(r.fields)) + for _, f := range r.fields { + out = append(out, f) + } + return out +} diff --git a/pkg/fieldcrypt/fieldcrypt_test.go b/pkg/fieldcrypt/fieldcrypt_test.go new file mode 100644 index 00000000..d40152ba --- /dev/null +++ b/pkg/fieldcrypt/fieldcrypt_test.go @@ -0,0 +1,377 @@ +package fieldcrypt + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "io" + "strings" + "testing" +) + +func TestEncryptDecryptRoundTrip(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + plaintext := "sensitive data here" + encrypted, err := Encrypt(plaintext, key, 1) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + if !strings.HasPrefix(encrypted, "epf:v1:") { + t.Fatalf("expected epf:v1: prefix, got %q", encrypted) + } + + decrypted, err := Decrypt(encrypted, func(version int) ([]byte, error) { + if version != 1 { + t.Fatalf("expected version 1, got %d", version) + } + return key, nil + }) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != plaintext { + t.Fatalf("expected %q, got %q", plaintext, decrypted) + } +} + +func TestEncryptDecryptVersion(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + encrypted, err := Encrypt("hello", key, 42) + if err != nil { + t.Fatal(err) + } + if !strings.HasPrefix(encrypted, "epf:v42:") { + t.Fatalf("expected epf:v42: prefix, got %q", encrypted) + } + + decrypted, err := Decrypt(encrypted, func(version int) ([]byte, error) { + if version != 42 { + t.Fatalf("expected version 42, got %d", version) + } + return key, nil + }) + if err != nil { + t.Fatal(err) + } + if decrypted != "hello" { + t.Fatalf("expected hello, got %q", decrypted) + } +} + +func TestIsEncrypted(t *testing.T) { + tests := []struct { + value string + want bool + }{ + {"epf:v1:abc", true}, + {"enc::abc", true}, + {"plaintext", false}, + {"", false}, + } + for _, tt := range tests { + if got := IsEncrypted(tt.value); got != tt.want { + t.Errorf("IsEncrypted(%q) = %v, want %v", tt.value, got, tt.want) + } + } +} + +func TestLegacyEncDecrypt(t *testing.T) { + // Simulate the old FieldEncryptor: SHA256(masterKey) -> AES-256-GCM. + masterKey := []byte("my-secret-key") + hash := sha256.Sum256(masterKey) + + block, err := aes.NewCipher(hash[:]) + if err != nil { + t.Fatal(err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + t.Fatal(err) + } + nonce := make([]byte, aead.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + t.Fatal(err) + } + ct := aead.Seal(nonce, nonce, []byte("legacy secret"), nil) + legacyEncoded := "enc::" + base64.StdEncoding.EncodeToString(ct) + + decrypted, err := Decrypt(legacyEncoded, func(version int) ([]byte, error) { + if version != 0 { + t.Fatalf("expected version 0 for legacy, got %d", version) + } + return masterKey, nil + }) + if err != nil { + t.Fatalf("Decrypt legacy: %v", err) + } + if decrypted != "legacy secret" { + t.Fatalf("expected 'legacy secret', got %q", decrypted) + } +} + +func TestMaskEmail(t *testing.T) { + got := MaskEmail("john@example.com") + if got != "j***@e***.com" { + t.Errorf("MaskEmail = %q, want %q", got, "j***@e***.com") + } +} + +func TestMaskPhone(t *testing.T) { + got := MaskPhone("555-123-4567") + if got != "***-***-4567" { + t.Errorf("MaskPhone = %q, want %q", got, "***-***-4567") + } +} + +func TestHashValue(t *testing.T) { + h := HashValue("test") + if len(h) != 64 { + t.Errorf("expected 64 char hex, got %d chars", len(h)) + } +} + +func TestRedactValue(t *testing.T) { + if got := RedactValue(); got != "[REDACTED]" { + t.Errorf("RedactValue = %q", got) + } +} + +func TestMaskValueBehaviors(t *testing.T) { + if MaskValue("secret", LogRedact, "") != "[REDACTED]" { + t.Error("LogRedact failed") + } + if MaskValue("secret", LogAllow, "") != "secret" { + t.Error("LogAllow failed") + } + h := MaskValue("secret", LogHash, "") + if len(h) != 64 { + t.Error("LogHash failed") + } +} + +func TestScanAndEncryptDecrypt(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + registry := NewRegistry([]ProtectedField{ + {Name: "ssn", Classification: ClassPII, Encryption: true, LogBehavior: LogRedact}, + {Name: "email", Classification: ClassPII, Encryption: true, LogBehavior: LogMask}, + }) + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "test@example.com", + "name": "John", + "nested": map[string]any{ + "ssn": "987-65-4321", + }, + "items": []any{ + map[string]any{"email": "a@b.com"}, + }, + } + + err := ScanAndEncrypt(data, registry, func() ([]byte, int, error) { + return key, 1, nil + }, 10) + if err != nil { + t.Fatalf("ScanAndEncrypt: %v", err) + } + + // Verify fields are encrypted. + if !IsEncrypted(data["ssn"].(string)) { + t.Error("ssn should be encrypted") + } + if !IsEncrypted(data["email"].(string)) { + t.Error("email should be encrypted") + } + if data["name"] != "John" { + t.Error("name should not be modified") + } + nested := data["nested"].(map[string]any) + if !IsEncrypted(nested["ssn"].(string)) { + t.Error("nested ssn should be encrypted") + } + items := data["items"].([]any) + item := items[0].(map[string]any) + if !IsEncrypted(item["email"].(string)) { + t.Error("array item email should be encrypted") + } + + // Now decrypt. + err = ScanAndDecrypt(data, registry, func(version int) ([]byte, error) { + return key, nil + }, 10) + if err != nil { + t.Fatalf("ScanAndDecrypt: %v", err) + } + + if data["ssn"] != "123-45-6789" { + t.Errorf("ssn = %q, want 123-45-6789", data["ssn"]) + } + if data["email"] != "test@example.com" { + t.Errorf("email = %q, want test@example.com", data["email"]) + } + if nested["ssn"] != "987-65-4321" { + t.Errorf("nested ssn = %q", nested["ssn"]) + } + item = items[0].(map[string]any) + if item["email"] != "a@b.com" { + t.Errorf("array item email = %q", item["email"]) + } +} + +func TestScanAndMask(t *testing.T) { + registry := NewRegistry([]ProtectedField{ + {Name: "ssn", Classification: ClassPII, LogBehavior: LogRedact}, + {Name: "email", Classification: ClassPII, LogBehavior: LogMask}, + }) + + data := map[string]any{ + "ssn": "123-45-6789", + "email": "test@example.com", + "name": "John", + } + + masked := ScanAndMask(data, registry, 10) + + if masked["ssn"] != "[REDACTED]" { + t.Errorf("ssn mask = %q", masked["ssn"]) + } + if masked["email"] == "test@example.com" { + t.Error("email should be masked") + } + if masked["name"] != "John" { + t.Error("name should be unchanged") + } + + // Original should be unmodified. + if data["ssn"] != "123-45-6789" { + t.Error("original ssn was modified") + } +} + +func TestKeyRingTenantIsolation(t *testing.T) { + masterKey := make([]byte, 32) + if _, err := rand.Read(masterKey); err != nil { + t.Fatal(err) + } + + kr := NewLocalKeyRing(masterKey) + ctx := context.Background() + + keyA, verA, err := kr.CurrentKey(ctx, "tenant-a") + if err != nil { + t.Fatal(err) + } + if verA != 1 { + t.Fatalf("expected version 1, got %d", verA) + } + + keyB, verB, err := kr.CurrentKey(ctx, "tenant-b") + if err != nil { + t.Fatal(err) + } + if verB != 1 { + t.Fatalf("expected version 1, got %d", verB) + } + + // Keys should be different for different tenants. + if string(keyA) == string(keyB) { + t.Error("tenant keys should differ") + } +} + +func TestKeyRingRotation(t *testing.T) { + masterKey := make([]byte, 32) + if _, err := rand.Read(masterKey); err != nil { + t.Fatal(err) + } + + kr := NewLocalKeyRing(masterKey) + ctx := context.Background() + + key1, ver1, err := kr.CurrentKey(ctx, "t1") + if err != nil { + t.Fatal(err) + } + + key2, ver2, err := kr.Rotate(ctx, "t1") + if err != nil { + t.Fatal(err) + } + if ver2 != ver1+1 { + t.Fatalf("expected version %d, got %d", ver1+1, ver2) + } + if string(key1) == string(key2) { + t.Error("rotated key should differ") + } + + // Old key should still be retrievable. + oldKey, err := kr.KeyByVersion(ctx, "t1", ver1) + if err != nil { + t.Fatal(err) + } + if string(oldKey) != string(key1) { + t.Error("old key mismatch") + } + + // Current should return new key. + curKey, curVer, err := kr.CurrentKey(ctx, "t1") + if err != nil { + t.Fatal(err) + } + if curVer != ver2 { + t.Fatalf("current version = %d, want %d", curVer, ver2) + } + if string(curKey) != string(key2) { + t.Error("current key mismatch") + } +} + +func TestKeyRingVersionLookup(t *testing.T) { + masterKey := []byte("deterministic-master-key-for-test") + kr := NewLocalKeyRing(masterKey) + ctx := context.Background() + + // Create v1 by calling CurrentKey. + _, _, err := kr.CurrentKey(ctx, "t") + if err != nil { + t.Fatal(err) + } + + // Rotate to v2. + _, _, err = kr.Rotate(ctx, "t") + if err != nil { + t.Fatal(err) + } + + // Lookup v1 by version. + k1, err := kr.KeyByVersion(ctx, "t", 1) + if err != nil { + t.Fatal(err) + } + + // Lookup v2 by version. + k2, err := kr.KeyByVersion(ctx, "t", 2) + if err != nil { + t.Fatal(err) + } + + if string(k1) == string(k2) { + t.Error("v1 and v2 keys should differ") + } +} diff --git a/pkg/fieldcrypt/keyring.go b/pkg/fieldcrypt/keyring.go new file mode 100644 index 00000000..6a69f193 --- /dev/null +++ b/pkg/fieldcrypt/keyring.go @@ -0,0 +1,122 @@ +package fieldcrypt + +import ( + "context" + "crypto/sha256" + "fmt" + "io" + "sync" + + "golang.org/x/crypto/hkdf" +) + +// KeyRing manages versioned, tenant-isolated encryption keys. +type KeyRing interface { + CurrentKey(ctx context.Context, tenantID string) (key []byte, version int, err error) + KeyByVersion(ctx context.Context, tenantID string, version int) ([]byte, error) + Rotate(ctx context.Context, tenantID string) (key []byte, version int, err error) +} + +// LocalKeyRing stores keys in memory, keyed by tenant. +// Keys are derived from a master key using HKDF. +type LocalKeyRing struct { + masterKey []byte + mu sync.RWMutex + tenantVersions map[string]int // tenantID -> current version number + tenantKeys map[string][]byte // "tenantID:version" -> derived key +} + +// NewLocalKeyRing creates a new LocalKeyRing from a master key. +func NewLocalKeyRing(masterKey []byte) *LocalKeyRing { + return &LocalKeyRing{ + masterKey: masterKey, + tenantVersions: make(map[string]int), + tenantKeys: make(map[string][]byte), + } +} + +// CurrentKey returns the current key version for a tenant. +// If no key exists yet, generates version 1. +func (k *LocalKeyRing) CurrentKey(_ context.Context, tenantID string) ([]byte, int, error) { + k.mu.RLock() + ver, ok := k.tenantVersions[tenantID] + if ok { + cacheKey := fmt.Sprintf("%s:%d", tenantID, ver) + key := k.tenantKeys[cacheKey] + k.mu.RUnlock() + return key, ver, nil + } + k.mu.RUnlock() + + // No key yet; create version 1. + k.mu.Lock() + defer k.mu.Unlock() + + // Double-check after acquiring write lock. + if ver, ok := k.tenantVersions[tenantID]; ok { + cacheKey := fmt.Sprintf("%s:%d", tenantID, ver) + return k.tenantKeys[cacheKey], ver, nil + } + + key, err := k.deriveKey(tenantID, 1) + if err != nil { + return nil, 0, err + } + k.tenantVersions[tenantID] = 1 + k.tenantKeys[fmt.Sprintf("%s:%d", tenantID, 1)] = key + return key, 1, nil +} + +// KeyByVersion returns the key for a specific tenant+version. +func (k *LocalKeyRing) KeyByVersion(_ context.Context, tenantID string, version int) ([]byte, error) { + cacheKey := fmt.Sprintf("%s:%d", tenantID, version) + + k.mu.RLock() + if key, ok := k.tenantKeys[cacheKey]; ok { + k.mu.RUnlock() + return key, nil + } + k.mu.RUnlock() + + k.mu.Lock() + defer k.mu.Unlock() + + // Double-check. + if key, ok := k.tenantKeys[cacheKey]; ok { + return key, nil + } + + key, err := k.deriveKey(tenantID, version) + if err != nil { + return nil, err + } + k.tenantKeys[cacheKey] = key + return key, nil +} + +// Rotate increments the version and derives a new key for the tenant. +func (k *LocalKeyRing) Rotate(_ context.Context, tenantID string) ([]byte, int, error) { + k.mu.Lock() + defer k.mu.Unlock() + + ver := k.tenantVersions[tenantID] + 1 + key, err := k.deriveKey(tenantID, ver) + if err != nil { + return nil, 0, err + } + k.tenantVersions[tenantID] = ver + k.tenantKeys[fmt.Sprintf("%s:%d", tenantID, ver)] = key + return key, ver, nil +} + +// deriveKey uses HKDF-SHA256 to derive a 32-byte key. +// Info = "fieldcrypt:" + tenantID + ":v" + version. +func (k *LocalKeyRing) deriveKey(tenantID string, version int) ([]byte, error) { + info := fmt.Sprintf("fieldcrypt:%s:v%d", tenantID, version) + hkdfReader := hkdf.New(sha256.New, k.masterKey, nil, []byte(info)) + derived := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, derived); err != nil { + return nil, fmt.Errorf("fieldcrypt: key derivation failed: %w", err) + } + return derived, nil +} diff --git a/pkg/fieldcrypt/mask.go b/pkg/fieldcrypt/mask.go new file mode 100644 index 00000000..ca4aa5ca --- /dev/null +++ b/pkg/fieldcrypt/mask.go @@ -0,0 +1,122 @@ +package fieldcrypt + +import ( + "crypto/sha256" + "fmt" + "strings" +) + +// MaskValue applies masking based on LogBehavior and optional pattern. +func MaskValue(value string, behavior LogBehavior, pattern string) string { + switch behavior { + case LogRedact: + return RedactValue() + case LogHash: + return HashValue(value) + case LogAllow: + return value + case LogMask: + if pattern != "" { + return applyPattern(value, pattern) + } + // Auto-detect: try email first, then phone-like, then generic mask. + if strings.Contains(value, "@") { + return MaskEmail(value) + } + if looksLikePhone(value) { + return MaskPhone(value) + } + return genericMask(value) + default: + return RedactValue() + } +} + +// MaskEmail masks an email: "j***@e***.com". +func MaskEmail(email string) string { + at := strings.LastIndex(email, "@") + if at < 0 { + return genericMask(email) + } + local := email[:at] + domain := email[at+1:] + + maskedLocal := maskPart(local) + // Mask domain but keep TLD. + dot := strings.LastIndex(domain, ".") + if dot > 0 { + maskedDomain := maskPart(domain[:dot]) + domain[dot:] + return maskedLocal + "@" + maskedDomain + } + return maskedLocal + "@" + maskPart(domain) +} + +// MaskPhone masks all but last 4 digits: "***-***-1234". +func MaskPhone(phone string) string { + digits := extractDigits(phone) + if len(digits) <= 4 { + return "****" + } + last4 := string(digits[len(digits)-4:]) + return "***-***-" + last4 +} + +// HashValue returns SHA256 hex of the value. +func HashValue(value string) string { + h := sha256.Sum256([]byte(value)) + return fmt.Sprintf("%x", h) +} + +// RedactValue returns "[REDACTED]". +func RedactValue() string { + return "[REDACTED]" +} + +// maskPart keeps first char and replaces rest with ***. +func maskPart(s string) string { + if len(s) == 0 { + return "" + } + return string(s[0]) + "***" +} + +func genericMask(s string) string { + if len(s) <= 1 { + return "***" + } + return string(s[0]) + "***" +} + +func looksLikePhone(s string) bool { + digits := extractDigits(s) + return len(digits) >= 7 && len(digits) <= 15 +} + +func extractDigits(s string) []byte { + var digits []byte + for i := 0; i < len(s); i++ { + if s[i] >= '0' && s[i] <= '9' { + digits = append(digits, s[i]) + } + } + return digits +} + +func applyPattern(value, pattern string) string { + digits := extractDigits(value) + di := 0 + var result []byte + for i := 0; i < len(pattern); i++ { + if pattern[i] == '#' { + if di < len(digits) { + result = append(result, digits[di]) + di++ + } else { + result = append(result, '#') + } + } else { + result = append(result, pattern[i]) + } + } + return string(result) +} diff --git a/pkg/fieldcrypt/scanner.go b/pkg/fieldcrypt/scanner.go new file mode 100644 index 00000000..7c8413e4 --- /dev/null +++ b/pkg/fieldcrypt/scanner.go @@ -0,0 +1,117 @@ +package fieldcrypt + +// ScanAndEncrypt recursively scans a map, encrypting protected fields that have Encryption=true. +// maxDepth limits recursion depth. +func ScanAndEncrypt(data map[string]any, registry *Registry, keyFn func() ([]byte, int, error), maxDepth int) error { + return scanEncrypt(data, registry, keyFn, 0, maxDepth) +} + +func scanEncrypt(data map[string]any, registry *Registry, keyFn func() ([]byte, int, error), depth, maxDepth int) error { + if depth >= maxDepth { + return nil + } + for k, v := range data { + switch val := v.(type) { + case string: + if pf, ok := registry.GetField(k); ok && pf.Encryption && !IsEncrypted(val) { + key, version, err := keyFn() + if err != nil { + return err + } + encrypted, err := Encrypt(val, key, version) + if err != nil { + return err + } + data[k] = encrypted + } + case map[string]any: + if err := scanEncrypt(val, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + case []any: + for _, elem := range val { + if m, ok := elem.(map[string]any); ok { + if err := scanEncrypt(m, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + } + } + } + } + return nil +} + +// ScanAndDecrypt recursively scans a map, decrypting epf:-prefixed (and enc::-prefixed) protected fields. +func ScanAndDecrypt(data map[string]any, registry *Registry, keyFn func(version int) ([]byte, error), maxDepth int) error { + return scanDecrypt(data, registry, keyFn, 0, maxDepth) +} + +func scanDecrypt(data map[string]any, registry *Registry, keyFn func(version int) ([]byte, error), depth, maxDepth int) error { + if depth >= maxDepth { + return nil + } + for k, v := range data { + switch val := v.(type) { + case string: + if registry.IsProtected(k) && IsEncrypted(val) { + decrypted, err := Decrypt(val, keyFn) + if err != nil { + return err + } + data[k] = decrypted + } + case map[string]any: + if err := scanDecrypt(val, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + case []any: + for _, elem := range val { + if m, ok := elem.(map[string]any); ok { + if err := scanDecrypt(m, registry, keyFn, depth+1, maxDepth); err != nil { + return err + } + } + } + } + } + return nil +} + +// ScanAndMask returns a deep copy of data with protected fields masked (for logging). +// Does NOT modify the original map. +func ScanAndMask(data map[string]any, registry *Registry, maxDepth int) map[string]any { + return scanMask(data, registry, 0, maxDepth) +} + +func scanMask(data map[string]any, registry *Registry, depth, maxDepth int) map[string]any { + result := make(map[string]any, len(data)) + for k, v := range data { + switch val := v.(type) { + case string: + if pf, ok := registry.GetField(k); ok { + result[k] = MaskValue(val, pf.LogBehavior, pf.MaskPattern) + } else { + result[k] = val + } + case map[string]any: + if depth < maxDepth { + result[k] = scanMask(val, registry, depth+1, maxDepth) + } else { + result[k] = val + } + case []any: + masked := make([]any, len(val)) + for i, elem := range val { + if m, ok := elem.(map[string]any); ok && depth < maxDepth { + masked[i] = scanMask(m, registry, depth+1, maxDepth) + } else { + masked[i] = elem + } + } + result[k] = masked + default: + result[k] = v + } + } + return result +} diff --git a/pkg/tlsutil/tlsutil.go b/pkg/tlsutil/tlsutil.go new file mode 100644 index 00000000..9c49475b --- /dev/null +++ b/pkg/tlsutil/tlsutil.go @@ -0,0 +1,74 @@ +package tlsutil + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" +) + +// TLSConfig is the common TLS configuration used across all transports. +type TLSConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + CertFile string `yaml:"cert_file" json:"cert_file"` + KeyFile string `yaml:"key_file" json:"key_file"` + CAFile string `yaml:"ca_file" json:"ca_file"` + ClientAuth string `yaml:"client_auth" json:"client_auth"` // require | request | none + SkipVerify bool `yaml:"skip_verify" json:"skip_verify"` // for dev only +} + +// AutocertConfig holds Let's Encrypt autocert configuration. +type AutocertConfig struct { + Domains []string `yaml:"domains" json:"domains"` + CacheDir string `yaml:"cache_dir" json:"cache_dir"` + Email string `yaml:"email" json:"email"` +} + +// LoadTLSConfig builds a *tls.Config from the YAML-friendly struct. +func LoadTLSConfig(cfg TLSConfig) (*tls.Config, error) { + if !cfg.Enabled { + return nil, nil + } + + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: cfg.SkipVerify, //nolint:gosec // G402: intentional dev-only option + } + + // Load server/client certificate keypair if provided + if cfg.CertFile != "" && cfg.KeyFile != "" { + cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) + if err != nil { + return nil, fmt.Errorf("tlsutil: load key pair: %w", err) + } + tlsCfg.Certificates = []tls.Certificate{cert} + } + + // Load CA certificate for client verification or custom root CA + if cfg.CAFile != "" { + caPEM, err := os.ReadFile(cfg.CAFile) + if err != nil { + return nil, fmt.Errorf("tlsutil: read CA file: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPEM) { + return nil, fmt.Errorf("tlsutil: no valid certificates found in %s", cfg.CAFile) + } + tlsCfg.RootCAs = pool + tlsCfg.ClientCAs = pool + } + + // Configure client authentication policy + switch cfg.ClientAuth { + case "require": + tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert + case "request": + tlsCfg.ClientAuth = tls.RequestClientCert + case "none", "": + tlsCfg.ClientAuth = tls.NoClientCert + default: + return nil, fmt.Errorf("tlsutil: unknown client_auth %q (valid: require, request, none)", cfg.ClientAuth) + } + + return tlsCfg, nil +} diff --git a/pkg/tlsutil/tlsutil_test.go b/pkg/tlsutil/tlsutil_test.go new file mode 100644 index 00000000..b29aedda --- /dev/null +++ b/pkg/tlsutil/tlsutil_test.go @@ -0,0 +1,237 @@ +package tlsutil_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/GoCodeAlone/workflow/pkg/tlsutil" +) + +// generateSelfSignedCert writes a self-signed cert+key pair to tmpdir and +// returns (certFile, keyFile). +func generateSelfSignedCert(t *testing.T, dir string) (string, string) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + IsCA: true, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + certFile := filepath.Join(dir, "cert.pem") + keyFile := filepath.Join(dir, "key.pem") + + cf, err := os.Create(certFile) + if err != nil { + t.Fatal(err) + } + defer cf.Close() + if err := pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + t.Fatal(err) + } + + kf, err := os.Create(keyFile) + if err != nil { + t.Fatal(err) + } + defer kf.Close() + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + if err := pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { + t.Fatal(err) + } + + return certFile, keyFile +} + +func TestLoadTLSConfig_Disabled(t *testing.T) { + cfg := tlsutil.TLSConfig{Enabled: false} + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatal("expected nil tls.Config when disabled") + } +} + +func TestLoadTLSConfig_ValidCert(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil { + t.Fatal("expected non-nil tls.Config") + } + if len(result.Certificates) != 1 { + t.Fatalf("expected 1 certificate, got %d", len(result.Certificates)) + } +} + +func TestLoadTLSConfig_WithCA(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + CAFile: certFile, // reuse the self-signed cert as CA + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil { + t.Fatal("expected non-nil tls.Config") + } + if result.RootCAs == nil { + t.Fatal("expected non-nil RootCAs") + } +} + +func TestLoadTLSConfig_ClientAuthRequire(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + ClientAuth: "require", + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.ClientAuth != tls.RequireAndVerifyClientCert { + t.Errorf("expected RequireAndVerifyClientCert, got %v", result.ClientAuth) + } +} + +func TestLoadTLSConfig_ClientAuthRequest(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateSelfSignedCert(t, dir) + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: certFile, + KeyFile: keyFile, + ClientAuth: "request", + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.ClientAuth != tls.RequestClientCert { + t.Errorf("expected RequestClientCert, got %v", result.ClientAuth) + } +} + +func TestLoadTLSConfig_SkipVerify(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + SkipVerify: true, + } + + result, err := tlsutil.LoadTLSConfig(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.InsecureSkipVerify { //nolint:gosec // G402: test assertion + t.Error("expected InsecureSkipVerify to be true") + } +} + +func TestLoadTLSConfig_InvalidCertFile(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + CertFile: "/nonexistent/cert.pem", + KeyFile: "/nonexistent/key.pem", + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for nonexistent cert files") + } +} + +func TestLoadTLSConfig_InvalidCAFile(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + CAFile: "/nonexistent/ca.pem", + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for nonexistent CA file") + } +} + +func TestLoadTLSConfig_InvalidCAContent(t *testing.T) { + dir := t.TempDir() + badCA := filepath.Join(dir, "bad-ca.pem") + if err := os.WriteFile(badCA, []byte("not a valid pem"), 0600); err != nil { + t.Fatal(err) + } + + cfg := tlsutil.TLSConfig{ + Enabled: true, + CAFile: badCA, + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for invalid CA content") + } +} + +func TestLoadTLSConfig_InvalidClientAuth(t *testing.T) { + cfg := tlsutil.TLSConfig{ + Enabled: true, + ClientAuth: "invalid-value", + } + + _, err := tlsutil.LoadTLSConfig(cfg) + if err == nil { + t.Fatal("expected error for invalid client_auth value") + } +} diff --git a/plugins/auth/plugin.go b/plugins/auth/plugin.go index 0dec02b1..3e69b721 100644 --- a/plugins/auth/plugin.go +++ b/plugins/auth/plugin.go @@ -11,6 +11,17 @@ import ( "github.com/GoCodeAlone/workflow/schema" ) +// durationFromMap parses a duration string from a config map, returning the +// default value when the key is absent or unparseable. +func durationFromMap(m map[string]any, key string, defaultVal time.Duration) time.Duration { + if s, ok := m[key].(string); ok && s != "" { + if d, err := time.ParseDuration(s); err == nil { + return d + } + } + return defaultVal +} + // Plugin provides authentication capabilities: auth.jwt, auth.user-store, // auth.oauth2, and auth.m2m modules plus the wiring hook that connects // AuthProviders to AuthMiddleware. @@ -38,12 +49,14 @@ func New() *Plugin { "auth.user-store", "auth.oauth2", "auth.m2m", + "auth.token-blacklist", + "security.field-protection", }, Capabilities: []plugin.CapabilityDecl{ {Name: "authentication", Role: "provider", Priority: 10}, {Name: "user-management", Role: "provider", Priority: 10}, }, - WiringHooks: []string{"auth-provider-wiring", "oauth2-jwt-wiring"}, + WiringHooks: []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring"}, }, }, } @@ -127,6 +140,16 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { // jwtAuth will be wired during the wiring hook. return module.NewOAuth2Module(name, providerCfgs, nil) }, + "auth.token-blacklist": func(name string, cfg map[string]any) modular.Module { + backend := stringFromMap(cfg, "backend") + redisURL := stringFromMap(cfg, "redis_url") + cleanupInterval := durationFromMap(cfg, "cleanup_interval", 5*time.Minute) + return module.NewTokenBlacklistModule(name, backend, redisURL, cleanupInterval) + }, + "security.field-protection": func(name string, cfg map[string]any) modular.Module { + mod, _ := module.NewFieldProtectionModule(name, cfg) + return mod + }, "auth.m2m": func(name string, cfg map[string]any) modular.Module { secret := stringFromMap(cfg, "secret") tokenExpiry := time.Hour @@ -235,6 +258,28 @@ func (p *Plugin) WiringHooks() []plugin.WiringHook { return nil }, }, + { + Name: "token-blacklist-wiring", + Priority: 70, + Hook: func(app modular.Application, _ *config.WorkflowConfig) error { + var blacklist *module.TokenBlacklistModule + for _, svc := range app.SvcRegistry() { + if bl, ok := svc.(*module.TokenBlacklistModule); ok { + blacklist = bl + break + } + } + if blacklist == nil { + return nil + } + for _, svc := range app.SvcRegistry() { + if j, ok := svc.(*module.JWTAuthModule); ok { + j.SetTokenBlacklist(blacklist) + } + } + return nil + }, + }, } } diff --git a/plugins/auth/plugin_test.go b/plugins/auth/plugin_test.go index 0bd49919..456530b3 100644 --- a/plugins/auth/plugin_test.go +++ b/plugins/auth/plugin_test.go @@ -21,11 +21,11 @@ func TestPluginManifest(t *testing.T) { if m.Name != "auth" { t.Errorf("expected name %q, got %q", "auth", m.Name) } - if len(m.ModuleTypes) != 4 { - t.Errorf("expected 4 module types, got %d", len(m.ModuleTypes)) + if len(m.ModuleTypes) != 6 { + t.Errorf("expected 6 module types, got %d", len(m.ModuleTypes)) } - if len(m.WiringHooks) != 2 { - t.Errorf("expected 2 wiring hooks, got %d", len(m.WiringHooks)) + if len(m.WiringHooks) != 3 { + t.Errorf("expected 3 wiring hooks, got %d", len(m.WiringHooks)) } } @@ -50,7 +50,7 @@ func TestModuleFactories(t *testing.T) { p := New() factories := p.ModuleFactories() - expectedTypes := []string{"auth.jwt", "auth.user-store", "auth.oauth2", "auth.m2m"} + expectedTypes := []string{"auth.jwt", "auth.user-store", "auth.oauth2", "auth.m2m", "auth.token-blacklist", "security.field-protection"} for _, typ := range expectedTypes { factory, ok := factories[typ] if !ok { @@ -83,8 +83,8 @@ func TestModuleFactoryJWTWithConfig(t *testing.T) { func TestWiringHooks(t *testing.T) { p := New() hooks := p.WiringHooks() - if len(hooks) != 2 { - t.Fatalf("expected 2 wiring hooks, got %d", len(hooks)) + if len(hooks) != 3 { + t.Fatalf("expected 3 wiring hooks, got %d", len(hooks)) } hookNames := map[string]bool{} for _, h := range hooks { @@ -93,7 +93,7 @@ func TestWiringHooks(t *testing.T) { t.Errorf("wiring hook %q function is nil", h.Name) } } - for _, expected := range []string{"auth-provider-wiring", "oauth2-jwt-wiring"} { + for _, expected := range []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring"} { if !hookNames[expected] { t.Errorf("missing wiring hook %q", expected) } diff --git a/plugins/pipelinesteps/plugin.go b/plugins/pipelinesteps/plugin.go index 52a46b69..5b9815ef 100644 --- a/plugins/pipelinesteps/plugin.go +++ b/plugins/pipelinesteps/plugin.go @@ -4,7 +4,7 @@ // validate_path_param, validate_pagination, validate_request_body, // foreach, webhook_verify, base64_decode, ui_scaffold, ui_scaffold_analyze, // dlq_send, dlq_replay, retry_with_backoff, circuit_breaker (wrapping), -// s3_upload, auth_validate. +// s3_upload, auth_validate, token_revoke, sandbox_exec. // It also provides the PipelineWorkflowHandler for composable pipelines. package pipelinesteps @@ -85,6 +85,9 @@ func New() *Plugin { "step.resilient_circuit_breaker", "step.s3_upload", "step.auth_validate", + "step.token_revoke", + "step.field_reencrypt", + "step.sandbox_exec", }, WorkflowTypes: []string{"pipeline"}, Capabilities: []plugin.CapabilityDecl{ @@ -148,7 +151,10 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { return p.concreteStepRegistry })), "step.s3_upload": wrapStepFactory(module.NewS3UploadStepFactory()), - "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), + "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), + "step.token_revoke": wrapStepFactory(module.NewTokenRevokeStepFactory()), + "step.field_reencrypt": wrapStepFactory(module.NewFieldReencryptStepFactory()), + "step.sandbox_exec": wrapStepFactory(module.NewSandboxExecStepFactory()), } } diff --git a/plugins/pipelinesteps/plugin_test.go b/plugins/pipelinesteps/plugin_test.go index a0e98c11..aa890748 100644 --- a/plugins/pipelinesteps/plugin_test.go +++ b/plugins/pipelinesteps/plugin_test.go @@ -62,7 +62,10 @@ func TestStepFactories(t *testing.T) { "step.resilient_circuit_breaker", "step.s3_upload", "step.auth_validate", + "step.token_revoke", "step.base64_decode", + "step.field_reencrypt", + "step.sandbox_exec", } for _, stepType := range expectedSteps { diff --git a/plugins/secrets/plugin.go b/plugins/secrets/plugin.go index 43b47633..a9920029 100644 --- a/plugins/secrets/plugin.go +++ b/plugins/secrets/plugin.go @@ -1,5 +1,6 @@ // Package secrets provides a plugin that registers secrets management modules: -// secrets.vault (HashiCorp Vault) and secrets.aws (AWS Secrets Manager). +// secrets.vault (HashiCorp Vault) and secrets.aws (AWS Secrets Manager), +// as well as the step.secret_rotate pipeline step type. package secrets import ( @@ -9,7 +10,7 @@ import ( "github.com/GoCodeAlone/workflow/plugin" ) -// Plugin registers secrets management module factories. +// Plugin registers secrets management module factories and step types. type Plugin struct { plugin.BaseEnginePlugin } @@ -30,6 +31,7 @@ func New() *Plugin { Description: "Secrets management modules (Vault, AWS Secrets Manager)", Tier: plugin.TierCore, ModuleTypes: []string{"secrets.vault", "secrets.aws"}, + StepTypes: []string{"step.secret_rotate"}, Capabilities: []plugin.CapabilityDecl{ {Name: "secrets-management", Role: "provider", Priority: 50}, }, @@ -85,3 +87,12 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { }, } } + +// StepFactories returns the step factories provided by this plugin. +func (p *Plugin) StepFactories() map[string]plugin.StepFactory { + return map[string]plugin.StepFactory{ + "step.secret_rotate": func(name string, cfg map[string]any, app modular.Application) (any, error) { + return module.NewSecretRotateStepFactory()(name, cfg, app) + }, + } +} diff --git a/sandbox/docker.go b/sandbox/docker.go index 9d5fda3c..496997f9 100644 --- a/sandbox/docker.go +++ b/sandbox/docker.go @@ -34,6 +34,38 @@ type SandboxConfig struct { CPULimit float64 `yaml:"cpu_limit"` Timeout time.Duration `yaml:"timeout"` NetworkMode string `yaml:"network_mode"` + + // Security hardening fields + SecurityOpts []string `yaml:"security_opts"` // e.g., ["seccomp=default.json"] + CapAdd []string `yaml:"cap_add"` // capabilities to add + CapDrop []string `yaml:"cap_drop"` // e.g., ["ALL"] + ReadOnlyRootfs bool `yaml:"read_only_rootfs"` + NoNewPrivileges bool `yaml:"no_new_privileges"` + User string `yaml:"user"` // e.g., "nobody:nogroup" + PidsLimit int64 `yaml:"pids_limit"` // max process count + Tmpfs map[string]string `yaml:"tmpfs"` // e.g., {"/tmp": "size=64m,noexec"} +} + +// DefaultSecureSandboxConfig returns a hardened SandboxConfig suitable for +// running untrusted workloads. It uses a minimal Wolfi-based image, drops all +// Linux capabilities, enables a read-only root filesystem, mounts /tmp as +// tmpfs with noexec, and disables network access. +func DefaultSecureSandboxConfig(image string) SandboxConfig { + if image == "" { + image = "cgr.dev/chainguard/wolfi-base:latest" + } + return SandboxConfig{ + Image: image, + MemoryLimit: 256 * 1024 * 1024, // 256MB + CPULimit: 0.5, + NetworkMode: "none", + CapDrop: []string{"ALL"}, + NoNewPrivileges: true, + ReadOnlyRootfs: true, + PidsLimit: 64, + Tmpfs: map[string]string{"/tmp": "size=64m,noexec"}, + Timeout: 5 * time.Minute, + } } // ExecResult holds the output from a command execution inside the sandbox. @@ -92,6 +124,9 @@ func (s *DockerSandbox) Exec(ctx context.Context, cmd []string) (*ExecResult, er Env: s.buildEnv(), WorkingDir: s.config.WorkDir, } + if s.config.User != "" { + containerConfig.User = s.config.User + } hostConfig := s.buildHostConfig() @@ -200,6 +235,9 @@ func (s *DockerSandbox) ExecInContainer(ctx context.Context, cmd []string, copyI Env: s.buildEnv(), WorkingDir: s.config.WorkDir, } + if s.config.User != "" { + containerConfig.User = s.config.User + } hostConfig := s.buildHostConfig() @@ -320,6 +358,10 @@ func (s *DockerSandbox) buildHostConfig() *container.HostConfig { // Docker uses NanoCPUs (1 CPU = 1e9 NanoCPUs) hc.NanoCPUs = int64(s.config.CPULimit * 1e9) } + if s.config.PidsLimit > 0 { + limit := s.config.PidsLimit + hc.PidsLimit = &limit + } // Mounts if len(s.config.Mounts) > 0 { @@ -340,6 +382,30 @@ func (s *DockerSandbox) buildHostConfig() *container.HostConfig { hc.NetworkMode = container.NetworkMode(s.config.NetworkMode) } + // Security options + secOpts := make([]string, len(s.config.SecurityOpts)) + copy(secOpts, s.config.SecurityOpts) + if s.config.NoNewPrivileges { + secOpts = append(secOpts, "no-new-privileges:true") + } + if len(secOpts) > 0 { + hc.SecurityOpt = secOpts + } + + // Capabilities + if len(s.config.CapAdd) > 0 { + hc.CapAdd = s.config.CapAdd + } + if len(s.config.CapDrop) > 0 { + hc.CapDrop = s.config.CapDrop + } + + // Filesystem hardening + hc.ReadonlyRootfs = s.config.ReadOnlyRootfs + if len(s.config.Tmpfs) > 0 { + hc.Tmpfs = s.config.Tmpfs + } + return hc } diff --git a/sandbox/docker_test.go b/sandbox/docker_test.go index d317e427..722f4053 100644 --- a/sandbox/docker_test.go +++ b/sandbox/docker_test.go @@ -204,3 +204,129 @@ func TestBuildHostConfig_NoLimits(t *testing.T) { t.Fatalf("expected 0 mounts, got %d", len(hc.Mounts)) } } + +func TestBuildHostConfig_SecurityFields(t *testing.T) { + pidsLimit := int64(32) + sb := &DockerSandbox{ + config: SandboxConfig{ + SecurityOpts: []string{"seccomp=default.json"}, + CapAdd: []string{"NET_BIND_SERVICE"}, + CapDrop: []string{"ALL"}, + ReadOnlyRootfs: true, + NoNewPrivileges: true, + PidsLimit: pidsLimit, + Tmpfs: map[string]string{"/tmp": "size=32m,noexec"}, + }, + } + + hc := sb.buildHostConfig() + + // SecurityOpt should contain both seccomp and no-new-privileges + foundSeccomp := false + foundNoNewPriv := false + for _, opt := range hc.SecurityOpt { + if opt == "seccomp=default.json" { + foundSeccomp = true + } + if opt == "no-new-privileges:true" { + foundNoNewPriv = true + } + } + if !foundSeccomp { + t.Fatal("expected seccomp=default.json in SecurityOpt") + } + if !foundNoNewPriv { + t.Fatal("expected no-new-privileges:true in SecurityOpt") + } + + if len(hc.CapAdd) != 1 || hc.CapAdd[0] != "NET_BIND_SERVICE" { + t.Fatalf("unexpected CapAdd: %v", hc.CapAdd) + } + if len(hc.CapDrop) != 1 || hc.CapDrop[0] != "ALL" { + t.Fatalf("unexpected CapDrop: %v", hc.CapDrop) + } + if !hc.ReadonlyRootfs { + t.Fatal("expected ReadonlyRootfs true") + } + if hc.PidsLimit == nil || *hc.PidsLimit != pidsLimit { + t.Fatalf("expected PidsLimit %d, got %v", pidsLimit, hc.PidsLimit) + } + if hc.Tmpfs["/tmp"] != "size=32m,noexec" { + t.Fatalf("unexpected Tmpfs: %v", hc.Tmpfs) + } +} + +func TestBuildHostConfig_NoNewPrivilegesOnly(t *testing.T) { + sb := &DockerSandbox{ + config: SandboxConfig{ + NoNewPrivileges: true, + }, + } + + hc := sb.buildHostConfig() + + found := false + for _, opt := range hc.SecurityOpt { + if opt == "no-new-privileges:true" { + found = true + } + } + if !found { + t.Fatal("expected no-new-privileges:true in SecurityOpt") + } +} + +func TestBuildHostConfig_PidsLimitZeroNotSet(t *testing.T) { + sb := &DockerSandbox{ + config: SandboxConfig{PidsLimit: 0}, + } + + hc := sb.buildHostConfig() + + if hc.PidsLimit != nil { + t.Fatalf("expected nil PidsLimit when PidsLimit=0, got %v", hc.PidsLimit) + } +} + +func TestDefaultSecureSandboxConfig(t *testing.T) { + cfg := DefaultSecureSandboxConfig("alpine:3.19") + + if cfg.Image != "alpine:3.19" { + t.Fatalf("unexpected image: %s", cfg.Image) + } + if cfg.MemoryLimit != 256*1024*1024 { + t.Fatalf("unexpected MemoryLimit: %d", cfg.MemoryLimit) + } + if cfg.CPULimit != 0.5 { + t.Fatalf("unexpected CPULimit: %f", cfg.CPULimit) + } + if cfg.NetworkMode != "none" { + t.Fatalf("unexpected NetworkMode: %s", cfg.NetworkMode) + } + if len(cfg.CapDrop) != 1 || cfg.CapDrop[0] != "ALL" { + t.Fatalf("unexpected CapDrop: %v", cfg.CapDrop) + } + if !cfg.NoNewPrivileges { + t.Fatal("expected NoNewPrivileges true") + } + if !cfg.ReadOnlyRootfs { + t.Fatal("expected ReadOnlyRootfs true") + } + if cfg.PidsLimit != 64 { + t.Fatalf("unexpected PidsLimit: %d", cfg.PidsLimit) + } + if cfg.Tmpfs["/tmp"] != "size=64m,noexec" { + t.Fatalf("unexpected Tmpfs: %v", cfg.Tmpfs) + } + if cfg.Timeout != 5*time.Minute { + t.Fatalf("unexpected Timeout: %s", cfg.Timeout) + } +} + +func TestDefaultSecureSandboxConfig_DefaultImage(t *testing.T) { + cfg := DefaultSecureSandboxConfig("") + + if cfg.Image != "cgr.dev/chainguard/wolfi-base:latest" { + t.Fatalf("unexpected default image: %s", cfg.Image) + } +} diff --git a/secrets/rotation_test.go b/secrets/rotation_test.go new file mode 100644 index 00000000..d746ca35 --- /dev/null +++ b/secrets/rotation_test.go @@ -0,0 +1,343 @@ +package secrets + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +// Compile-time check: VaultProvider must implement RotationProvider. +var _ RotationProvider = (*VaultProvider)(nil) + +// versionedEntry stores all versions of a secret so GetPrevious can be tested. +type versionedEntry struct { + versions []map[string]interface{} // index 0 = version 1, index 1 = version 2, ... +} + +func (e *versionedEntry) current() map[string]interface{} { + if len(e.versions) == 0 { + return nil + } + return e.versions[len(e.versions)-1] +} + +func (e *versionedEntry) currentVersion() int { + return len(e.versions) +} + +func (e *versionedEntry) getVersion(v int) map[string]interface{} { + if v < 1 || v > len(e.versions) { + return nil + } + return e.versions[v-1] +} + +func (e *versionedEntry) put(data map[string]interface{}) { + e.versions = append(e.versions, data) +} + +// newVersionedVaultServer creates an httptest server that tracks all KV v2 versions. +// It supports ?version=N query params on GET for GetVersion requests. +func newVersionedVaultServer(t *testing.T) (*httptest.Server, map[string]*versionedEntry) { + t.Helper() + store := make(map[string]*versionedEntry) + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("X-Vault-Token") + if token == "" { + http.Error(w, `{"errors":["missing client token"]}`, http.StatusForbidden) + return + } + + path := r.URL.Path + switch { + case strings.Contains(path, "/data/"): + handleVersionedData(w, r, path, store) + default: + http.Error(w, `{"errors":["not found"]}`, http.StatusNotFound) + } + }) + + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + return server, store +} + +func handleVersionedData(w http.ResponseWriter, r *http.Request, path string, store map[string]*versionedEntry) { + parts := strings.SplitN(path, "/data/", 2) + if len(parts) < 2 { + http.Error(w, `{"errors":["invalid path"]}`, http.StatusBadRequest) + return + } + key := parts[1] + + switch r.Method { + case http.MethodGet: + entry, ok := store[key] + if !ok || len(entry.versions) == 0 { + http.Error(w, `{"errors":[]}`, http.StatusNotFound) + return + } + + // Check for ?version=N query param + versionParam := r.URL.Query().Get("version") + var data map[string]interface{} + var version int + + if versionParam != "" { + v, err := strconv.Atoi(versionParam) + if err != nil || v < 1 { + http.Error(w, `{"errors":["invalid version"]}`, http.StatusBadRequest) + return + } + data = entry.getVersion(v) + if data == nil { + http.Error(w, `{"errors":[]}`, http.StatusNotFound) + return + } + version = v + } else { + data = entry.current() + version = entry.currentVersion() + } + + resp := map[string]interface{}{ + "data": map[string]interface{}{ + "data": data, + "metadata": map[string]interface{}{ + "version": version, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + case http.MethodPost, http.MethodPut: + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, `{"errors":["read body failed"]}`, http.StatusBadRequest) + return + } + var payload struct { + Data map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(body, &payload); err != nil { + http.Error(w, `{"errors":["invalid json"]}`, http.StatusBadRequest) + return + } + + if store[key] == nil { + store[key] = &versionedEntry{} + } + store[key].put(payload.Data) + + newVersion := store[key].currentVersion() + resp := map[string]interface{}{ + "data": map[string]interface{}{ + "version": newVersion, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + default: + http.Error(w, `{"errors":["method not allowed"]}`, http.StatusMethodNotAllowed) + } +} + +func TestVaultProvider_Rotate(t *testing.T) { + server, store := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + MountPath: "secret", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + ctx := context.Background() + + // Rotate on a fresh key — should create version 1. + newVal, err := p.Rotate(ctx, "myapp/api-key") + if err != nil { + t.Fatalf("Rotate: %v", err) + } + if newVal == "" { + t.Fatal("expected non-empty rotated value") + } + // 32 bytes hex-encoded = 64 chars + if len(newVal) != 64 { + t.Errorf("expected 64-char hex value, got %d chars: %q", len(newVal), newVal) + } + + entry, ok := store["myapp/api-key"] + if !ok { + t.Fatal("expected key to exist in store") + } + if entry.currentVersion() != 1 { + t.Errorf("expected version 1 after first rotate, got %d", entry.currentVersion()) + } + + // Rotate again — should create version 2. + newVal2, err := p.Rotate(ctx, "myapp/api-key") + if err != nil { + t.Fatalf("Rotate (second): %v", err) + } + if newVal2 == newVal { + t.Error("expected different value after second rotate") + } + if entry.currentVersion() != 2 { + t.Errorf("expected version 2 after second rotate, got %d", entry.currentVersion()) + } +} + +func TestVaultProvider_Rotate_EmptyKey(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.Rotate(context.Background(), "") + if err != ErrInvalidKey { + t.Errorf("expected ErrInvalidKey, got %v", err) + } +} + +func TestVaultProvider_GetPrevious(t *testing.T) { + server, store := newVersionedVaultServer(t) + + // Pre-populate two versions directly in the store. + store["myapp/db-pass"] = &versionedEntry{ + versions: []map[string]interface{}{ + {"value": "old-secret"}, + {"value": "new-secret"}, + }, + } + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + MountPath: "secret", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + ctx := context.Background() + + // GetPrevious should return version 1 (old-secret). + prev, err := p.GetPrevious(ctx, "myapp/db-pass#value") + if err != nil { + t.Fatalf("GetPrevious: %v", err) + } + if prev != "old-secret" { + t.Errorf("expected 'old-secret', got %q", prev) + } +} + +func TestVaultProvider_GetPrevious_NoHistory(t *testing.T) { + server, store := newVersionedVaultServer(t) + + // Only one version exists — GetPrevious should return ErrNotFound. + store["myapp/only-one"] = &versionedEntry{ + versions: []map[string]interface{}{ + {"value": "only"}, + }, + } + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.GetPrevious(context.Background(), "myapp/only-one") + if err == nil { + t.Fatal("expected error when only one version exists") + } +} + +func TestVaultProvider_GetPrevious_EmptyKey(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.GetPrevious(context.Background(), "") + if err != ErrInvalidKey { + t.Errorf("expected ErrInvalidKey, got %v", err) + } +} + +func TestVaultProvider_GetPrevious_NotFound(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + _, err = p.GetPrevious(context.Background(), "nonexistent/key") + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestVaultProvider_Rotate_ThenGetPrevious(t *testing.T) { + server, _ := newVersionedVaultServer(t) + + p, err := NewVaultProvider(VaultConfig{ + Address: server.URL, + Token: "test-token", + MountPath: "secret", + }) + if err != nil { + t.Fatalf("NewVaultProvider: %v", err) + } + + ctx := context.Background() + + // First rotation creates version 1. + val1, err := p.Rotate(ctx, "svc/token") + if err != nil { + t.Fatalf("first Rotate: %v", err) + } + + // Second rotation creates version 2. + _, err = p.Rotate(ctx, "svc/token") + if err != nil { + t.Fatalf("second Rotate: %v", err) + } + + // GetPrevious should return the version 1 value (stored as "value" field). + prev, err := p.GetPrevious(ctx, "svc/token#value") + if err != nil { + t.Fatalf("GetPrevious: %v", err) + } + if prev != val1 { + t.Errorf("expected previous value %q, got %q", val1, prev) + } +} diff --git a/secrets/secrets.go b/secrets/secrets.go index 2b7f7eef..84532b78 100644 --- a/secrets/secrets.go +++ b/secrets/secrets.go @@ -34,6 +34,15 @@ type Provider interface { List(ctx context.Context) ([]string, error) } +// RotationProvider extends Provider with key rotation capabilities. +type RotationProvider interface { + Provider + // Rotate generates a new secret value and stores it, returning the new value. + Rotate(ctx context.Context, key string) (string, error) + // GetPrevious retrieves the previous version of a rotated secret (for grace periods). + GetPrevious(ctx context.Context, key string) (string, error) +} + // --- Environment Variable Provider --- // EnvProvider reads secrets from environment variables. diff --git a/secrets/vault_provider.go b/secrets/vault_provider.go index 1f7dbb3d..b0c3fa98 100644 --- a/secrets/vault_provider.go +++ b/secrets/vault_provider.go @@ -2,6 +2,8 @@ package secrets import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "strings" @@ -157,6 +159,85 @@ func (p *VaultProvider) List(ctx context.Context) ([]string, error) { return p.listRecursive(ctx, "") } +// Rotate generates a new random 32-byte hex-encoded secret and stores it at the given key, +// creating a new version in Vault KV v2. It returns the newly generated value. +func (p *VaultProvider) Rotate(ctx context.Context, key string) (string, error) { + if key == "" { + return "", ErrInvalidKey + } + + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("secrets: failed to generate random secret: %w", err) + } + newValue := hex.EncodeToString(raw) + + path, _ := parseVaultKey(key) + kv := p.client.KVv2(p.config.MountPath) + if _, err := kv.Put(ctx, path, map[string]interface{}{ + "value": newValue, + }); err != nil { + return "", fmt.Errorf("secrets: vault rotate failed: %w", err) + } + + return newValue, nil +} + +// GetPrevious retrieves version N-1 of the secret at the given key from Vault KV v2. +// It reads the current version metadata to determine N, then fetches version N-1. +// Returns ErrNotFound if the secret has only one version or does not exist. +func (p *VaultProvider) GetPrevious(ctx context.Context, key string) (string, error) { + if key == "" { + return "", ErrInvalidKey + } + + path, field := parseVaultKey(key) + kv := p.client.KVv2(p.config.MountPath) + + // Get the current version to determine the previous version number. + current, err := kv.Get(ctx, path) + if err != nil { + if isVaultNotFound(err) { + return "", fmt.Errorf("%w: vault returned not found for key %q", ErrNotFound, key) + } + return "", fmt.Errorf("secrets: vault get (for previous) failed: %w", err) + } + if current == nil || current.VersionMetadata == nil { + return "", fmt.Errorf("%w: no version metadata for key %q", ErrNotFound, key) + } + + currentVersion := current.VersionMetadata.Version + if currentVersion <= 1 { + return "", fmt.Errorf("%w: no previous version exists for key %q (current version is %d)", ErrNotFound, key, currentVersion) + } + + prevVersion := currentVersion - 1 + prev, err := kv.GetVersion(ctx, path, prevVersion) + if err != nil { + if isVaultNotFound(err) { + return "", fmt.Errorf("%w: previous version %d not found for key %q", ErrNotFound, prevVersion, key) + } + return "", fmt.Errorf("secrets: vault get version %d failed: %w", prevVersion, err) + } + if prev == nil || prev.Data == nil { + return "", fmt.Errorf("%w: no data in previous version %d for key %q", ErrNotFound, prevVersion, key) + } + + if field != "" { + val, ok := prev.Data[field] + if !ok { + return "", fmt.Errorf("%w: field %q not found in previous version of key %q", ErrNotFound, field, path) + } + return fmt.Sprintf("%v", val), nil + } + + data, err := json.Marshal(prev.Data) + if err != nil { + return "", fmt.Errorf("secrets: failed to marshal vault previous version data: %w", err) + } + return string(data), nil +} + // listRecursive walks the metadata tree and collects all leaf keys. func (p *VaultProvider) listRecursive(ctx context.Context, prefix string) ([]string, error) { // Construct metadata path: {mount}/metadata or {mount}/metadata/{prefix} From a6042be91a81e3a699b0caa23b8831b7f2c02686 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sat, 28 Feb 2026 05:10:14 -0500 Subject: [PATCH 2/6] fix: address CI lint and build failures - Fix nilerr: use separate parseErr variable in token revoke step - Fix staticcheck: remove redundant embedded field selector in SCRAM client - Remove unused alwaysErrorApp type and fmt import in test - Update example/go.sum with xdg-go/scram dependency Co-Authored-By: Claude Opus 4.6 --- example/go.mod | 3 +++ example/go.sum | 7 +++++++ module/kafka_scram.go | 2 +- module/pipeline_step_token_revoke.go | 4 ++-- module/pipeline_step_token_revoke_test.go | 10 ---------- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/example/go.mod b/example/go.mod index b02fbb93..0369bc0e 100644 --- a/example/go.mod +++ b/example/go.mod @@ -151,6 +151,9 @@ require ( github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.2.0 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect diff --git a/example/go.sum b/example/go.sum index a907e5fd..f97c46e6 100644 --- a/example/go.sum +++ b/example/go.sum @@ -407,6 +407,12 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= +github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= @@ -499,6 +505,7 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= diff --git a/module/kafka_scram.go b/module/kafka_scram.go index 34c005fa..9f455cf0 100644 --- a/module/kafka_scram.go +++ b/module/kafka_scram.go @@ -22,7 +22,7 @@ type xDGSCRAMClient struct { } func (x *xDGSCRAMClient) Begin(userName, password, authzID string) error { - client, err := x.HashGeneratorFcn.NewClient(userName, password, authzID) + client, err := x.NewClient(userName, password, authzID) if err != nil { return err } diff --git a/module/pipeline_step_token_revoke.go b/module/pipeline_step_token_revoke.go index c3737943..a3777049 100644 --- a/module/pipeline_step_token_revoke.go +++ b/module/pipeline_step_token_revoke.go @@ -59,8 +59,8 @@ func (s *TokenRevokeStep) Execute(_ context.Context, pc *PipelineContext) (*Step // 3. Parse claims without signature validation (token is being revoked, not authenticated). parser := jwt.NewParser() - token, _, err := parser.ParseUnverified(tokenStr, jwt.MapClaims{}) - if err != nil { + token, _, parseErr := parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if parseErr != nil { return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid token format"}}, nil } diff --git a/module/pipeline_step_token_revoke_test.go b/module/pipeline_step_token_revoke_test.go index 74650dd5..b3b79a84 100644 --- a/module/pipeline_step_token_revoke_test.go +++ b/module/pipeline_step_token_revoke_test.go @@ -2,7 +2,6 @@ package module import ( "context" - "fmt" "testing" "time" @@ -298,12 +297,3 @@ var _ TokenBlacklist = (*TokenBlacklistModule)(nil) // Compile-time check: mockBlacklist satisfies TokenBlacklist. var _ TokenBlacklist = (*mockBlacklist)(nil) -// fakeTokenRevokeBlacklistError implements TokenBlacklist but always errors on GetService. -type alwaysErrorApp struct { - *MockApplication - serviceErr error -} - -func (a *alwaysErrorApp) GetService(name string, out any) error { - return fmt.Errorf("forced error: %w", a.serviceErr) -} From d5bafa22ad13566a03675c0fb40394057ebdc416 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sat, 28 Feb 2026 05:22:23 -0500 Subject: [PATCH 3/6] fix: field protection wiring, error handling, and Kafka integration - Require master_key (error instead of zero-key fallback) - Handle error in field-protection module factory - Add field-protection-wiring hook: connects ProtectedFieldManager to KafkaBroker - KafkaBroker.SetFieldProtection() for field-level encrypt/decrypt in JSON payloads - Add Registry.Len() method - Add TestFieldProtectionRequiresMasterKey test - Update wiring hook count in auth plugin tests Co-Authored-By: Claude Opus 4.6 --- module/field_protection.go | 11 +++----- module/field_protection_test.go | 14 ++++++++++ module/kafka_broker.go | 48 +++++++++++++++++++++++++++++---- pkg/fieldcrypt/fieldcrypt.go | 5 ++++ plugins/auth/plugin.go | 31 ++++++++++++++++++++- plugins/auth/plugin_test.go | 10 +++---- 6 files changed, 101 insertions(+), 18 deletions(-) diff --git a/module/field_protection.go b/module/field_protection.go index 8da17141..d0045d88 100644 --- a/module/field_protection.go +++ b/module/field_protection.go @@ -2,7 +2,7 @@ package module import ( "context" - "log" + "fmt" "os" "github.com/CrisisTextLine/modular" @@ -100,13 +100,10 @@ func NewFieldProtectionModule(name string, cfg map[string]any) (*FieldProtection masterKeyStr = os.Getenv("FIELD_ENCRYPTION_KEY") } - var masterKey []byte - if masterKeyStr != "" { - masterKey = []byte(masterKeyStr) - } else { - log.Println("WARNING: field-protection module using zero key — set master_key or FIELD_ENCRYPTION_KEY") - masterKey = make([]byte, 32) + if masterKeyStr == "" { + return nil, fmt.Errorf("field-protection: master_key config or FIELD_ENCRYPTION_KEY env var is required") } + masterKey := []byte(masterKeyStr) scanDepth := 10 if v, ok := cfg["scan_depth"].(int); ok && v > 0 { diff --git a/module/field_protection_test.go b/module/field_protection_test.go index 599015ff..d6ad3090 100644 --- a/module/field_protection_test.go +++ b/module/field_protection_test.go @@ -190,3 +190,17 @@ func TestFieldProtectionProvidesServices(t *testing.T) { t.Error("service instance should be *ProtectedFieldManager") } } + +func TestFieldProtectionRequiresMasterKey(t *testing.T) { + cfg := map[string]any{ + "protected_fields": []any{}, + } + + // Unset env var to ensure it's not picked up. + t.Setenv("FIELD_ENCRYPTION_KEY", "") + + _, err := NewFieldProtectionModule("fp-nokey", cfg) + if err == nil { + t.Fatal("expected error when no master_key is provided") + } +} diff --git a/module/kafka_broker.go b/module/kafka_broker.go index b20d2c7a..4cdfd92e 100644 --- a/module/kafka_broker.go +++ b/module/kafka_broker.go @@ -2,6 +2,7 @@ package module import ( "context" + "encoding/json" "fmt" "sync" @@ -38,8 +39,9 @@ type KafkaBroker struct { logger modular.Logger healthy bool healthMsg string - encryptor *FieldEncryptor - tlsCfg KafkaTLSConfig + encryptor *FieldEncryptor + fieldProtector *ProtectedFieldManager + tlsCfg KafkaTLSConfig } // NewKafkaBroker creates a new Kafka message broker. @@ -109,6 +111,15 @@ func (b *KafkaBroker) SetGroupID(groupID string) { } // SetTLSConfig sets the TLS and SASL configuration for the Kafka broker. +// SetFieldProtection sets the field-level encryption manager for this broker. +// When set, individual protected fields are encrypted/decrypted in JSON payloads +// before the legacy whole-message encryptor runs. +func (b *KafkaBroker) SetFieldProtection(mgr *ProtectedFieldManager) { + b.mu.Lock() + defer b.mu.Unlock() + b.fieldProtector = mgr +} + func (b *KafkaBroker) SetTLSConfig(cfg KafkaTLSConfig) { b.mu.Lock() defer b.mu.Unlock() @@ -298,16 +309,31 @@ func (p *kafkaProducerAdapter) SendMessage(topic string, message []byte) error { p.broker.mu.RLock() producer := p.broker.producer encryptor := p.broker.encryptor + fieldProt := p.broker.fieldProtector p.broker.mu.RUnlock() if producer == nil { return fmt.Errorf("kafka producer not initialized; call Start first") } - // Encrypt the message payload if encryption is enabled payload := message + + // Field-level encryption: encrypt individual protected fields in JSON payloads. + if fieldProt != nil { + var data map[string]any + if err := json.Unmarshal(payload, &data); err == nil { + if encErr := fieldProt.EncryptMap(context.Background(), "", data); encErr != nil { + return fmt.Errorf("failed to field-encrypt kafka message for topic %q: %w", topic, encErr) + } + if out, err := json.Marshal(data); err == nil { + payload = out + } + } + } + + // Legacy whole-message encryption (if configured via ENCRYPTION_KEY). if encryptor != nil && encryptor.Enabled() { - encrypted, err := encryptor.EncryptJSON(message) + encrypted, err := encryptor.EncryptJSON(payload) if err != nil { return fmt.Errorf("failed to encrypt kafka message for topic %q: %w", topic, err) } @@ -367,10 +393,11 @@ func (h *kafkaGroupHandler) ConsumeClaim(session sarama.ConsumerGroupSession, cl h.broker.mu.RLock() handler, ok := h.broker.handlers[msg.Topic] encryptor := h.broker.encryptor + fieldProt := h.broker.fieldProtector h.broker.mu.RUnlock() if ok { - // Decrypt message payload if encryption is enabled + // Legacy whole-message decryption first. payload := msg.Value if encryptor != nil && encryptor.Enabled() { decrypted, err := encryptor.DecryptJSON(payload) @@ -381,6 +408,17 @@ func (h *kafkaGroupHandler) ConsumeClaim(session sarama.ConsumerGroupSession, cl } payload = decrypted } + // Field-level decryption: decrypt individual protected fields. + if fieldProt != nil { + var data map[string]any + if err := json.Unmarshal(payload, &data); err == nil { + if decErr := fieldProt.DecryptMap(context.Background(), "", data); decErr == nil { + if out, err := json.Marshal(data); err == nil { + payload = out + } + } + } + } if err := handler.HandleMessage(payload); err != nil { h.broker.logger.Error("Error handling Kafka message", "topic", msg.Topic, "error", err) diff --git a/pkg/fieldcrypt/fieldcrypt.go b/pkg/fieldcrypt/fieldcrypt.go index 9265ef94..9e47ef9f 100644 --- a/pkg/fieldcrypt/fieldcrypt.go +++ b/pkg/fieldcrypt/fieldcrypt.go @@ -56,6 +56,11 @@ func (r *Registry) GetField(fieldName string) (*ProtectedField, bool) { return &f, true } +// Len returns the number of registered protected fields. +func (r *Registry) Len() int { + return len(r.fields) +} + // ProtectedFields returns all registered protected fields. func (r *Registry) ProtectedFields() []ProtectedField { out := make([]ProtectedField, 0, len(r.fields)) diff --git a/plugins/auth/plugin.go b/plugins/auth/plugin.go index 3e69b721..89f12c93 100644 --- a/plugins/auth/plugin.go +++ b/plugins/auth/plugin.go @@ -1,6 +1,7 @@ package auth import ( + "log" "time" "github.com/CrisisTextLine/modular" @@ -147,7 +148,11 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { return module.NewTokenBlacklistModule(name, backend, redisURL, cleanupInterval) }, "security.field-protection": func(name string, cfg map[string]any) modular.Module { - mod, _ := module.NewFieldProtectionModule(name, cfg) + mod, err := module.NewFieldProtectionModule(name, cfg) + if err != nil { + log.Printf("ERROR: field-protection module %q: %v", name, err) + return nil + } return mod }, "auth.m2m": func(name string, cfg map[string]any) modular.Module { @@ -280,6 +285,30 @@ func (p *Plugin) WiringHooks() []plugin.WiringHook { return nil }, }, + { + Name: "field-protection-wiring", + Priority: 50, + Hook: func(app modular.Application, _ *config.WorkflowConfig) error { + var mgr *module.ProtectedFieldManager + for _, svc := range app.SvcRegistry() { + if m, ok := svc.(*module.ProtectedFieldManager); ok { + mgr = m + break + } + } + if mgr == nil { + return nil + } + // Wire field protection to Kafka brokers for field-level encryption. + for _, svc := range app.SvcRegistry() { + if kb, ok := svc.(*module.KafkaBroker); ok { + kb.SetFieldProtection(mgr) + } + } + log.Printf("field-protection: wired to %d registered fields", mgr.Registry.Len()) + return nil + }, + }, } } diff --git a/plugins/auth/plugin_test.go b/plugins/auth/plugin_test.go index 456530b3..95be066d 100644 --- a/plugins/auth/plugin_test.go +++ b/plugins/auth/plugin_test.go @@ -24,8 +24,8 @@ func TestPluginManifest(t *testing.T) { if len(m.ModuleTypes) != 6 { t.Errorf("expected 6 module types, got %d", len(m.ModuleTypes)) } - if len(m.WiringHooks) != 3 { - t.Errorf("expected 3 wiring hooks, got %d", len(m.WiringHooks)) + if len(m.WiringHooks) != 4 { + t.Errorf("expected 4 wiring hooks, got %d", len(m.WiringHooks)) } } @@ -83,8 +83,8 @@ func TestModuleFactoryJWTWithConfig(t *testing.T) { func TestWiringHooks(t *testing.T) { p := New() hooks := p.WiringHooks() - if len(hooks) != 3 { - t.Fatalf("expected 3 wiring hooks, got %d", len(hooks)) + if len(hooks) != 4 { + t.Fatalf("expected 4 wiring hooks, got %d", len(hooks)) } hookNames := map[string]bool{} for _, h := range hooks { @@ -93,7 +93,7 @@ func TestWiringHooks(t *testing.T) { t.Errorf("wiring hook %q function is nil", h.Name) } } - for _, expected := range []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring"} { + for _, expected := range []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring", "field-protection-wiring"} { if !hookNames[expected] { t.Errorf("missing wiring hook %q", expected) } From 9bd419f3aed86f423ba56acd6ff29d0cd9794f1b Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sat, 28 Feb 2026 05:24:53 -0500 Subject: [PATCH 4/6] fix: implement CopyIn/CopyOut with CreateContainer/RemoveContainer lifecycle Replace stub CopyIn/CopyOut methods with real implementations that use the Docker API. Added CreateContainer() to create a container and store its ID, and RemoveContainer() to clean up. CopyIn delegates to the existing copyToContainer helper; CopyOut uses client.CopyFromContainer. Co-Authored-By: Claude Opus 4.6 --- sandbox/docker.go | 63 ++++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/sandbox/docker.go b/sandbox/docker.go index 496997f9..6f2a4ad5 100644 --- a/sandbox/docker.go +++ b/sandbox/docker.go @@ -77,8 +77,9 @@ type ExecResult struct { // DockerSandbox wraps the Docker Engine SDK to execute commands in isolated containers. type DockerSandbox struct { - client *client.Client - config SandboxConfig + client *client.Client + config SandboxConfig + containerID string // set by CreateContainer, used by CopyIn/CopyOut/RemoveContainer } // NewDockerSandbox creates a new DockerSandbox with the given configuration. @@ -180,39 +181,51 @@ func (s *DockerSandbox) Exec(ctx context.Context, cmd []string) (*ExecResult, er }, nil } -// CopyIn copies a file from the host into a running or created container. +// CopyIn copies a file from the host into the active container. +// Call CreateContainer first to set the active container ID. func (s *DockerSandbox) CopyIn(ctx context.Context, srcPath, destPath string) error { - f, err := os.Open(srcPath) - if err != nil { - return fmt.Errorf("sandbox: failed to open source file: %w", err) + if s.containerID == "" { + return fmt.Errorf("sandbox: CopyIn requires an active container; call CreateContainer first") } - defer f.Close() + return s.copyToContainer(ctx, s.containerID, srcPath, destPath) +} - stat, err := f.Stat() +// CopyOut copies a file out of the active container. Returns a ReadCloser with the file contents. +// Call CreateContainer first to set the active container ID. +func (s *DockerSandbox) CopyOut(ctx context.Context, srcPath string) (io.ReadCloser, error) { + if s.containerID == "" { + return nil, fmt.Errorf("sandbox: CopyOut requires an active container; call CreateContainer first") + } + reader, _, err := s.client.CopyFromContainer(ctx, s.containerID, srcPath) if err != nil { - return fmt.Errorf("sandbox: failed to stat source file: %w", err) + return nil, fmt.Errorf("sandbox: CopyOut %q: %w", srcPath, err) } + return reader, nil +} - // Create a tar archive containing the file - tarReader, err := createTarFromFile(f, stat) +// CreateContainer creates and starts a container, storing its ID for use with CopyIn/CopyOut. +// Call RemoveContainer when done to clean up. +func (s *DockerSandbox) CreateContainer(ctx context.Context, cmd []string) error { + hostConfig := s.buildHostConfig() + resp, err := s.client.ContainerCreate(ctx, &container.Config{ + Image: s.config.Image, + Cmd: cmd, + }, hostConfig, nil, nil, "") if err != nil { - return fmt.Errorf("sandbox: failed to create tar archive: %w", err) + return fmt.Errorf("sandbox: create container: %w", err) } - - // We need a container to copy into. This method is intended to be used - // with a container that has been created but the caller manages its lifecycle. - // For the typical use case, the Exec method handles the full lifecycle. - // This is a lower-level utility for advanced usage. - _ = destPath - _ = tarReader - return fmt.Errorf("sandbox: CopyIn requires an active container ID; use Exec for typical workflows") + s.containerID = resp.ID + return nil } -// CopyOut copies a file out of a container. Returns a ReadCloser with the file contents. -func (s *DockerSandbox) CopyOut(ctx context.Context, srcPath string) (io.ReadCloser, error) { - // Similar to CopyIn, this requires an active container. - _ = srcPath - return nil, fmt.Errorf("sandbox: CopyOut requires an active container ID; use Exec for typical workflows") +// RemoveContainer stops and removes the active container. +func (s *DockerSandbox) RemoveContainer(ctx context.Context) error { + if s.containerID == "" { + return nil + } + id := s.containerID + s.containerID = "" + return s.client.ContainerRemove(ctx, id, container.RemoveOptions{Force: true}) } // ExecInContainer creates a container, copies files in, runs the command, and allows file extraction. From ba249f752bf5147980c4a2790cb620ce00fb6e58 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sat, 28 Feb 2026 05:31:51 -0500 Subject: [PATCH 5/6] fix: return errors from token revoke step for invalid tokens Return parseErr when JWT parsing fails and fmt.Errorf for invalid claims type, instead of swallowing the error. Fixes nilerr lint violation. Co-Authored-By: Claude Opus 4.6 --- module/pipeline_step_token_revoke.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/module/pipeline_step_token_revoke.go b/module/pipeline_step_token_revoke.go index a3777049..2ff7504c 100644 --- a/module/pipeline_step_token_revoke.go +++ b/module/pipeline_step_token_revoke.go @@ -61,12 +61,12 @@ func (s *TokenRevokeStep) Execute(_ context.Context, pc *PipelineContext) (*Step parser := jwt.NewParser() token, _, parseErr := parser.ParseUnverified(tokenStr, jwt.MapClaims{}) if parseErr != nil { - return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid token format"}}, nil + return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid token format"}}, parseErr } claims, ok := token.Claims.(jwt.MapClaims) if !ok { - return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid claims"}}, nil + return &StepResult{Output: map[string]any{"revoked": false, "error": "invalid claims"}}, fmt.Errorf("invalid JWT claims type") } jti, _ := claims["jti"].(string) From bac3b4f1ac02a877a1ccf2f901b91f4d6c9d1144 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sat, 28 Feb 2026 05:40:42 -0500 Subject: [PATCH 6/6] fix: add field-protection-wiring to manifest and set env in test The manifest WiringHooks list was missing the field-protection-wiring entry, and TestModuleFactories needed FIELD_ENCRYPTION_KEY set for the security.field-protection factory to succeed. Co-Authored-By: Claude Opus 4.6 --- plugins/auth/plugin.go | 2 +- plugins/auth/plugin_test.go | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/auth/plugin.go b/plugins/auth/plugin.go index 89f12c93..6b87fb18 100644 --- a/plugins/auth/plugin.go +++ b/plugins/auth/plugin.go @@ -57,7 +57,7 @@ func New() *Plugin { {Name: "authentication", Role: "provider", Priority: 10}, {Name: "user-management", Role: "provider", Priority: 10}, }, - WiringHooks: []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring"}, + WiringHooks: []string{"auth-provider-wiring", "oauth2-jwt-wiring", "token-blacklist-wiring", "field-protection-wiring"}, }, }, } diff --git a/plugins/auth/plugin_test.go b/plugins/auth/plugin_test.go index 95be066d..a031daf8 100644 --- a/plugins/auth/plugin_test.go +++ b/plugins/auth/plugin_test.go @@ -47,6 +47,9 @@ func TestPluginCapabilities(t *testing.T) { } func TestModuleFactories(t *testing.T) { + // field-protection factory requires a master key via env var + t.Setenv("FIELD_ENCRYPTION_KEY", "test-master-key-32-bytes-long!!") + p := New() factories := p.ModuleFactories()