From 9ac133779e356affd3a4258955195c9c831d728d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 03:45:58 +0000 Subject: [PATCH 1/5] Initial plan From 67eae7b6813e0430f5bdfedbb1106d407447275d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 04:05:19 +0000 Subject: [PATCH 2/5] feat: add PostgreSQL LIST partitioning support for multi-tenant data isolation - Add `database.partitioned` module type with PartitionedDatabase, PartitionKeyProvider and PartitionManager interfaces for managing PostgreSQL LIST partitions - Add `tenantKey` config to `step.db_query` and `step.db_exec` for automatic tenant scoping via PartitionKeyProvider - Add `step.db_create_partition` step for idempotent runtime partition creation - Add `appendTenantFilter` helper to sql_placeholders.go - Register all new types in storage plugin, pipelinesteps plugin, schema registry, coreModuleTypes/coreStepTypes, and type_registry - Add tests for new functionality Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- cmd/wfctl/type_registry.go | 15 +- module/cache_redis.go | 2 +- module/database_partitioned.go | 194 ++++++++++++++++++ module/database_partitioned_test.go | 216 ++++++++++++++++++++ module/http_server.go | 8 +- module/kafka_broker.go | 28 +-- module/pipeline_step_db_create_partition.go | 74 +++++++ module/pipeline_step_db_exec.go | 22 +- module/pipeline_step_db_query.go | 51 +++-- module/pipeline_step_db_query_test.go | 2 +- module/pipeline_step_db_tenant_test.go | 185 +++++++++++++++++ module/pipeline_step_sandbox_exec.go | 16 +- module/pipeline_step_token_revoke_test.go | 1 - module/sql_placeholders.go | 13 +- plugins/pipelinesteps/plugin.go | 18 +- plugins/pipelinesteps/plugin_test.go | 1 + plugins/storage/plugin.go | 44 ++++ plugins/storage/plugin_test.go | 12 +- schema/module_schema.go | 33 +++ schema/schema.go | 4 +- schema/snippets_export.go | 18 +- 21 files changed, 885 insertions(+), 72 deletions(-) create mode 100644 module/database_partitioned.go create mode 100644 module/database_partitioned_test.go create mode 100644 module/pipeline_step_db_create_partition.go create mode 100644 module/pipeline_step_db_tenant_test.go diff --git a/cmd/wfctl/type_registry.go b/cmd/wfctl/type_registry.go index 8bbae253..43f7930b 100644 --- a/cmd/wfctl/type_registry.go +++ b/cmd/wfctl/type_registry.go @@ -55,6 +55,12 @@ func KnownModuleTypes() map[string]ModuleTypeInfo { Stateful: true, ConfigKeys: []string{"driver", "dsn", "maxOpenConns", "maxIdleConns"}, }, + "database.partitioned": { + Type: "database.partitioned", + Plugin: "storage", + Stateful: true, + ConfigKeys: []string{"driver", "dsn", "partitionKey", "tables", "maxOpenConns", "maxIdleConns"}, + }, "persistence.store": { Type: "persistence.store", Plugin: "storage", @@ -584,18 +590,23 @@ func KnownStepTypes() map[string]StepTypeInfo { "step.db_query": { Type: "step.db_query", Plugin: "pipelinesteps", - ConfigKeys: []string{"database", "query", "params"}, + ConfigKeys: []string{"database", "query", "params", "tenantKey"}, }, "step.db_exec": { Type: "step.db_exec", Plugin: "pipelinesteps", - ConfigKeys: []string{"database", "query", "params"}, + ConfigKeys: []string{"database", "query", "params", "tenantKey"}, }, "step.db_query_cached": { Type: "step.db_query_cached", Plugin: "pipelinesteps", ConfigKeys: []string{"database", "query", "params", "cache_key", "cache_ttl", "scan_fields"}, }, + "step.db_create_partition": { + Type: "step.db_create_partition", + Plugin: "pipelinesteps", + ConfigKeys: []string{"database", "tenantKey"}, + }, "step.json_response": { Type: "step.json_response", Plugin: "pipelinesteps", diff --git a/module/cache_redis.go b/module/cache_redis.go index 8d750d7e..31281490 100644 --- a/module/cache_redis.go +++ b/module/cache_redis.go @@ -31,7 +31,7 @@ type RedisClient interface { // RedisCacheConfig holds configuration for the cache.redis module. type RedisCacheConfig struct { Address string - Password string //nolint:gosec // G117: config struct field, not a hardcoded secret + Password string //nolint:gosec // G117: config struct field, not a hardcoded secret DB int Prefix string DefaultTTL time.Duration diff --git a/module/database_partitioned.go b/module/database_partitioned.go new file mode 100644 index 00000000..1e0b1dc6 --- /dev/null +++ b/module/database_partitioned.go @@ -0,0 +1,194 @@ +package module + +import ( + "context" + "database/sql" + "fmt" + "regexp" + "strings" + "sync" + + "github.com/CrisisTextLine/modular" +) + +// validPartitionValue matches safe LIST partition values (alphanumeric, hyphens, underscores, dots). +var validPartitionValue = regexp.MustCompile(`^[a-zA-Z0-9_.\-]+$`) + +// PartitionKeyProvider is optionally implemented by database modules that support +// LIST partitioning. Steps can use PartitionKey() to determine the column name +// for automatic tenant scoping. +type PartitionKeyProvider interface { + DBProvider + PartitionKey() string +} + +// PartitionManager is optionally implemented by database modules that support +// runtime creation of LIST partitions. The EnsurePartition method is idempotent — +// if the partition already exists the call succeeds without error. +type PartitionManager interface { + PartitionKeyProvider + EnsurePartition(ctx context.Context, tenantValue string) error +} + +// PartitionedDatabaseConfig holds configuration for the database.partitioned module. +type PartitionedDatabaseConfig struct { + Driver string `json:"driver" yaml:"driver"` + DSN string `json:"dsn" yaml:"dsn"` + MaxOpenConns int `json:"maxOpenConns" yaml:"maxOpenConns"` + MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"` + PartitionKey string `json:"partitionKey" yaml:"partitionKey"` + Tables []string `json:"tables" yaml:"tables"` +} + +// PartitionedDatabase wraps WorkflowDatabase and adds PostgreSQL LIST partition +// management. It satisfies DBProvider, DBDriverProvider, PartitionKeyProvider, +// and PartitionManager. +type PartitionedDatabase struct { + name string + config PartitionedDatabaseConfig + base *WorkflowDatabase + mu sync.RWMutex +} + +// NewPartitionedDatabase creates a new PartitionedDatabase module. +func NewPartitionedDatabase(name string, cfg PartitionedDatabaseConfig) *PartitionedDatabase { + dbConfig := DatabaseConfig{ + Driver: cfg.Driver, + DSN: cfg.DSN, + MaxOpenConns: cfg.MaxOpenConns, + MaxIdleConns: cfg.MaxIdleConns, + } + return &PartitionedDatabase{ + name: name, + config: cfg, + base: NewWorkflowDatabase(name+"._base", dbConfig), + } +} + +// Name returns the module name. +func (p *PartitionedDatabase) Name() string { return p.name } + +// Init registers this module as a service. +func (p *PartitionedDatabase) Init(app modular.Application) error { + return app.RegisterService(p.name, p) +} + +// ProvidesServices declares the service this module provides. +func (p *PartitionedDatabase) ProvidesServices() []modular.ServiceProvider { + return []modular.ServiceProvider{ + { + Name: p.name, + Description: "Partitioned Database: " + p.name, + Instance: p, + }, + } +} + +// RequiresServices returns no dependencies. +func (p *PartitionedDatabase) RequiresServices() []modular.ServiceDependency { + return nil +} + +// Start opens the database connection during application startup. +func (p *PartitionedDatabase) Start(ctx context.Context) error { + return p.base.Start(ctx) +} + +// Stop closes the database connection during application shutdown. +func (p *PartitionedDatabase) Stop(ctx context.Context) error { + return p.base.Stop(ctx) +} + +// DB returns the underlying *sql.DB (satisfies DBProvider). +func (p *PartitionedDatabase) DB() *sql.DB { + return p.base.DB() +} + +// DriverName returns the configured database driver (satisfies DBDriverProvider). +func (p *PartitionedDatabase) DriverName() string { + return p.config.Driver +} + +// PartitionKey returns the column name used for LIST partitioning (satisfies PartitionKeyProvider). +func (p *PartitionedDatabase) PartitionKey() string { + return p.config.PartitionKey +} + +// Tables returns the list of tables managed by this partitioned database. +func (p *PartitionedDatabase) Tables() []string { + result := make([]string, len(p.config.Tables)) + copy(result, p.config.Tables) + return result +} + +// EnsurePartition creates a LIST partition for the given tenant value on all +// configured tables. The operation is idempotent — IF NOT EXISTS prevents errors +// when the partition already exists. +// +// Only PostgreSQL (pgx, pgx/v5, postgres) is supported. The method validates +// the tenant value and table/column names to prevent SQL injection. +func (p *PartitionedDatabase) EnsurePartition(ctx context.Context, tenantValue string) error { + if !validPartitionValue.MatchString(tenantValue) { + return fmt.Errorf("partitioned database %q: invalid tenant value %q (must match [a-zA-Z0-9_.\\-]+)", p.name, tenantValue) + } + + if !isSupportedPartitionDriver(p.config.Driver) { + return fmt.Errorf("partitioned database %q: driver %q does not support LIST partitioning (use pgx, pgx/v5, or postgres)", p.name, p.config.Driver) + } + + if err := validateIdentifier(p.config.PartitionKey); err != nil { + return fmt.Errorf("partitioned database %q: invalid partition_key: %w", p.name, err) + } + + db := p.base.DB() + if db == nil { + return fmt.Errorf("partitioned database %q: database connection is nil", p.name) + } + + p.mu.Lock() + defer p.mu.Unlock() + + for _, table := range p.config.Tables { + if err := validateIdentifier(table); err != nil { + return fmt.Errorf("partitioned database %q: invalid table name: %w", p.name, err) + } + + // Sanitize the partition suffix: replace hyphens and dots with underscores. + partitionSuffix := sanitizePartitionSuffix(tenantValue) + partitionName := table + "_" + partitionSuffix + + // Use IF NOT EXISTS to make this idempotent. + // The tenant value is embedded as a quoted literal (single-quoted). + // We have already validated tenantValue against validPartitionValue so + // it cannot contain single-quote characters. + sql := fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES IN ('%s')", + partitionName, + table, + strings.ReplaceAll(tenantValue, "'", ""), + ) + + if _, err := db.ExecContext(ctx, sql); err != nil { + return fmt.Errorf("partitioned database %q: failed to create partition %q for table %q: %w", + p.name, partitionName, table, err) + } + } + + return nil +} + +// isSupportedPartitionDriver returns true for PostgreSQL-compatible drivers. +func isSupportedPartitionDriver(driver string) bool { + switch driver { + case "pgx", "pgx/v5", "postgres", "postgresql": + return true + } + return false +} + +// sanitizePartitionSuffix converts a tenant value to a safe PostgreSQL identifier suffix. +// Hyphens and dots are replaced with underscores. +func sanitizePartitionSuffix(tenantValue string) string { + r := strings.NewReplacer("-", "_", ".", "_") + return r.Replace(tenantValue) +} diff --git a/module/database_partitioned_test.go b/module/database_partitioned_test.go new file mode 100644 index 00000000..252b1904 --- /dev/null +++ b/module/database_partitioned_test.go @@ -0,0 +1,216 @@ +package module + +import ( + "context" + "database/sql" + "testing" +) + +func TestPartitionedDatabase_PartitionKey(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + DSN: "postgres://localhost/test", + PartitionKey: "tenant_id", + Tables: []string{"forms", "submissions"}, + } + pd := NewPartitionedDatabase("db", cfg) + + if pd.PartitionKey() != "tenant_id" { + t.Errorf("expected tenant_id, got %q", pd.PartitionKey()) + } + if pd.Name() != "db" { + t.Errorf("expected name 'db', got %q", pd.Name()) + } + tables := pd.Tables() + if len(tables) != 2 { + t.Errorf("expected 2 tables, got %d", len(tables)) + } +} + +func TestPartitionedDatabase_EnsurePartition_InvalidDriver(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "sqlite3", + PartitionKey: "tenant_id", + Tables: []string{"forms"}, + } + pd := NewPartitionedDatabase("db", cfg) + + err := pd.EnsurePartition(context.Background(), "org-alpha") + if err == nil { + t.Fatal("expected error for non-postgres driver") + } +} + +func TestPartitionedDatabase_EnsurePartition_InvalidTenantValue(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + PartitionKey: "tenant_id", + Tables: []string{"forms"}, + } + pd := NewPartitionedDatabase("db", cfg) + + err := pd.EnsurePartition(context.Background(), "org'; DROP TABLE forms; --") + if err == nil { + t.Fatal("expected error for invalid tenant value") + } +} + +func TestPartitionedDatabase_EnsurePartition_InvalidPartitionKey(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + PartitionKey: "bad column name!", + Tables: []string{"forms"}, + } + pd := NewPartitionedDatabase("db", cfg) + + err := pd.EnsurePartition(context.Background(), "org-alpha") + if err == nil { + t.Fatal("expected error for invalid partition key") + } +} + +func TestSanitizePartitionSuffix(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"org-alpha", "org_alpha"}, + {"org.beta", "org_beta"}, + {"tenant_1", "tenant_1"}, + {"org-my.tenant", "org_my_tenant"}, + } + + for _, tc := range tests { + got := sanitizePartitionSuffix(tc.input) + if got != tc.expected { + t.Errorf("sanitizePartitionSuffix(%q) = %q, want %q", tc.input, got, tc.expected) + } + } +} + +func TestIsSupportedPartitionDriver(t *testing.T) { + supported := []string{"pgx", "pgx/v5", "postgres", "postgresql"} + for _, d := range supported { + if !isSupportedPartitionDriver(d) { + t.Errorf("expected %q to be supported", d) + } + } + + unsupported := []string{"sqlite3", "sqlite", "mysql", ""} + for _, d := range unsupported { + if isSupportedPartitionDriver(d) { + t.Errorf("expected %q to be unsupported", d) + } + } +} + +// testPartitionManager is a mock PartitionManager for testing step.db_create_partition. +type testPartitionManager struct { + partitionKey string + partitions map[string]bool +} + +func (p *testPartitionManager) DB() *sql.DB { return nil } +func (p *testPartitionManager) PartitionKey() string { return p.partitionKey } +func (p *testPartitionManager) EnsurePartition(_ context.Context, tenantValue string) error { + if p.partitions == nil { + p.partitions = make(map[string]bool) + } + p.partitions[tenantValue] = true + return nil +} + +func TestDBCreatePartitionStep_Execute(t *testing.T) { + mgr := &testPartitionManager{partitionKey: "tenant_id"} + app := NewMockApplication() + app.Services["part-db"] = mgr + + factory := NewDBCreatePartitionStepFactory() + step, err := factory("create-part", map[string]any{ + "database": "part-db", + "tenantKey": "steps.body.tenant_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("body", map[string]any{"tenant_id": "new-org"}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["tenant"] != "new-org" { + t.Errorf("expected tenant='new-org', got %v", result.Output["tenant"]) + } + if !mgr.partitions["new-org"] { + t.Error("expected partition to be created for new-org") + } +} + +func TestDBCreatePartitionStep_MissingDatabase(t *testing.T) { + factory := NewDBCreatePartitionStepFactory() + _, err := factory("create-part", map[string]any{ + "tenantKey": "body.tenant_id", + }, nil) + if err == nil { + t.Fatal("expected error for missing database") + } +} + +func TestDBCreatePartitionStep_MissingTenantKey(t *testing.T) { + factory := NewDBCreatePartitionStepFactory() + _, err := factory("create-part", map[string]any{ + "database": "part-db", + }, nil) + if err == nil { + t.Fatal("expected error for missing tenantKey") + } +} + +func TestDBCreatePartitionStep_NotPartitionManager(t *testing.T) { + db := setupTenantTestDB(t) + app := mockAppWithDB("plain-db", db) + + factory := NewDBCreatePartitionStepFactory() + step, err := factory("create-part", map[string]any{ + "database": "plain-db", + "tenantKey": "body.tenant_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("body", map[string]any{"tenant_id": "new-org"}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when service does not implement PartitionManager") + } +} + +func TestDBCreatePartitionStep_NilTenantValue(t *testing.T) { + mgr := &testPartitionManager{partitionKey: "tenant_id"} + app := NewMockApplication() + app.Services["part-db"] = mgr + + factory := NewDBCreatePartitionStepFactory() + step, err := factory("create-part", map[string]any{ + "database": "part-db", + "tenantKey": "steps.body.tenant_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + // No tenant_id in context + pc := NewPipelineContext(nil, nil) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when tenant value is nil") + } +} diff --git a/module/http_server.go b/module/http_server.go index f562becc..c58b045a 100644 --- a/module/http_server.go +++ b/module/http_server.go @@ -15,11 +15,11 @@ import ( // HTTPServerTLSConfig holds TLS configuration for the HTTP server. type HTTPServerTLSConfig struct { - Mode string `yaml:"mode" json:"mode"` // manual | autocert | disabled - Manual tlsutil.TLSConfig `yaml:"manual" json:"manual"` + Mode string `yaml:"mode" json:"mode"` // manual | autocert | disabled + Manual tlsutil.TLSConfig `yaml:"manual" json:"manual"` Autocert tlsutil.AutocertConfig `yaml:"autocert" json:"autocert"` - ClientCAFile string `yaml:"client_ca_file" json:"client_ca_file"` - ClientAuth string `yaml:"client_auth" json:"client_auth"` // require | request | none + ClientCAFile string `yaml:"client_ca_file" json:"client_ca_file"` + ClientAuth string `yaml:"client_auth" json:"client_auth"` // require | request | none } // StandardHTTPServer implements the HTTPServer interface and modular.Module interfaces diff --git a/module/kafka_broker.go b/module/kafka_broker.go index 4cdfd92e..f81ba345 100644 --- a/module/kafka_broker.go +++ b/module/kafka_broker.go @@ -7,8 +7,8 @@ import ( "sync" "github.com/CrisisTextLine/modular" - "github.com/IBM/sarama" "github.com/GoCodeAlone/workflow/pkg/tlsutil" + "github.com/IBM/sarama" ) // KafkaSASLConfig holds SASL authentication configuration for Kafka. @@ -26,19 +26,19 @@ type KafkaTLSConfig struct { // KafkaBroker implements the MessageBroker interface using Apache Kafka via Sarama. type KafkaBroker struct { - name string - brokers []string - groupID string - producer sarama.SyncProducer - consumerGroup sarama.ConsumerGroup - handlers map[string]MessageHandler - mu sync.RWMutex - kafkaProducer *kafkaProducerAdapter - kafkaConsumer *kafkaConsumerAdapter - cancelFunc context.CancelFunc - logger modular.Logger - healthy bool - healthMsg string + name string + brokers []string + groupID string + producer sarama.SyncProducer + consumerGroup sarama.ConsumerGroup + handlers map[string]MessageHandler + mu sync.RWMutex + kafkaProducer *kafkaProducerAdapter + kafkaConsumer *kafkaConsumerAdapter + cancelFunc context.CancelFunc + logger modular.Logger + healthy bool + healthMsg string encryptor *FieldEncryptor fieldProtector *ProtectedFieldManager tlsCfg KafkaTLSConfig diff --git a/module/pipeline_step_db_create_partition.go b/module/pipeline_step_db_create_partition.go new file mode 100644 index 00000000..c76f193a --- /dev/null +++ b/module/pipeline_step_db_create_partition.go @@ -0,0 +1,74 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" +) + +// DBCreatePartitionStep creates a PostgreSQL LIST partition for a given tenant value +// on all tables managed by a database.partitioned module. +type DBCreatePartitionStep struct { + name string + database string + tenantKey string // dot-path in PipelineContext to resolve the tenant value + app modular.Application + tmpl *TemplateEngine +} + +// NewDBCreatePartitionStepFactory returns a StepFactory for DBCreatePartitionStep. +func NewDBCreatePartitionStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + database, _ := config["database"].(string) + if database == "" { + return nil, fmt.Errorf("db_create_partition step %q: 'database' is required", name) + } + + tenantKey, _ := config["tenantKey"].(string) + if tenantKey == "" { + return nil, fmt.Errorf("db_create_partition step %q: 'tenantKey' is required", name) + } + + return &DBCreatePartitionStep{ + name: name, + database: database, + tenantKey: tenantKey, + app: app, + tmpl: NewTemplateEngine(), + }, nil + } +} + +func (s *DBCreatePartitionStep) Name() string { return s.name } + +func (s *DBCreatePartitionStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("db_create_partition step %q: no application context", s.name) + } + + svc, ok := s.app.SvcRegistry()[s.database] + if !ok { + return nil, fmt.Errorf("db_create_partition step %q: database service %q not found", s.name, s.database) + } + + mgr, ok := svc.(PartitionManager) + if !ok { + return nil, fmt.Errorf("db_create_partition step %q: service %q does not implement PartitionManager (use database.partitioned)", s.name, s.database) + } + + tenantVal := resolveBodyFrom(s.tenantKey, pc) + if tenantVal == nil { + return nil, fmt.Errorf("db_create_partition step %q: tenantKey %q resolved to nil in pipeline context", s.name, s.tenantKey) + } + tenantStr := fmt.Sprintf("%v", tenantVal) + + if err := mgr.EnsurePartition(ctx, tenantStr); err != nil { + return nil, fmt.Errorf("db_create_partition step %q: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{ + "tenant": tenantStr, + "partition": "created", + }}, nil +} diff --git a/module/pipeline_step_db_exec.go b/module/pipeline_step_db_exec.go index 588564a4..b9a2d85d 100644 --- a/module/pipeline_step_db_exec.go +++ b/module/pipeline_step_db_exec.go @@ -15,6 +15,7 @@ type DBExecStep struct { query string params []string ignoreError bool + tenantKey string // dot-path to resolve tenant value for automatic scoping app modular.Application tmpl *TemplateEngine } @@ -49,6 +50,7 @@ func NewDBExecStepFactory() StepFactory { } ignoreError, _ := config["ignore_error"].(bool) + tenantKey, _ := config["tenantKey"].(string) return &DBExecStep{ name: name, @@ -56,6 +58,7 @@ func NewDBExecStepFactory() StepFactory { query: query, params: params, ignoreError: ignoreError, + tenantKey: tenantKey, app: app, tmpl: NewTemplateEngine(), }, nil @@ -100,9 +103,26 @@ func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResul resolvedParams[i] = resolved } + // Apply automatic tenant scoping when tenantKey is configured. + query := s.query + if s.tenantKey != "" { + pkp, ok := svc.(PartitionKeyProvider) + if !ok { + return nil, fmt.Errorf("db_exec step %q: tenantKey requires database %q to implement PartitionKeyProvider (use database.partitioned)", s.name, s.database) + } + tenantVal := resolveBodyFrom(s.tenantKey, pc) + if tenantVal == nil { + return nil, fmt.Errorf("db_exec step %q: tenantKey %q resolved to nil in pipeline context", s.name, s.tenantKey) + } + tenantStr := fmt.Sprintf("%v", tenantVal) + nextParam := len(resolvedParams) + 1 + query = appendTenantFilter(query, pkp.PartitionKey(), nextParam) + resolvedParams = append(resolvedParams, tenantStr) + } + // Normalize SQL placeholders: users write $1,$2,$3 (PostgreSQL style), // engine converts to ? for SQLite automatically. - query := normalizePlaceholders(s.query, driver) + query = normalizePlaceholders(query, driver) // Execute statement result, err := db.Exec(query, resolvedParams...) diff --git a/module/pipeline_step_db_query.go b/module/pipeline_step_db_query.go index 43ed813a..d54569a8 100644 --- a/module/pipeline_step_db_query.go +++ b/module/pipeline_step_db_query.go @@ -25,13 +25,14 @@ type DBDriverProvider interface { // DBQueryStep executes a parameterized SQL SELECT against a named database service. type DBQueryStep struct { - name string - database string - query string - params []string - mode string // "list" or "single" - app modular.Application - tmpl *TemplateEngine + name string + database string + query string + params []string + mode string // "list" or "single" + tenantKey string // dot-path to resolve tenant value for automatic scoping + app modular.Application + tmpl *TemplateEngine } // NewDBQueryStepFactory returns a StepFactory that creates DBQueryStep instances. @@ -71,14 +72,17 @@ func NewDBQueryStepFactory() StepFactory { return nil, fmt.Errorf("db_query step %q: mode must be 'list' or 'single', got %q", name, mode) } + tenantKey, _ := config["tenantKey"].(string) + return &DBQueryStep{ - name: name, - database: database, - query: query, - params: params, - mode: mode, - app: app, - tmpl: NewTemplateEngine(), + name: name, + database: database, + query: query, + params: params, + mode: mode, + tenantKey: tenantKey, + app: app, + tmpl: NewTemplateEngine(), }, nil } } @@ -122,9 +126,26 @@ func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResu resolvedParams[i] = resolved } + // Apply automatic tenant scoping when tenantKey is configured. + query := s.query + if s.tenantKey != "" { + pkp, ok := svc.(PartitionKeyProvider) + if !ok { + return nil, fmt.Errorf("db_query step %q: tenantKey requires database %q to implement PartitionKeyProvider (use database.partitioned)", s.name, s.database) + } + tenantVal := resolveBodyFrom(s.tenantKey, pc) + if tenantVal == nil { + return nil, fmt.Errorf("db_query step %q: tenantKey %q resolved to nil in pipeline context", s.name, s.tenantKey) + } + tenantStr := fmt.Sprintf("%v", tenantVal) + nextParam := len(resolvedParams) + 1 + query = appendTenantFilter(query, pkp.PartitionKey(), nextParam) + resolvedParams = append(resolvedParams, tenantStr) + } + // Normalize SQL placeholders: users write $1,$2,$3 (PostgreSQL style), // engine converts to ? for SQLite automatically. - query := normalizePlaceholders(s.query, driver) + query = normalizePlaceholders(query, driver) // Execute query rows, err := db.Query(query, resolvedParams...) diff --git a/module/pipeline_step_db_query_test.go b/module/pipeline_step_db_query_test.go index acef121d..c3d19825 100644 --- a/module/pipeline_step_db_query_test.go +++ b/module/pipeline_step_db_query_test.go @@ -21,7 +21,7 @@ type testDBDriverProvider struct { driver string } -func (p *testDBDriverProvider) DB() *sql.DB { return p.db } +func (p *testDBDriverProvider) DB() *sql.DB { return p.db } func (p *testDBDriverProvider) DriverName() string { return p.driver } // mockAppWithDBDriver creates a MockApplication with a named database that reports its driver diff --git a/module/pipeline_step_db_tenant_test.go b/module/pipeline_step_db_tenant_test.go new file mode 100644 index 00000000..d1d0ae17 --- /dev/null +++ b/module/pipeline_step_db_tenant_test.go @@ -0,0 +1,185 @@ +package module + +import ( + "context" + "database/sql" + "testing" +) + +// testPartitionKeyProvider wraps a *sql.DB to satisfy PartitionKeyProvider. +type testPartitionKeyProvider struct { + db *sql.DB + partitionKey string +} + +func (p *testPartitionKeyProvider) DB() *sql.DB { return p.db } +func (p *testPartitionKeyProvider) PartitionKey() string { return p.partitionKey } + +// mockAppWithPartitionDB creates a MockApplication with a PartitionKeyProvider service. +func mockAppWithPartitionDB(name string, db *sql.DB, partitionKey string) *MockApplication { + app := NewMockApplication() + app.Services[name] = &testPartitionKeyProvider{db: db, partitionKey: partitionKey} + return app +} + +// setupTenantTestDB creates an in-memory SQLite database with tenant-scoped test data. +func setupTenantTestDB(t *testing.T) *sql.DB { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open db: %v", err) + } + t.Cleanup(func() { db.Close() }) + + _, err = db.Exec(` + CREATE TABLE forms ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL, + slug TEXT NOT NULL + ); + INSERT INTO forms (id, tenant_id, slug) VALUES ('f1', 'org-alpha', 'contact'); + INSERT INTO forms (id, tenant_id, slug) VALUES ('f2', 'org-alpha', 'feedback'); + INSERT INTO forms (id, tenant_id, slug) VALUES ('f3', 'org-beta', 'signup'); + `) + if err != nil { + t.Fatalf("setup tenant db: %v", err) + } + return db +} + +func TestDBQueryStep_TenantKey_AutoFilter(t *testing.T) { + db := setupTenantTestDB(t) + app := mockAppWithPartitionDB("part-db", db, "tenant_id") + + factory := NewDBQueryStepFactory() + step, err := factory("list-forms", map[string]any{ + "database": "part-db", + "query": "SELECT id, slug FROM forms", + "tenantKey": "steps.auth.tenant_id", + "mode": "list", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant_id": "org-alpha"}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + rows, ok := result.Output["rows"].([]map[string]any) + if !ok { + t.Fatal("expected rows in output") + } + if len(rows) != 2 { + t.Errorf("expected 2 rows for org-alpha, got %d", len(rows)) + } +} + +func TestDBQueryStep_TenantKey_NoPartitionKeyProvider(t *testing.T) { + db := setupTenantTestDB(t) + // Use a plain DBProvider (no PartitionKeyProvider) + app := mockAppWithDB("plain-db", db) + + factory := NewDBQueryStepFactory() + step, err := factory("list-forms", map[string]any{ + "database": "plain-db", + "query": "SELECT id FROM forms", + "tenantKey": "steps.auth.tenant_id", + "mode": "list", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant_id": "org-alpha"}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when database does not implement PartitionKeyProvider") + } +} + +func TestDBQueryStep_TenantKey_NilTenantValue(t *testing.T) { + db := setupTenantTestDB(t) + app := mockAppWithPartitionDB("part-db", db, "tenant_id") + + factory := NewDBQueryStepFactory() + step, err := factory("list-forms", map[string]any{ + "database": "part-db", + "query": "SELECT id FROM forms", + "tenantKey": "steps.auth.tenant_id", + "mode": "list", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + // Pipeline context without auth.tenant_id set + pc := NewPipelineContext(nil, nil) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when tenant value is nil") + } +} + +func TestDBExecStep_TenantKey_AutoFilter(t *testing.T) { + db := setupTenantTestDB(t) + app := mockAppWithPartitionDB("part-db", db, "tenant_id") + + factory := NewDBExecStepFactory() + step, err := factory("update-form", map[string]any{ + "database": "part-db", + "query": "UPDATE forms SET slug = $1 WHERE id = $2", + "params": []any{"new-slug", "f1"}, + "tenantKey": "steps.auth.tenant_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant_id": "org-alpha"}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + affected, _ := result.Output["affected_rows"].(int64) + if affected != 1 { + t.Errorf("expected 1 affected row, got %d", affected) + } +} + +func TestAppendTenantFilter_WithWhereClause(t *testing.T) { + query := "SELECT * FROM forms WHERE active = true" + result := appendTenantFilter(query, "tenant_id", 1) + expected := "SELECT * FROM forms WHERE active = true AND tenant_id = $1" + if result != expected { + t.Errorf("expected %q, got %q", expected, result) + } +} + +func TestAppendTenantFilter_WithoutWhereClause(t *testing.T) { + query := "SELECT * FROM forms" + result := appendTenantFilter(query, "tenant_id", 2) + expected := "SELECT * FROM forms WHERE tenant_id = $2" + if result != expected { + t.Errorf("expected %q, got %q", expected, result) + } +} + +func TestAppendTenantFilter_TrailingWhitespace(t *testing.T) { + query := "SELECT * FROM forms ORDER BY created_at " + result := appendTenantFilter(query, "tenant_id", 1) + expected := "SELECT * FROM forms ORDER BY created_at WHERE tenant_id = $1" + if result != expected { + t.Errorf("expected %q, got %q", expected, result) + } +} diff --git a/module/pipeline_step_sandbox_exec.go b/module/pipeline_step_sandbox_exec.go index d63a4a88..5011e01b 100644 --- a/module/pipeline_step_sandbox_exec.go +++ b/module/pipeline_step_sandbox_exec.go @@ -163,15 +163,15 @@ func (s *SandboxExecStep) buildSandboxConfig() sandbox.SandboxConfig { } case "standard": cfg = sandbox.SandboxConfig{ - Image: s.image, - MemoryLimit: 256 * 1024 * 1024, - CPULimit: 0.5, - NetworkMode: "bridge", - CapDrop: []string{"NET_ADMIN", "SYS_ADMIN", "SYS_PTRACE", "SETUID", "SETGID"}, - CapAdd: []string{"NET_BIND_SERVICE"}, + Image: s.image, + MemoryLimit: 256 * 1024 * 1024, + CPULimit: 0.5, + NetworkMode: "bridge", + CapDrop: []string{"NET_ADMIN", "SYS_ADMIN", "SYS_PTRACE", "SETUID", "SETGID"}, + CapAdd: []string{"NET_BIND_SERVICE"}, NoNewPrivileges: true, - PidsLimit: 64, - Timeout: 5 * time.Minute, + PidsLimit: 64, + Timeout: 5 * time.Minute, } default: // "strict" cfg = sandbox.DefaultSecureSandboxConfig(s.image) diff --git a/module/pipeline_step_token_revoke_test.go b/module/pipeline_step_token_revoke_test.go index b3b79a84..d83a8ea3 100644 --- a/module/pipeline_step_token_revoke_test.go +++ b/module/pipeline_step_token_revoke_test.go @@ -296,4 +296,3 @@ var _ TokenBlacklist = (*TokenBlacklistModule)(nil) // Compile-time check: mockBlacklist satisfies TokenBlacklist. var _ TokenBlacklist = (*mockBlacklist)(nil) - diff --git a/module/sql_placeholders.go b/module/sql_placeholders.go index 7dea9667..0f20a926 100644 --- a/module/sql_placeholders.go +++ b/module/sql_placeholders.go @@ -82,7 +82,18 @@ func normalizePlaceholders(query, driver string) string { return result } -// validatePlaceholderCount checks that the number of params matches the +// appendTenantFilter appends "AND = $N" to a SQL query, where N is +// the next positional parameter index. The function handles queries with and +// without an existing WHERE clause. +func appendTenantFilter(query, column string, paramIndex int) string { + query = strings.TrimRight(query, " \t\n\r;") + upper := strings.ToUpper(query) + if strings.Contains(upper, " WHERE ") { + return fmt.Sprintf("%s AND %s = $%d", query, column, paramIndex) + } + return fmt.Sprintf("%s WHERE %s = $%d", query, column, paramIndex) +} + // placeholder count in the query. Returns an error if there's a mismatch. func validatePlaceholderCount(query, driver string, paramCount int) error { if paramCount == 0 { diff --git a/plugins/pipelinesteps/plugin.go b/plugins/pipelinesteps/plugin.go index 03dbce15..31b45842 100644 --- a/plugins/pipelinesteps/plugin.go +++ b/plugins/pipelinesteps/plugin.go @@ -66,6 +66,7 @@ func New() *Plugin { "step.db_query", "step.db_exec", "step.db_query_cached", + "step.db_create_partition", "step.json_response", "step.raw_response", "step.workflow_call", @@ -127,6 +128,7 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { "step.db_query": wrapStepFactory(module.NewDBQueryStepFactory()), "step.db_exec": wrapStepFactory(module.NewDBExecStepFactory()), "step.db_query_cached": wrapStepFactory(module.NewDBQueryCachedStepFactory()), + "step.db_create_partition": wrapStepFactory(module.NewDBCreatePartitionStepFactory()), "step.json_response": wrapStepFactory(module.NewJSONResponseStepFactory()), "step.raw_response": wrapStepFactory(module.NewRawResponseStepFactory()), "step.validate_path_param": wrapStepFactory(module.NewValidatePathParamStepFactory()), @@ -137,8 +139,8 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { "step.foreach": wrapStepFactory(module.NewForEachStepFactory(func() *module.StepRegistry { return p.concreteStepRegistry })), - "step.webhook_verify": wrapStepFactory(module.NewWebhookVerifyStepFactory()), - "step.base64_decode": wrapStepFactory(module.NewBase64DecodeStepFactory()), + "step.webhook_verify": wrapStepFactory(module.NewWebhookVerifyStepFactory()), + "step.base64_decode": wrapStepFactory(module.NewBase64DecodeStepFactory()), "step.cache_get": wrapStepFactory(module.NewCacheGetStepFactory()), "step.cache_set": wrapStepFactory(module.NewCacheSetStepFactory()), "step.cache_delete": wrapStepFactory(module.NewCacheDeleteStepFactory()), @@ -153,12 +155,12 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { "step.resilient_circuit_breaker": wrapStepFactory(module.NewResilienceCircuitBreakerStepFactory(func() *module.StepRegistry { return p.concreteStepRegistry })), - "step.s3_upload": wrapStepFactory(module.NewS3UploadStepFactory()), - "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), - "step.token_revoke": wrapStepFactory(module.NewTokenRevokeStepFactory()), - "step.field_reencrypt": wrapStepFactory(module.NewFieldReencryptStepFactory()), - "step.sandbox_exec": wrapStepFactory(module.NewSandboxExecStepFactory()), - "step.http_proxy": wrapStepFactory(module.NewHTTPProxyStepFactory()), + "step.s3_upload": wrapStepFactory(module.NewS3UploadStepFactory()), + "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), + "step.token_revoke": wrapStepFactory(module.NewTokenRevokeStepFactory()), + "step.field_reencrypt": wrapStepFactory(module.NewFieldReencryptStepFactory()), + "step.sandbox_exec": wrapStepFactory(module.NewSandboxExecStepFactory()), + "step.http_proxy": wrapStepFactory(module.NewHTTPProxyStepFactory()), } } diff --git a/plugins/pipelinesteps/plugin_test.go b/plugins/pipelinesteps/plugin_test.go index e06c3819..69ac2b76 100644 --- a/plugins/pipelinesteps/plugin_test.go +++ b/plugins/pipelinesteps/plugin_test.go @@ -45,6 +45,7 @@ func TestStepFactories(t *testing.T) { "step.db_query", "step.db_exec", "step.db_query_cached", + "step.db_create_partition", "step.json_response", "step.raw_response", "step.validate_path_param", diff --git a/plugins/storage/plugin.go b/plugins/storage/plugin.go index 37afa2cf..33b9a40b 100644 --- a/plugins/storage/plugin.go +++ b/plugins/storage/plugin.go @@ -40,6 +40,7 @@ func New() *Plugin { "storage.sqlite", "storage.artifact", "database.workflow", + "database.partitioned", "persistence.store", "cache.redis", }, @@ -149,6 +150,32 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { } return module.NewWorkflowDatabase(name, dbConfig) }, + "database.partitioned": func(name string, cfg map[string]any) modular.Module { + partCfg := module.PartitionedDatabaseConfig{} + if driver, ok := cfg["driver"].(string); ok { + partCfg.Driver = driver + } + if dsn, ok := cfg["dsn"].(string); ok { + partCfg.DSN = dsn + } + if maxOpen, ok := cfg["maxOpenConns"].(float64); ok { + partCfg.MaxOpenConns = int(maxOpen) + } + if maxIdle, ok := cfg["maxIdleConns"].(float64); ok { + partCfg.MaxIdleConns = int(maxIdle) + } + if pk, ok := cfg["partitionKey"].(string); ok { + partCfg.PartitionKey = pk + } + if tables, ok := cfg["tables"].([]any); ok { + for _, t := range tables { + if s, ok := t.(string); ok { + partCfg.Tables = append(partCfg.Tables, s) + } + } + } + return module.NewPartitionedDatabase(name, partCfg) + }, "persistence.store": func(name string, cfg map[string]any) modular.Module { dbServiceName := "database" if n, ok := cfg["database"].(string); ok && n != "" { @@ -312,6 +339,23 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema { }, DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5}, }, + { + Type: "database.partitioned", + Label: "Partitioned Database", + Category: "database", + Description: "PostgreSQL LIST-partitioned database for multi-tenant data isolation. Automatically manages per-tenant partitions and enables automatic tenant scoping in step.db_query and step.db_exec.", + Inputs: []schema.ServiceIODef{{Name: "query", Type: "SQL", Description: "SQL query to execute"}}, + Outputs: []schema.ServiceIODef{{Name: "database", Type: "sql.DB", Description: "SQL database connection pool"}}, + ConfigFields: []schema.ConfigFieldDef{ + {Key: "driver", Label: "Driver", Type: schema.FieldTypeSelect, Options: []string{"pgx", "pgx/v5", "postgres"}, Required: true, Description: "PostgreSQL database driver"}, + {Key: "dsn", Label: "DSN", Type: schema.FieldTypeString, Required: true, Description: "Data source name / connection string", Placeholder: "postgres://user:pass@localhost/db?sslmode=disable", Sensitive: true}, //nolint:gosec // G101: placeholder DSN example in schema documentation + {Key: "partitionKey", Label: "Partition Key", Type: schema.FieldTypeString, Required: true, Description: "Column name used for LIST partitioning (e.g. tenant_id)", Placeholder: "tenant_id"}, + {Key: "tables", Label: "Tables", Type: schema.FieldTypeArray, ArrayItemType: "string", Required: true, Description: "Tables to manage LIST partitions for", Placeholder: "forms"}, + {Key: "maxOpenConns", Label: "Max Open Connections", Type: schema.FieldTypeNumber, DefaultValue: 25, Description: "Maximum number of open database connections"}, + {Key: "maxIdleConns", Label: "Max Idle Connections", Type: schema.FieldTypeNumber, DefaultValue: 5, Description: "Maximum number of idle connections in the pool"}, + }, + DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5}, + }, { Type: "persistence.store", Label: "Persistence Store", diff --git a/plugins/storage/plugin_test.go b/plugins/storage/plugin_test.go index a6c6a92d..83080a23 100644 --- a/plugins/storage/plugin_test.go +++ b/plugins/storage/plugin_test.go @@ -21,8 +21,8 @@ func TestPluginManifest(t *testing.T) { if m.Name != "storage" { t.Errorf("expected name %q, got %q", "storage", m.Name) } - if len(m.ModuleTypes) != 8 { - t.Errorf("expected 8 module types, got %d", len(m.ModuleTypes)) + if len(m.ModuleTypes) != 9 { + t.Errorf("expected 9 module types, got %d", len(m.ModuleTypes)) } if len(m.StepTypes) != 4 { t.Errorf("expected 4 step types, got %d", len(m.StepTypes)) @@ -134,8 +134,8 @@ func TestStepFactories(t *testing.T) { func TestModuleSchemas(t *testing.T) { p := New() schemas := p.ModuleSchemas() - if len(schemas) != 8 { - t.Fatalf("expected 8 module schemas, got %d", len(schemas)) + if len(schemas) != 9 { + t.Fatalf("expected 9 module schemas, got %d", len(schemas)) } types := map[string]bool{} @@ -144,8 +144,8 @@ func TestModuleSchemas(t *testing.T) { } expectedTypes := []string{ "storage.s3", "storage.local", "storage.gcs", - "storage.sqlite", "database.workflow", "persistence.store", - "cache.redis", + "storage.sqlite", "database.workflow", "database.partitioned", + "persistence.store", "cache.redis", } for _, expected := range expectedTypes { if !types[expected] { diff --git a/schema/module_schema.go b/schema/module_schema.go index 83c670e3..53c72738 100644 --- a/schema/module_schema.go +++ b/schema/module_schema.go @@ -512,6 +512,24 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5}, }) + r.Register(&ModuleSchema{ + Type: "database.partitioned", + Label: "Partitioned Database", + Category: "database", + Description: "PostgreSQL LIST-partitioned database for multi-tenant data isolation. Automatically manages per-tenant partitions and enables automatic tenant scoping in step.db_query and step.db_exec.", + Inputs: []ServiceIODef{{Name: "query", Type: "SQL", Description: "SQL query to execute"}}, + Outputs: []ServiceIODef{{Name: "database", Type: "sql.DB", Description: "SQL database connection pool"}}, + ConfigFields: []ConfigFieldDef{ + {Key: "driver", Label: "Driver", Type: FieldTypeSelect, Options: []string{"pgx", "pgx/v5", "postgres"}, Required: true, Description: "PostgreSQL database driver"}, + {Key: "dsn", Label: "DSN", Type: FieldTypeString, Required: true, Description: "Data source name / connection string", Placeholder: "postgres://user:pass@localhost/db?sslmode=disable", Sensitive: true}, //nolint:gosec // G101: placeholder DSN example in schema documentation + {Key: "partitionKey", Label: "Partition Key", Type: FieldTypeString, Required: true, Description: "Column name used for LIST partitioning (e.g. tenant_id)", Placeholder: "tenant_id"}, + {Key: "tables", Label: "Tables", Type: FieldTypeArray, ArrayItemType: "string", Required: true, Description: "Tables to manage LIST partitions for", Placeholder: "forms"}, + {Key: "maxOpenConns", Label: "Max Open Connections", Type: FieldTypeNumber, DefaultValue: 25, Description: "Maximum number of open database connections"}, + {Key: "maxIdleConns", Label: "Max Idle Connections", Type: FieldTypeNumber, DefaultValue: 5, Description: "Maximum number of idle connections in the pool"}, + }, + DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5}, + }) + r.Register(&ModuleSchema{ Type: "persistence.store", Label: "Persistence Store", @@ -1002,6 +1020,7 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { {Key: "query", Label: "SQL Query", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL SELECT query (use ? for placeholders, no template expressions allowed)", Placeholder: "SELECT id, name FROM companies WHERE id = ?"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for ? placeholders in query"}, {Key: "mode", Label: "Mode", Type: FieldTypeSelect, Options: []string{"list", "single"}, DefaultValue: "list", Description: "Result mode: 'list' returns rows/count, 'single' returns row/found"}, + {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping (requires database.partitioned)", Placeholder: "auth.tenant_id"}, }, }) @@ -1033,6 +1052,20 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of the database service (must implement DBProvider)", Placeholder: "admin-db", InheritFrom: "dependency.name"}, {Key: "query", Label: "SQL Statement", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL INSERT/UPDATE/DELETE statement (use ? for placeholders)", Placeholder: "INSERT INTO companies (id, name) VALUES (?, ?)"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for ? placeholders"}, + {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping (requires database.partitioned)", Placeholder: "auth.tenant_id"}, + }, + }) + + r.Register(&ModuleSchema{ + Type: "step.db_create_partition", + Label: "Create Database Partition", + Category: "pipeline", + Description: "Creates a PostgreSQL LIST partition for a tenant on all tables managed by a database.partitioned module. Idempotent — safe to call when a partition may already exist.", + Inputs: []ServiceIODef{{Name: "context", Type: "PipelineContext", Description: "Pipeline context for tenant key resolution"}}, + Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Partition creation result with tenant and partition fields"}}, + ConfigFields: []ConfigFieldDef{ + {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of a database.partitioned service", Placeholder: "db", InheritFrom: "dependency.name"}, + {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Required: true, Description: "Dot-path in pipeline context to resolve the tenant value (e.g. the new tenant's ID)", Placeholder: "body.tenant_id"}, }, }) diff --git a/schema/schema.go b/schema/schema.go index 81d70764..38bb3e46 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -152,6 +152,7 @@ var coreModuleTypes = []string{ "cache.modular", "config.provider", "data.transformer", + "database.partitioned", "database.workflow", "dlq.service", "dynamic.component", @@ -209,6 +210,7 @@ var coreModuleTypes = []string{ "step.circuit_breaker", "step.conditional", "step.constraint_check", + "step.db_create_partition", "step.db_exec", "step.db_query", "step.db_query_cached", @@ -612,7 +614,7 @@ func GenerateWorkflowSchema() *Schema { Type: "object", Description: "Workflow handler configurations keyed by workflow type (e.g. http, messaging, statemachine, scheduler, integration)", }, - "triggers": triggerSchema, + "triggers": triggerSchema, "pipelines": buildPipelinesSchema(pipelineSchema), "imports": { Type: "array", diff --git a/schema/snippets_export.go b/schema/snippets_export.go index bfd499ec..99ac0b9d 100644 --- a/schema/snippets_export.go +++ b/schema/snippets_export.go @@ -31,19 +31,19 @@ type vscodeSnippet struct { // jetbrainsTemplateSet is the root XML element for JetBrains live templates. type jetbrainsTemplateSet struct { - XMLName xml.Name `xml:"templateSet"` - Group string `xml:"group,attr"` + XMLName xml.Name `xml:"templateSet"` + Group string `xml:"group,attr"` Templates []jetbrainsTemplate `xml:"template"` } type jetbrainsTemplate struct { - Name string `xml:"name,attr"` - Value string `xml:"value,attr"` - Description string `xml:"description,attr"` - ToReformat bool `xml:"toReformat,attr"` - ToShortenFQ bool `xml:"toShortenFQNames,attr"` - Variables []jetbrainsVariable `xml:"variable,omitempty"` - Contexts []jetbrainsContext `xml:"context"` + Name string `xml:"name,attr"` + Value string `xml:"value,attr"` + Description string `xml:"description,attr"` + ToReformat bool `xml:"toReformat,attr"` + ToShortenFQ bool `xml:"toShortenFQNames,attr"` + Variables []jetbrainsVariable `xml:"variable,omitempty"` + Contexts []jetbrainsContext `xml:"context"` } type jetbrainsVariable struct { From 09f44e3bcad0e06686958ae0244c079ca9ec0be9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 04:10:48 +0000 Subject: [PATCH 3/5] security: validate partition key identifier in db_query and db_exec steps Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- module/pipeline_step_db_exec.go | 6 +++++- module/pipeline_step_db_query.go | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/module/pipeline_step_db_exec.go b/module/pipeline_step_db_exec.go index b9a2d85d..6ee6ff17 100644 --- a/module/pipeline_step_db_exec.go +++ b/module/pipeline_step_db_exec.go @@ -110,13 +110,17 @@ func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResul if !ok { return nil, fmt.Errorf("db_exec step %q: tenantKey requires database %q to implement PartitionKeyProvider (use database.partitioned)", s.name, s.database) } + partKey := pkp.PartitionKey() + if err := validateIdentifier(partKey); err != nil { + return nil, fmt.Errorf("db_exec step %q: invalid partition key %q: %w", s.name, partKey, err) + } tenantVal := resolveBodyFrom(s.tenantKey, pc) if tenantVal == nil { return nil, fmt.Errorf("db_exec step %q: tenantKey %q resolved to nil in pipeline context", s.name, s.tenantKey) } tenantStr := fmt.Sprintf("%v", tenantVal) nextParam := len(resolvedParams) + 1 - query = appendTenantFilter(query, pkp.PartitionKey(), nextParam) + query = appendTenantFilter(query, partKey, nextParam) resolvedParams = append(resolvedParams, tenantStr) } diff --git a/module/pipeline_step_db_query.go b/module/pipeline_step_db_query.go index d54569a8..5fe9b2b4 100644 --- a/module/pipeline_step_db_query.go +++ b/module/pipeline_step_db_query.go @@ -133,13 +133,17 @@ func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResu if !ok { return nil, fmt.Errorf("db_query step %q: tenantKey requires database %q to implement PartitionKeyProvider (use database.partitioned)", s.name, s.database) } + partKey := pkp.PartitionKey() + if err := validateIdentifier(partKey); err != nil { + return nil, fmt.Errorf("db_query step %q: invalid partition key %q: %w", s.name, partKey, err) + } tenantVal := resolveBodyFrom(s.tenantKey, pc) if tenantVal == nil { return nil, fmt.Errorf("db_query step %q: tenantKey %q resolved to nil in pipeline context", s.name, s.tenantKey) } tenantStr := fmt.Sprintf("%v", tenantVal) nextParam := len(resolvedParams) + 1 - query = appendTenantFilter(query, pkp.PartitionKey(), nextParam) + query = appendTenantFilter(query, partKey, nextParam) resolvedParams = append(resolvedParams, tenantStr) } From d4402788da6f011b01544961083bff30bd11ba85 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 05:18:09 +0000 Subject: [PATCH 4/5] feat: address PR review feedback and add partition features - Fix appendTenantFilter to insert tenant predicate before ORDER BY/ LIMIT/GROUP BY/HAVING/OFFSET clauses instead of blindly appending - Reject tenantKey for INSERT statements in step.db_exec with clear error - Remove "postgresql" from isSupportedPartitionDriver for consistency - Fix schema tenantKey placeholders to use "steps." prefix consistently - Add partitionType config (list/range) with RANGE partition DDL support - Add partitionNameFormat config ({table}_{tenant}, {tenant}_{table}, etc.) - Add PartitionTableName method to PartitionKeyProvider interface - Add sourceTable/sourceColumn config for auto-partition sync - Add SyncPartitionsFromSource method to PartitionManager interface - Add step.db_sync_partitions for triggering partition sync from source - Add comprehensive tests for all new functionality Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- cmd/wfctl/type_registry.go | 7 +- module/database_partitioned.go | 179 ++++++++++++++++++--- module/database_partitioned_test.go | 171 +++++++++++++++++++- module/pipeline_step_db_exec.go | 5 + module/pipeline_step_db_sync_partitions.go | 59 +++++++ module/pipeline_step_db_tenant_test.go | 72 ++++++++- module/sql_placeholders.go | 68 +++++++- plugins/pipelinesteps/plugin.go | 2 + plugins/pipelinesteps/plugin_test.go | 1 + plugins/storage/plugin.go | 24 ++- schema/module_schema.go | 34 ++-- schema/schema.go | 1 + 12 files changed, 571 insertions(+), 52 deletions(-) create mode 100644 module/pipeline_step_db_sync_partitions.go diff --git a/cmd/wfctl/type_registry.go b/cmd/wfctl/type_registry.go index 43f7930b..ec59006d 100644 --- a/cmd/wfctl/type_registry.go +++ b/cmd/wfctl/type_registry.go @@ -59,7 +59,7 @@ func KnownModuleTypes() map[string]ModuleTypeInfo { Type: "database.partitioned", Plugin: "storage", Stateful: true, - ConfigKeys: []string{"driver", "dsn", "partitionKey", "tables", "maxOpenConns", "maxIdleConns"}, + ConfigKeys: []string{"driver", "dsn", "partitionKey", "tables", "partitionType", "partitionNameFormat", "sourceTable", "sourceColumn", "maxOpenConns", "maxIdleConns"}, }, "persistence.store": { Type: "persistence.store", @@ -607,6 +607,11 @@ func KnownStepTypes() map[string]StepTypeInfo { Plugin: "pipelinesteps", ConfigKeys: []string{"database", "tenantKey"}, }, + "step.db_sync_partitions": { + Type: "step.db_sync_partitions", + Plugin: "pipelinesteps", + ConfigKeys: []string{"database"}, + }, "step.json_response": { Type: "step.json_response", Plugin: "pipelinesteps", diff --git a/module/database_partitioned.go b/module/database_partitioned.go index 1e0b1dc6..fe02eec1 100644 --- a/module/database_partitioned.go +++ b/module/database_partitioned.go @@ -14,20 +14,35 @@ import ( // validPartitionValue matches safe LIST partition values (alphanumeric, hyphens, underscores, dots). var validPartitionValue = regexp.MustCompile(`^[a-zA-Z0-9_.\-]+$`) +// Partition types supported by PostgreSQL. +const ( + PartitionTypeList = "list" + PartitionTypeRange = "range" +) + // PartitionKeyProvider is optionally implemented by database modules that support -// LIST partitioning. Steps can use PartitionKey() to determine the column name -// for automatic tenant scoping. +// partitioning. Steps can use PartitionKey() to determine the column name +// for automatic tenant scoping, and PartitionTableName() to resolve +// tenant-specific partition table names at query time. type PartitionKeyProvider interface { DBProvider PartitionKey() string + // PartitionTableName resolves the partition table name for a given parent + // table and tenant value, using the configured partitionNameFormat. + // Returns the parent table name unchanged when no format is configured. + PartitionTableName(parentTable, tenantValue string) string } // PartitionManager is optionally implemented by database modules that support -// runtime creation of LIST partitions. The EnsurePartition method is idempotent — +// runtime creation of partitions. The EnsurePartition method is idempotent — // if the partition already exists the call succeeds without error. type PartitionManager interface { PartitionKeyProvider EnsurePartition(ctx context.Context, tenantValue string) error + // SyncPartitionsFromSource queries the configured sourceTable for all + // distinct tenant values and ensures that partitions exist for each one. + // No-ops if sourceTable is not configured. + SyncPartitionsFromSource(ctx context.Context) error } // PartitionedDatabaseConfig holds configuration for the database.partitioned module. @@ -38,9 +53,25 @@ type PartitionedDatabaseConfig struct { MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"` PartitionKey string `json:"partitionKey" yaml:"partitionKey"` Tables []string `json:"tables" yaml:"tables"` + // PartitionType is "list" (default) or "range". + // LIST partitions are created with FOR VALUES IN ('value'). + // RANGE partitions are created with FOR VALUES FROM ('value') TO ('value_next'). + PartitionType string `json:"partitionType" yaml:"partitionType"` + // PartitionNameFormat is a template for generating partition table names. + // Supports {table} and {tenant} placeholders. + // Default: "{table}_{tenant}" (e.g. forms_org_alpha). + PartitionNameFormat string `json:"partitionNameFormat" yaml:"partitionNameFormat"` + // SourceTable is the table that contains all tenant IDs. + // When set, SyncPartitionsFromSource queries this table for all distinct + // values in the partition key column and ensures partitions exist. + // Example: "tenants" — will query "SELECT DISTINCT tenant_id FROM tenants". + SourceTable string `json:"sourceTable" yaml:"sourceTable"` + // SourceColumn overrides the column queried in sourceTable. + // Defaults to PartitionKey if empty. + SourceColumn string `json:"sourceColumn" yaml:"sourceColumn"` } -// PartitionedDatabase wraps WorkflowDatabase and adds PostgreSQL LIST partition +// PartitionedDatabase wraps WorkflowDatabase and adds PostgreSQL partition // management. It satisfies DBProvider, DBDriverProvider, PartitionKeyProvider, // and PartitionManager. type PartitionedDatabase struct { @@ -58,6 +89,12 @@ func NewPartitionedDatabase(name string, cfg PartitionedDatabaseConfig) *Partiti MaxOpenConns: cfg.MaxOpenConns, MaxIdleConns: cfg.MaxIdleConns, } + if cfg.PartitionType == "" { + cfg.PartitionType = PartitionTypeList + } + if cfg.PartitionNameFormat == "" { + cfg.PartitionNameFormat = "{table}_{tenant}" + } return &PartitionedDatabase{ name: name, config: cfg, @@ -109,11 +146,31 @@ func (p *PartitionedDatabase) DriverName() string { return p.config.Driver } -// PartitionKey returns the column name used for LIST partitioning (satisfies PartitionKeyProvider). +// PartitionKey returns the column name used for partitioning (satisfies PartitionKeyProvider). func (p *PartitionedDatabase) PartitionKey() string { return p.config.PartitionKey } +// PartitionType returns the partition type ("list" or "range"). +func (p *PartitionedDatabase) PartitionType() string { + return p.config.PartitionType +} + +// PartitionNameFormat returns the configured partition name format template. +func (p *PartitionedDatabase) PartitionNameFormat() string { + return p.config.PartitionNameFormat +} + +// PartitionTableName resolves the partition table name for a given parent +// table and tenant value using the configured partitionNameFormat. +func (p *PartitionedDatabase) PartitionTableName(parentTable, tenantValue string) string { + suffix := sanitizePartitionSuffix(tenantValue) + name := p.config.PartitionNameFormat + name = strings.ReplaceAll(name, "{table}", parentTable) + name = strings.ReplaceAll(name, "{tenant}", suffix) + return name +} + // Tables returns the list of tables managed by this partitioned database. func (p *PartitionedDatabase) Tables() []string { result := make([]string, len(p.config.Tables)) @@ -121,10 +178,13 @@ func (p *PartitionedDatabase) Tables() []string { return result } -// EnsurePartition creates a LIST partition for the given tenant value on all +// EnsurePartition creates a partition for the given tenant value on all // configured tables. The operation is idempotent — IF NOT EXISTS prevents errors // when the partition already exists. // +// For LIST partitions: CREATE TABLE IF NOT EXISTS PARTITION OF FOR VALUES IN ('') +// For RANGE partitions: CREATE TABLE IF NOT EXISTS PARTITION OF
FOR VALUES FROM ('') TO ('\x00') +// // Only PostgreSQL (pgx, pgx/v5, postgres) is supported. The method validates // the tenant value and table/column names to prevent SQL injection. func (p *PartitionedDatabase) EnsurePartition(ctx context.Context, tenantValue string) error { @@ -133,7 +193,7 @@ func (p *PartitionedDatabase) EnsurePartition(ctx context.Context, tenantValue s } if !isSupportedPartitionDriver(p.config.Driver) { - return fmt.Errorf("partitioned database %q: driver %q does not support LIST partitioning (use pgx, pgx/v5, or postgres)", p.name, p.config.Driver) + return fmt.Errorf("partitioned database %q: driver %q does not support partitioning (use pgx, pgx/v5, or postgres)", p.name, p.config.Driver) } if err := validateIdentifier(p.config.PartitionKey); err != nil { @@ -153,22 +213,39 @@ func (p *PartitionedDatabase) EnsurePartition(ctx context.Context, tenantValue s return fmt.Errorf("partitioned database %q: invalid table name: %w", p.name, err) } - // Sanitize the partition suffix: replace hyphens and dots with underscores. - partitionSuffix := sanitizePartitionSuffix(tenantValue) - partitionName := table + "_" + partitionSuffix + partitionName := p.PartitionTableName(table, tenantValue) - // Use IF NOT EXISTS to make this idempotent. - // The tenant value is embedded as a quoted literal (single-quoted). + // Validate the computed partition name is a safe identifier. + if err := validateIdentifier(partitionName); err != nil { + return fmt.Errorf("partitioned database %q: invalid partition name %q: %w", p.name, partitionName, err) + } + + var ddl string // We have already validated tenantValue against validPartitionValue so // it cannot contain single-quote characters. - sql := fmt.Sprintf( - "CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES IN ('%s')", - partitionName, - table, - strings.ReplaceAll(tenantValue, "'", ""), - ) - - if _, err := db.ExecContext(ctx, sql); err != nil { + safeValue := strings.ReplaceAll(tenantValue, "'", "") + + switch p.config.PartitionType { + case PartitionTypeList: + ddl = fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES IN ('%s')", + partitionName, table, safeValue, + ) + case PartitionTypeRange: + // RANGE partition: from the tenant value (inclusive) to the same + // value followed by a null byte (exclusive). This creates a + // single-value range partition, which is the closest equivalent + // to LIST semantics for RANGE-partitioned tables. + ddl = fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES FROM ('%s') TO ('%s\\x00')", + partitionName, table, safeValue, safeValue, + ) + default: + return fmt.Errorf("partitioned database %q: unsupported partition type %q (use %q or %q)", + p.name, p.config.PartitionType, PartitionTypeList, PartitionTypeRange) + } + + if _, err := db.ExecContext(ctx, ddl); err != nil { return fmt.Errorf("partitioned database %q: failed to create partition %q for table %q: %w", p.name, partitionName, table, err) } @@ -177,10 +254,70 @@ func (p *PartitionedDatabase) EnsurePartition(ctx context.Context, tenantValue s return nil } +// SyncPartitionsFromSource queries the configured sourceTable for all distinct +// tenant values and ensures that partitions exist for each one. +// This enables automatic partition creation when new tenants are added to a +// source table (e.g., a "tenants" table). +// +// No-ops if sourceTable is not configured. +func (p *PartitionedDatabase) SyncPartitionsFromSource(ctx context.Context) error { + if p.config.SourceTable == "" { + return nil + } + + if err := validateIdentifier(p.config.SourceTable); err != nil { + return fmt.Errorf("partitioned database %q: invalid source table: %w", p.name, err) + } + + srcCol := p.config.SourceColumn + if srcCol == "" { + srcCol = p.config.PartitionKey + } + if err := validateIdentifier(srcCol); err != nil { + return fmt.Errorf("partitioned database %q: invalid source column: %w", p.name, err) + } + + db := p.base.DB() + if db == nil { + return fmt.Errorf("partitioned database %q: database connection is nil", p.name) + } + + // All identifiers (srcCol, SourceTable) have been validated by validateIdentifier above. + query := fmt.Sprintf("SELECT DISTINCT %s FROM %s WHERE %s IS NOT NULL", //nolint:gosec // G201: identifiers validated above + srcCol, p.config.SourceTable, srcCol) + + rows, err := db.QueryContext(ctx, query) + if err != nil { + return fmt.Errorf("partitioned database %q: failed to query source table %q: %w", + p.name, p.config.SourceTable, err) + } + defer rows.Close() + + var tenants []string + for rows.Next() { + var val string + if err := rows.Scan(&val); err != nil { + return fmt.Errorf("partitioned database %q: failed to scan tenant value: %w", p.name, err) + } + tenants = append(tenants, val) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("partitioned database %q: row iteration error: %w", p.name, err) + } + + for _, tenant := range tenants { + if err := p.EnsurePartition(ctx, tenant); err != nil { + return err + } + } + + return nil +} + // isSupportedPartitionDriver returns true for PostgreSQL-compatible drivers. func isSupportedPartitionDriver(driver string) bool { switch driver { - case "pgx", "pgx/v5", "postgres", "postgresql": + case "pgx", "pgx/v5", "postgres": return true } return false diff --git a/module/database_partitioned_test.go b/module/database_partitioned_test.go index 252b1904..5298d13d 100644 --- a/module/database_partitioned_test.go +++ b/module/database_partitioned_test.go @@ -3,6 +3,7 @@ package module import ( "context" "database/sql" + "strings" "testing" ) @@ -27,6 +28,49 @@ func TestPartitionedDatabase_PartitionKey(t *testing.T) { } } +func TestPartitionedDatabase_Defaults(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + PartitionKey: "tenant_id", + } + pd := NewPartitionedDatabase("db", cfg) + + if pd.PartitionType() != PartitionTypeList { + t.Errorf("expected default partition type %q, got %q", PartitionTypeList, pd.PartitionType()) + } + if pd.PartitionNameFormat() != "{table}_{tenant}" { + t.Errorf("expected default format, got %q", pd.PartitionNameFormat()) + } +} + +func TestPartitionedDatabase_PartitionTableName(t *testing.T) { + tests := []struct { + format string + table string + tenant string + expected string + }{ + {"{table}_{tenant}", "forms", "org-alpha", "forms_org_alpha"}, + {"{tenant}_{table}", "forms", "org-alpha", "org_alpha_forms"}, + {"{table}_{tenant}", "submissions", "org.beta", "submissions_org_beta"}, + {"", "forms", "org-alpha", "forms_org_alpha"}, // default format + } + + for _, tc := range tests { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + PartitionKey: "tenant_id", + PartitionNameFormat: tc.format, + } + pd := NewPartitionedDatabase("db", cfg) + got := pd.PartitionTableName(tc.table, tc.tenant) + if got != tc.expected { + t.Errorf("PartitionTableName(format=%q, table=%q, tenant=%q) = %q, want %q", + tc.format, tc.table, tc.tenant, got, tc.expected) + } + } +} + func TestPartitionedDatabase_EnsurePartition_InvalidDriver(t *testing.T) { cfg := PartitionedDatabaseConfig{ Driver: "sqlite3", @@ -69,6 +113,51 @@ func TestPartitionedDatabase_EnsurePartition_InvalidPartitionKey(t *testing.T) { } } +func TestPartitionedDatabase_EnsurePartition_UnsupportedType(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + PartitionKey: "tenant_id", + Tables: []string{"forms"}, + PartitionType: "hash", + } + pd := NewPartitionedDatabase("db", cfg) + // DB is nil — but the partition type check should happen before the nil check + err := pd.EnsurePartition(context.Background(), "org-alpha") + if err == nil { + t.Fatal("expected error for unsupported partition type") + } +} + +func TestPartitionedDatabase_SyncPartitionsFromSource_NoSourceTable(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + PartitionKey: "tenant_id", + Tables: []string{"forms"}, + } + pd := NewPartitionedDatabase("db", cfg) + + // No source table => no-op + err := pd.SyncPartitionsFromSource(context.Background()) + if err != nil { + t.Fatalf("expected no-op when sourceTable is empty, got: %v", err) + } +} + +func TestPartitionedDatabase_SyncPartitionsFromSource_InvalidSourceTable(t *testing.T) { + cfg := PartitionedDatabaseConfig{ + Driver: "pgx", + PartitionKey: "tenant_id", + Tables: []string{"forms"}, + SourceTable: "invalid table!", + } + pd := NewPartitionedDatabase("db", cfg) + + err := pd.SyncPartitionsFromSource(context.Background()) + if err == nil { + t.Fatal("expected error for invalid source table name") + } +} + func TestSanitizePartitionSuffix(t *testing.T) { tests := []struct { input string @@ -89,14 +178,14 @@ func TestSanitizePartitionSuffix(t *testing.T) { } func TestIsSupportedPartitionDriver(t *testing.T) { - supported := []string{"pgx", "pgx/v5", "postgres", "postgresql"} + supported := []string{"pgx", "pgx/v5", "postgres"} for _, d := range supported { if !isSupportedPartitionDriver(d) { t.Errorf("expected %q to be supported", d) } } - unsupported := []string{"sqlite3", "sqlite", "mysql", ""} + unsupported := []string{"sqlite3", "sqlite", "mysql", "", "postgresql"} for _, d := range unsupported { if isSupportedPartitionDriver(d) { t.Errorf("expected %q to be unsupported", d) @@ -106,12 +195,24 @@ func TestIsSupportedPartitionDriver(t *testing.T) { // testPartitionManager is a mock PartitionManager for testing step.db_create_partition. type testPartitionManager struct { - partitionKey string - partitions map[string]bool + partitionKey string + partitionNameFormat string + partitions map[string]bool + syncCalled bool } func (p *testPartitionManager) DB() *sql.DB { return nil } func (p *testPartitionManager) PartitionKey() string { return p.partitionKey } +func (p *testPartitionManager) PartitionTableName(parentTable, tenantValue string) string { + format := p.partitionNameFormat + if format == "" { + format = "{table}_{tenant}" + } + suffix := sanitizePartitionSuffix(tenantValue) + name := strings.ReplaceAll(format, "{table}", parentTable) + name = strings.ReplaceAll(name, "{tenant}", suffix) + return name +} func (p *testPartitionManager) EnsurePartition(_ context.Context, tenantValue string) error { if p.partitions == nil { p.partitions = make(map[string]bool) @@ -119,6 +220,10 @@ func (p *testPartitionManager) EnsurePartition(_ context.Context, tenantValue st p.partitions[tenantValue] = true return nil } +func (p *testPartitionManager) SyncPartitionsFromSource(_ context.Context) error { + p.syncCalled = true + return nil +} func TestDBCreatePartitionStep_Execute(t *testing.T) { mgr := &testPartitionManager{partitionKey: "tenant_id"} @@ -153,7 +258,7 @@ func TestDBCreatePartitionStep_Execute(t *testing.T) { func TestDBCreatePartitionStep_MissingDatabase(t *testing.T) { factory := NewDBCreatePartitionStepFactory() _, err := factory("create-part", map[string]any{ - "tenantKey": "body.tenant_id", + "tenantKey": "steps.body.tenant_id", }, nil) if err == nil { t.Fatal("expected error for missing database") @@ -177,7 +282,7 @@ func TestDBCreatePartitionStep_NotPartitionManager(t *testing.T) { factory := NewDBCreatePartitionStepFactory() step, err := factory("create-part", map[string]any{ "database": "plain-db", - "tenantKey": "body.tenant_id", + "tenantKey": "steps.body.tenant_id", }, app) if err != nil { t.Fatalf("factory error: %v", err) @@ -214,3 +319,57 @@ func TestDBCreatePartitionStep_NilTenantValue(t *testing.T) { t.Fatal("expected error when tenant value is nil") } } + +func TestDBSyncPartitionsStep_Execute(t *testing.T) { + mgr := &testPartitionManager{partitionKey: "tenant_id"} + app := NewMockApplication() + app.Services["part-db"] = mgr + + factory := NewDBSyncPartitionsStepFactory() + step, err := factory("sync-parts", map[string]any{ + "database": "part-db", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if !mgr.syncCalled { + t.Error("expected SyncPartitionsFromSource to be called") + } + if result.Output["synced"] != true { + t.Errorf("expected synced=true, got %v", result.Output["synced"]) + } +} + +func TestDBSyncPartitionsStep_MissingDatabase(t *testing.T) { + factory := NewDBSyncPartitionsStepFactory() + _, err := factory("sync-parts", map[string]any{}, nil) + if err == nil { + t.Fatal("expected error for missing database") + } +} + +func TestDBSyncPartitionsStep_NotPartitionManager(t *testing.T) { + db := setupTenantTestDB(t) + app := mockAppWithDB("plain-db", db) + + factory := NewDBSyncPartitionsStepFactory() + step, err := factory("sync-parts", map[string]any{ + "database": "plain-db", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when service does not implement PartitionManager") + } +} diff --git a/module/pipeline_step_db_exec.go b/module/pipeline_step_db_exec.go index 6ee6ff17..7dfde5b0 100644 --- a/module/pipeline_step_db_exec.go +++ b/module/pipeline_step_db_exec.go @@ -106,6 +106,11 @@ func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResul // Apply automatic tenant scoping when tenantKey is configured. query := s.query if s.tenantKey != "" { + // Reject tenantKey for INSERT statements — WHERE doesn't apply. + upperQ := strings.TrimLeft(strings.ToUpper(strings.TrimSpace(s.query)), "(") + if strings.HasPrefix(upperQ, "INSERT") { + return nil, fmt.Errorf("db_exec step %q: tenantKey is not supported for INSERT statements (include the tenant column in your VALUES instead)", s.name) + } pkp, ok := svc.(PartitionKeyProvider) if !ok { return nil, fmt.Errorf("db_exec step %q: tenantKey requires database %q to implement PartitionKeyProvider (use database.partitioned)", s.name, s.database) diff --git a/module/pipeline_step_db_sync_partitions.go b/module/pipeline_step_db_sync_partitions.go new file mode 100644 index 00000000..6d7bc414 --- /dev/null +++ b/module/pipeline_step_db_sync_partitions.go @@ -0,0 +1,59 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" +) + +// DBSyncPartitionsStep synchronizes partitions from a source table (e.g., tenants) +// for all tables managed by a database.partitioned module. This enables automatic +// partition creation when new tenants are onboarded. +type DBSyncPartitionsStep struct { + name string + database string + app modular.Application +} + +// NewDBSyncPartitionsStepFactory returns a StepFactory for DBSyncPartitionsStep. +func NewDBSyncPartitionsStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + database, _ := config["database"].(string) + if database == "" { + return nil, fmt.Errorf("db_sync_partitions step %q: 'database' is required", name) + } + + return &DBSyncPartitionsStep{ + name: name, + database: database, + app: app, + }, nil + } +} + +func (s *DBSyncPartitionsStep) Name() string { return s.name } + +func (s *DBSyncPartitionsStep) Execute(ctx context.Context, _ *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("db_sync_partitions step %q: no application context", s.name) + } + + svc, ok := s.app.SvcRegistry()[s.database] + if !ok { + return nil, fmt.Errorf("db_sync_partitions step %q: database service %q not found", s.name, s.database) + } + + mgr, ok := svc.(PartitionManager) + if !ok { + return nil, fmt.Errorf("db_sync_partitions step %q: service %q does not implement PartitionManager (use database.partitioned)", s.name, s.database) + } + + if err := mgr.SyncPartitionsFromSource(ctx); err != nil { + return nil, fmt.Errorf("db_sync_partitions step %q: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{ + "synced": true, + }}, nil +} diff --git a/module/pipeline_step_db_tenant_test.go b/module/pipeline_step_db_tenant_test.go index d1d0ae17..5fec2c52 100644 --- a/module/pipeline_step_db_tenant_test.go +++ b/module/pipeline_step_db_tenant_test.go @@ -3,17 +3,29 @@ package module import ( "context" "database/sql" + "strings" "testing" ) // testPartitionKeyProvider wraps a *sql.DB to satisfy PartitionKeyProvider. type testPartitionKeyProvider struct { - db *sql.DB - partitionKey string + db *sql.DB + partitionKey string + partitionNameFormat string } func (p *testPartitionKeyProvider) DB() *sql.DB { return p.db } func (p *testPartitionKeyProvider) PartitionKey() string { return p.partitionKey } +func (p *testPartitionKeyProvider) PartitionTableName(parentTable, tenantValue string) string { + format := p.partitionNameFormat + if format == "" { + format = "{table}_{tenant}" + } + suffix := sanitizePartitionSuffix(tenantValue) + name := strings.ReplaceAll(format, "{table}", parentTable) + name = strings.ReplaceAll(name, "{tenant}", suffix) + return name +} // mockAppWithPartitionDB creates a MockApplication with a PartitionKeyProvider service. func mockAppWithPartitionDB(name string, db *sql.DB, partitionKey string) *MockApplication { @@ -178,8 +190,62 @@ func TestAppendTenantFilter_WithoutWhereClause(t *testing.T) { func TestAppendTenantFilter_TrailingWhitespace(t *testing.T) { query := "SELECT * FROM forms ORDER BY created_at " result := appendTenantFilter(query, "tenant_id", 1) - expected := "SELECT * FROM forms ORDER BY created_at WHERE tenant_id = $1" + expected := "SELECT * FROM forms WHERE tenant_id = $1 ORDER BY created_at" + if result != expected { + t.Errorf("expected %q, got %q", expected, result) + } +} + +func TestAppendTenantFilter_OrderByWithWhere(t *testing.T) { + query := "SELECT * FROM forms WHERE active = true ORDER BY created_at DESC" + result := appendTenantFilter(query, "tenant_id", 2) + expected := "SELECT * FROM forms WHERE active = true AND tenant_id = $2 ORDER BY created_at DESC" + if result != expected { + t.Errorf("expected %q, got %q", expected, result) + } +} + +func TestAppendTenantFilter_LimitOffset(t *testing.T) { + query := "SELECT * FROM forms WHERE active = true ORDER BY id LIMIT 10 OFFSET 20" + result := appendTenantFilter(query, "tenant_id", 2) + expected := "SELECT * FROM forms WHERE active = true AND tenant_id = $2 ORDER BY id LIMIT 10 OFFSET 20" + if result != expected { + t.Errorf("expected %q, got %q", expected, result) + } +} + +func TestAppendTenantFilter_GroupByHaving(t *testing.T) { + query := "SELECT tenant_id, COUNT(*) FROM forms GROUP BY tenant_id HAVING COUNT(*) > 1" + result := appendTenantFilter(query, "tenant_id", 1) + expected := "SELECT tenant_id, COUNT(*) FROM forms WHERE tenant_id = $1 GROUP BY tenant_id HAVING COUNT(*) > 1" if result != expected { t.Errorf("expected %q, got %q", expected, result) } } + +func TestDBExecStep_TenantKey_RejectsInsert(t *testing.T) { + db := setupTenantTestDB(t) + app := mockAppWithPartitionDB("part-db", db, "tenant_id") + + factory := NewDBExecStepFactory() + step, err := factory("insert-form", map[string]any{ + "database": "part-db", + "query": "INSERT INTO forms (id, tenant_id, slug) VALUES ($1, $2, $3)", + "params": []any{"f4", "org-alpha", "new-form"}, + "tenantKey": "steps.auth.tenant_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant_id": "org-alpha"}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when using tenantKey with INSERT") + } + if !strings.Contains(err.Error(), "INSERT") { + t.Errorf("expected error to mention INSERT, got: %v", err) + } +} diff --git a/module/sql_placeholders.go b/module/sql_placeholders.go index 0f20a926..710d73c4 100644 --- a/module/sql_placeholders.go +++ b/module/sql_placeholders.go @@ -82,16 +82,68 @@ func normalizePlaceholders(query, driver string) string { return result } -// appendTenantFilter appends "AND = $N" to a SQL query, where N is -// the next positional parameter index. The function handles queries with and -// without an existing WHERE clause. +// sqlTrailingClauses are SQL clause keywords that must come after WHERE. +// We search for the last occurrence of these to insert the tenant predicate +// before them. The order does not matter; we find the earliest position among +// all matches that appear after any existing WHERE. +var sqlTrailingClauses = []string{ + " ORDER BY ", + " GROUP BY ", + " HAVING ", + " LIMIT ", + " OFFSET ", + " UNION ", + " INTERSECT ", + " EXCEPT ", + " FOR UPDATE", + " FOR SHARE", + " FOR NO KEY UPDATE", + " RETURNING ", +} + +// appendTenantFilter inserts a tenant predicate into a SQL SELECT/UPDATE/DELETE +// query. The predicate is placed: +// - After an existing WHERE clause and before any trailing clause +// (ORDER BY, LIMIT, etc.), or +// - As a new WHERE clause before any trailing clause when none exists. +// +// Returns an error string (empty on success) when the query is an INSERT or +// other unsupported statement type. func appendTenantFilter(query, column string, paramIndex int) string { - query = strings.TrimRight(query, " \t\n\r;") - upper := strings.ToUpper(query) - if strings.Contains(upper, " WHERE ") { - return fmt.Sprintf("%s AND %s = $%d", query, column, paramIndex) + trimmed := strings.TrimRight(query, " \t\n\r;") + upper := strings.ToUpper(trimmed) + + // Find the position right after the WHERE clause (if any). + whereIdx := strings.Index(upper, " WHERE ") + hasWhere := whereIdx >= 0 + + // Find the earliest trailing clause position that appears after the WHERE. + insertPos := len(trimmed) + whereLen := len(" WHERE ") + for _, kw := range sqlTrailingClauses { + // Search starting from the position after WHERE (or from the start if no WHERE). + searchStart := 0 + if hasWhere { + searchStart = whereIdx + whereLen + } + idx := strings.Index(upper[searchStart:], kw) + if idx >= 0 { + absPos := searchStart + idx + if absPos < insertPos { + insertPos = absPos + } + } + } + + predicate := fmt.Sprintf("%s = $%d", column, paramIndex) + + before := trimmed[:insertPos] + after := trimmed[insertPos:] + + if hasWhere { + return fmt.Sprintf("%s AND %s%s", before, predicate, after) } - return fmt.Sprintf("%s WHERE %s = $%d", query, column, paramIndex) + return fmt.Sprintf("%s WHERE %s%s", before, predicate, after) } // placeholder count in the query. Returns an error if there's a mismatch. diff --git a/plugins/pipelinesteps/plugin.go b/plugins/pipelinesteps/plugin.go index 31b45842..9588bbdc 100644 --- a/plugins/pipelinesteps/plugin.go +++ b/plugins/pipelinesteps/plugin.go @@ -67,6 +67,7 @@ func New() *Plugin { "step.db_exec", "step.db_query_cached", "step.db_create_partition", + "step.db_sync_partitions", "step.json_response", "step.raw_response", "step.workflow_call", @@ -129,6 +130,7 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { "step.db_exec": wrapStepFactory(module.NewDBExecStepFactory()), "step.db_query_cached": wrapStepFactory(module.NewDBQueryCachedStepFactory()), "step.db_create_partition": wrapStepFactory(module.NewDBCreatePartitionStepFactory()), + "step.db_sync_partitions": wrapStepFactory(module.NewDBSyncPartitionsStepFactory()), "step.json_response": wrapStepFactory(module.NewJSONResponseStepFactory()), "step.raw_response": wrapStepFactory(module.NewRawResponseStepFactory()), "step.validate_path_param": wrapStepFactory(module.NewValidatePathParamStepFactory()), diff --git a/plugins/pipelinesteps/plugin_test.go b/plugins/pipelinesteps/plugin_test.go index 69ac2b76..9af2f58f 100644 --- a/plugins/pipelinesteps/plugin_test.go +++ b/plugins/pipelinesteps/plugin_test.go @@ -46,6 +46,7 @@ func TestStepFactories(t *testing.T) { "step.db_exec", "step.db_query_cached", "step.db_create_partition", + "step.db_sync_partitions", "step.json_response", "step.raw_response", "step.validate_path_param", diff --git a/plugins/storage/plugin.go b/plugins/storage/plugin.go index 33b9a40b..d4e1d399 100644 --- a/plugins/storage/plugin.go +++ b/plugins/storage/plugin.go @@ -174,6 +174,18 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { } } } + if pt, ok := cfg["partitionType"].(string); ok { + partCfg.PartitionType = pt + } + if pnf, ok := cfg["partitionNameFormat"].(string); ok { + partCfg.PartitionNameFormat = pnf + } + if st, ok := cfg["sourceTable"].(string); ok { + partCfg.SourceTable = st + } + if sc, ok := cfg["sourceColumn"].(string); ok { + partCfg.SourceColumn = sc + } return module.NewPartitionedDatabase(name, partCfg) }, "persistence.store": func(name string, cfg map[string]any) modular.Module { @@ -343,18 +355,22 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema { Type: "database.partitioned", Label: "Partitioned Database", Category: "database", - Description: "PostgreSQL LIST-partitioned database for multi-tenant data isolation. Automatically manages per-tenant partitions and enables automatic tenant scoping in step.db_query and step.db_exec.", + Description: "PostgreSQL partitioned database for multi-tenant data isolation. Supports LIST and RANGE partitions with configurable naming format and optional source-table-driven auto-partition creation.", Inputs: []schema.ServiceIODef{{Name: "query", Type: "SQL", Description: "SQL query to execute"}}, Outputs: []schema.ServiceIODef{{Name: "database", Type: "sql.DB", Description: "SQL database connection pool"}}, ConfigFields: []schema.ConfigFieldDef{ {Key: "driver", Label: "Driver", Type: schema.FieldTypeSelect, Options: []string{"pgx", "pgx/v5", "postgres"}, Required: true, Description: "PostgreSQL database driver"}, {Key: "dsn", Label: "DSN", Type: schema.FieldTypeString, Required: true, Description: "Data source name / connection string", Placeholder: "postgres://user:pass@localhost/db?sslmode=disable", Sensitive: true}, //nolint:gosec // G101: placeholder DSN example in schema documentation - {Key: "partitionKey", Label: "Partition Key", Type: schema.FieldTypeString, Required: true, Description: "Column name used for LIST partitioning (e.g. tenant_id)", Placeholder: "tenant_id"}, - {Key: "tables", Label: "Tables", Type: schema.FieldTypeArray, ArrayItemType: "string", Required: true, Description: "Tables to manage LIST partitions for", Placeholder: "forms"}, + {Key: "partitionKey", Label: "Partition Key", Type: schema.FieldTypeString, Required: true, Description: "Column name used for partitioning (e.g. tenant_id)", Placeholder: "tenant_id"}, + {Key: "tables", Label: "Tables", Type: schema.FieldTypeArray, ArrayItemType: "string", Required: true, Description: "Tables to manage partitions for", Placeholder: "forms"}, + {Key: "partitionType", Label: "Partition Type", Type: schema.FieldTypeSelect, Options: []string{"list", "range"}, DefaultValue: "list", Description: "PostgreSQL partition type: list (FOR VALUES IN) or range (FOR VALUES FROM/TO)"}, + {Key: "partitionNameFormat", Label: "Partition Name Format", Type: schema.FieldTypeString, DefaultValue: "{table}_{tenant}", Description: "Template for partition table names. Supports {table} and {tenant} placeholders.", Placeholder: "{table}_{tenant}"}, + {Key: "sourceTable", Label: "Source Table", Type: schema.FieldTypeString, Description: "Table containing all tenant IDs for auto-partition sync (e.g. tenants)", Placeholder: "tenants"}, + {Key: "sourceColumn", Label: "Source Column", Type: schema.FieldTypeString, Description: "Column in source table to query for tenant values. Defaults to partitionKey.", Placeholder: "id"}, {Key: "maxOpenConns", Label: "Max Open Connections", Type: schema.FieldTypeNumber, DefaultValue: 25, Description: "Maximum number of open database connections"}, {Key: "maxIdleConns", Label: "Max Idle Connections", Type: schema.FieldTypeNumber, DefaultValue: 5, Description: "Maximum number of idle connections in the pool"}, }, - DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5}, + DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5, "partitionType": "list", "partitionNameFormat": "{table}_{tenant}"}, }, { Type: "persistence.store", diff --git a/schema/module_schema.go b/schema/module_schema.go index 53c72738..0c685fe6 100644 --- a/schema/module_schema.go +++ b/schema/module_schema.go @@ -516,18 +516,22 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { Type: "database.partitioned", Label: "Partitioned Database", Category: "database", - Description: "PostgreSQL LIST-partitioned database for multi-tenant data isolation. Automatically manages per-tenant partitions and enables automatic tenant scoping in step.db_query and step.db_exec.", + Description: "PostgreSQL partitioned database for multi-tenant data isolation. Supports LIST and RANGE partitions with configurable naming format and optional source-table-driven auto-partition creation.", Inputs: []ServiceIODef{{Name: "query", Type: "SQL", Description: "SQL query to execute"}}, Outputs: []ServiceIODef{{Name: "database", Type: "sql.DB", Description: "SQL database connection pool"}}, ConfigFields: []ConfigFieldDef{ {Key: "driver", Label: "Driver", Type: FieldTypeSelect, Options: []string{"pgx", "pgx/v5", "postgres"}, Required: true, Description: "PostgreSQL database driver"}, {Key: "dsn", Label: "DSN", Type: FieldTypeString, Required: true, Description: "Data source name / connection string", Placeholder: "postgres://user:pass@localhost/db?sslmode=disable", Sensitive: true}, //nolint:gosec // G101: placeholder DSN example in schema documentation - {Key: "partitionKey", Label: "Partition Key", Type: FieldTypeString, Required: true, Description: "Column name used for LIST partitioning (e.g. tenant_id)", Placeholder: "tenant_id"}, - {Key: "tables", Label: "Tables", Type: FieldTypeArray, ArrayItemType: "string", Required: true, Description: "Tables to manage LIST partitions for", Placeholder: "forms"}, + {Key: "partitionKey", Label: "Partition Key", Type: FieldTypeString, Required: true, Description: "Column name used for partitioning (e.g. tenant_id)", Placeholder: "tenant_id"}, + {Key: "tables", Label: "Tables", Type: FieldTypeArray, ArrayItemType: "string", Required: true, Description: "Tables to manage partitions for", Placeholder: "forms"}, + {Key: "partitionType", Label: "Partition Type", Type: FieldTypeSelect, Options: []string{"list", "range"}, DefaultValue: "list", Description: "PostgreSQL partition type: list (FOR VALUES IN) or range (FOR VALUES FROM/TO)"}, + {Key: "partitionNameFormat", Label: "Partition Name Format", Type: FieldTypeString, DefaultValue: "{table}_{tenant}", Description: "Template for partition table names. Supports {table} and {tenant} placeholders.", Placeholder: "{table}_{tenant}"}, + {Key: "sourceTable", Label: "Source Table", Type: FieldTypeString, Description: "Table containing all tenant IDs for auto-partition sync (e.g. tenants)", Placeholder: "tenants"}, + {Key: "sourceColumn", Label: "Source Column", Type: FieldTypeString, Description: "Column in source table to query for tenant values. Defaults to partitionKey.", Placeholder: "id"}, {Key: "maxOpenConns", Label: "Max Open Connections", Type: FieldTypeNumber, DefaultValue: 25, Description: "Maximum number of open database connections"}, {Key: "maxIdleConns", Label: "Max Idle Connections", Type: FieldTypeNumber, DefaultValue: 5, Description: "Maximum number of idle connections in the pool"}, }, - DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5}, + DefaultConfig: map[string]any{"maxOpenConns": 25, "maxIdleConns": 5, "partitionType": "list", "partitionNameFormat": "{table}_{tenant}"}, }) r.Register(&ModuleSchema{ @@ -1020,7 +1024,7 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { {Key: "query", Label: "SQL Query", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL SELECT query (use ? for placeholders, no template expressions allowed)", Placeholder: "SELECT id, name FROM companies WHERE id = ?"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for ? placeholders in query"}, {Key: "mode", Label: "Mode", Type: FieldTypeSelect, Options: []string{"list", "single"}, DefaultValue: "list", Description: "Result mode: 'list' returns rows/count, 'single' returns row/found"}, - {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping (requires database.partitioned)", Placeholder: "auth.tenant_id"}, + {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping (requires database.partitioned)", Placeholder: "steps.auth.tenant_id"}, }, }) @@ -1045,14 +1049,14 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { Type: "step.db_exec", Label: "Database Execute", Category: "pipeline", - Description: "Executes a parameterized SQL INSERT/UPDATE/DELETE against a named database service", + Description: "Executes a parameterized SQL UPDATE/DELETE against a named database service. tenantKey is supported for UPDATE/DELETE but rejected for INSERT.", Inputs: []ServiceIODef{{Name: "context", Type: "PipelineContext", Description: "Pipeline context for template parameter resolution"}}, Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Execution result with affected_rows and last_id"}}, ConfigFields: []ConfigFieldDef{ {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of the database service (must implement DBProvider)", Placeholder: "admin-db", InheritFrom: "dependency.name"}, {Key: "query", Label: "SQL Statement", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL INSERT/UPDATE/DELETE statement (use ? for placeholders)", Placeholder: "INSERT INTO companies (id, name) VALUES (?, ?)"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for ? placeholders"}, - {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping (requires database.partitioned)", Placeholder: "auth.tenant_id"}, + {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping. Supported for UPDATE/DELETE only (requires database.partitioned)", Placeholder: "steps.auth.tenant_id"}, }, }) @@ -1060,12 +1064,24 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { Type: "step.db_create_partition", Label: "Create Database Partition", Category: "pipeline", - Description: "Creates a PostgreSQL LIST partition for a tenant on all tables managed by a database.partitioned module. Idempotent — safe to call when a partition may already exist.", + Description: "Creates a PostgreSQL partition for a tenant on all tables managed by a database.partitioned module. Supports both LIST and RANGE partition types. Idempotent — safe to call when a partition may already exist.", Inputs: []ServiceIODef{{Name: "context", Type: "PipelineContext", Description: "Pipeline context for tenant key resolution"}}, Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Partition creation result with tenant and partition fields"}}, ConfigFields: []ConfigFieldDef{ {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of a database.partitioned service", Placeholder: "db", InheritFrom: "dependency.name"}, - {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Required: true, Description: "Dot-path in pipeline context to resolve the tenant value (e.g. the new tenant's ID)", Placeholder: "body.tenant_id"}, + {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Required: true, Description: "Dot-path in pipeline context to resolve the tenant value (e.g. the new tenant's ID)", Placeholder: "steps.body.tenant_id"}, + }, + }) + + r.Register(&ModuleSchema{ + Type: "step.db_sync_partitions", + Label: "Sync Database Partitions", + Category: "pipeline", + Description: "Synchronizes partitions from the configured source table in a database.partitioned module. Queries all distinct tenant values from the source table and creates missing partitions for all managed tables.", + Inputs: []ServiceIODef{{Name: "context", Type: "PipelineContext", Description: "Pipeline context (not used but required for step interface)"}}, + Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Sync result with synced boolean"}}, + ConfigFields: []ConfigFieldDef{ + {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of a database.partitioned service with sourceTable configured", Placeholder: "db", InheritFrom: "dependency.name"}, }, }) diff --git a/schema/schema.go b/schema/schema.go index 38bb3e46..33bd75ae 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -214,6 +214,7 @@ var coreModuleTypes = []string{ "step.db_exec", "step.db_query", "step.db_query_cached", + "step.db_sync_partitions", "step.delegate", "step.deploy", "step.dlq_replay", From 6a2a32d4574fe63fdd6a3347f785a49a077b279c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 07:04:22 +0000 Subject: [PATCH 5/5] fix: resolve merge conflicts with main branch Merge features from main (allow_dynamic_sql, auth.m2m custom claims) with our partitioning features (tenantKey, database.partitioned, step.db_create_partition, step.db_sync_partitions). Both db_query and db_exec now support both allow_dynamic_sql and tenantKey simultaneously. Schema includes all fields from both feature branches. Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- module/auth_m2m.go | 11 +- module/auth_m2m_test.go | 140 +++++++++++++++++++ module/pipeline_step_db_dynamic.go | 76 ++++++++++ module/pipeline_step_db_exec.go | 53 ++++--- module/pipeline_step_db_exec_test.go | 63 +++++++++ module/pipeline_step_db_query.go | 53 ++++--- module/pipeline_step_db_query_cached.go | 65 +++++---- module/pipeline_step_db_query_cached_test.go | 66 +++++++++ module/pipeline_step_db_query_test.go | 110 +++++++++++++++ plugins/auth/plugin.go | 5 +- plugins/auth/plugin_test.go | 48 +++++++ schema/module_schema.go | 11 +- 12 files changed, 629 insertions(+), 72 deletions(-) create mode 100644 module/pipeline_step_db_dynamic.go diff --git a/module/auth_m2m.go b/module/auth_m2m.go index 16fc8c74..8c84cc0d 100644 --- a/module/auth_m2m.go +++ b/module/auth_m2m.go @@ -40,10 +40,11 @@ const ( // M2MClient represents a registered machine-to-machine OAuth2 client. type M2MClient struct { - ClientID string `json:"clientId"` - ClientSecret string `json:"clientSecret"` //nolint:gosec // G117: config DTO field - Description string `json:"description,omitempty"` - Scopes []string `json:"scopes,omitempty"` + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` //nolint:gosec // G117: config DTO field + Description string `json:"description,omitempty"` + Scopes []string `json:"scopes,omitempty"` + Claims map[string]any `json:"claims,omitempty"` } // M2MAuthModule provides machine-to-machine (server-to-server) OAuth2 authentication. @@ -251,7 +252,7 @@ func (m *M2MAuthModule) handleClientCredentials(w http.ResponseWriter, r *http.R return } - token, err := m.issueToken(clientID, grantedScopes, nil) + token, err := m.issueToken(clientID, grantedScopes, client.Claims) if err != nil { w.WriteHeader(http.StatusInternalServerError) _ = json.NewEncoder(w).Encode(oauthError("server_error", "failed to issue token")) diff --git a/module/auth_m2m_test.go b/module/auth_m2m_test.go index 4978883d..6de23b19 100644 --- a/module/auth_m2m_test.go +++ b/module/auth_m2m_test.go @@ -1171,3 +1171,143 @@ func TestM2M_ClientCredentials_SubMatchesClientID(t *testing.T) { t.Errorf("expected sub=test-client, got %v", claims["sub"]) } } + +// --- per-client custom claims --- + +// TestM2M_ClientCredentials_CustomClaimsInToken verifies that a client's Claims +// map is included in the issued access token. +func TestM2M_ClientCredentials_CustomClaimsInToken(t *testing.T) { + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer") + m.RegisterClient(M2MClient{ + ClientID: "org-alpha", + ClientSecret: "secret-org-alpha", //nolint:gosec // test credential + Scopes: []string{"read"}, + Claims: map[string]any{"tenant_id": "alpha"}, + }) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"org-alpha"}, + "client_secret": {"secret-org-alpha"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + _, claims, err := m.Authenticate(tokenStr) + if err != nil { + t.Fatalf("authenticate: %v", err) + } + if claims["tenant_id"] != "alpha" { + t.Errorf("expected tenant_id=alpha, got %v", claims["tenant_id"]) + } +} + +// TestM2M_ClientCredentials_MultipleCustomClaims verifies that multiple custom +// claims are all present in the issued token. +func TestM2M_ClientCredentials_MultipleCustomClaims(t *testing.T) { + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer") + m.RegisterClient(M2MClient{ + ClientID: "org-beta", + ClientSecret: "secret-org-beta", //nolint:gosec // test credential + Scopes: []string{"read", "write"}, + Claims: map[string]any{ + "tenant_id": "beta", + "affiliate_id": "partner-42", + }, + }) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"org-beta"}, + "client_secret": {"secret-org-beta"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + _, claims, err := m.Authenticate(tokenStr) + if err != nil { + t.Fatalf("authenticate: %v", err) + } + if claims["tenant_id"] != "beta" { + t.Errorf("expected tenant_id=beta, got %v", claims["tenant_id"]) + } + if claims["affiliate_id"] != "partner-42" { + t.Errorf("expected affiliate_id=partner-42, got %v", claims["affiliate_id"]) + } +} + +// TestM2M_ClientCredentials_CustomClaimsDoNotOverrideStandard verifies that +// custom claims on a client cannot override standard JWT claims. +func TestM2M_ClientCredentials_CustomClaimsDoNotOverrideStandard(t *testing.T) { + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "trusted-issuer") + m.RegisterClient(M2MClient{ + ClientID: "attacker", + ClientSecret: "attacker-secret-here", //nolint:gosec // test credential + Scopes: []string{"read"}, + Claims: map[string]any{ + "iss": "evil-issuer", + "sub": "admin", + }, + }) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"attacker"}, + "client_secret": {"attacker-secret-here"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]any + json.NewDecoder(w.Body).Decode(&resp) + tokenStr, _ := resp["access_token"].(string) + + _, claims, err := m.Authenticate(tokenStr) + if err != nil { + t.Fatalf("authenticate: %v", err) + } + // Standard claims must not be overridden by client.Claims. + if claims["iss"] != "trusted-issuer" { + t.Errorf("iss must not be overridable via client claims, got %v", claims["iss"]) + } + if claims["sub"] != "attacker" { + t.Errorf("sub must not be overridable via client claims, got %v", claims["sub"]) + } +} + +// TestM2M_ClientCredentials_NilClaimsOK verifies that a client with nil Claims +// still issues tokens without error. +func TestM2M_ClientCredentials_NilClaimsOK(t *testing.T) { + m := NewM2MAuthModule("m2m", "this-is-a-valid-secret-32-bytes!", time.Hour, "test-issuer") + m.RegisterClient(M2MClient{ + ClientID: "plain-client", + ClientSecret: "plain-client-secret!", //nolint:gosec // test credential + Scopes: []string{"read"}, + Claims: nil, + }) + + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"plain-client"}, + "client_secret": {"plain-client-secret!"}, + } + w := postToken(t, m, params) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } +} diff --git a/module/pipeline_step_db_dynamic.go b/module/pipeline_step_db_dynamic.go new file mode 100644 index 00000000..0175022c --- /dev/null +++ b/module/pipeline_step_db_dynamic.go @@ -0,0 +1,76 @@ +package module + +import ( + "fmt" + "strings" +) + +// validateSQLIdentifier checks that s is safe to interpolate directly into SQL as an +// identifier (e.g. a table name). Only ASCII letters (A-Z, a-z), ASCII digits (0-9), +// underscores (_) and hyphens (-) are permitted. This strict allowlist prevents SQL +// injection when dynamic values are embedded in queries via allow_dynamic_sql. +func validateSQLIdentifier(s string) error { + if s == "" { + return fmt.Errorf("dynamic SQL identifier must not be empty") + } + for _, c := range s { + if (c < 'a' || c > 'z') && + (c < 'A' || c > 'Z') && + (c < '0' || c > '9') && + c != '_' && c != '-' { + return fmt.Errorf("dynamic SQL identifier %q contains unsafe character %q (only ASCII letters, digits, underscores and hyphens are allowed)", s, string(c)) + } + } + return nil +} + +// resolveDynamicSQL resolves every {{ }} template expression found in query against +// pc and validates that each resolved value is a safe SQL identifier. The validated +// values are substituted back into the query in left-to-right order and the final +// SQL string is returned. +// +// Each occurrence of a template expression is resolved independently, so +// non-deterministic functions like {{uuid}} or {{now}} produce a distinct value +// per occurrence. +// +// This is only called when allow_dynamic_sql is true (explicit opt-in). Callers +// are responsible for ensuring that the query has already passed template parsing. +func resolveDynamicSQL(tmpl *TemplateEngine, query string, pc *PipelineContext) (string, error) { + if !strings.Contains(query, "{{") { + return query, nil + } + + // Process template expressions in left-to-right order. Each occurrence is + // resolved and validated independently to preserve correct semantics for + // non-deterministic template functions (e.g. {{uuid}}, {{now}}). + var result strings.Builder + rest := query + for { + openIdx := strings.Index(rest, "{{") + if openIdx < 0 { + result.WriteString(rest) + break + } + closeIdx := strings.Index(rest[openIdx:], "}}") + if closeIdx < 0 { + return "", fmt.Errorf("dynamic SQL: unclosed template action in query (missing closing '}}')") + } + closeIdx += openIdx + + // Write the literal SQL text before this expression. + result.WriteString(rest[:openIdx]) + + expr := rest[openIdx : closeIdx+2] + + resolved, err := tmpl.Resolve(expr, pc) + if err != nil { + return "", fmt.Errorf("dynamic SQL: failed to resolve %q: %w", expr, err) + } + if err := validateSQLIdentifier(resolved); err != nil { + return "", fmt.Errorf("dynamic SQL: %w", err) + } + result.WriteString(resolved) + rest = rest[closeIdx+2:] + } + return result.String(), nil +} diff --git a/module/pipeline_step_db_exec.go b/module/pipeline_step_db_exec.go index 7dfde5b0..710b60f0 100644 --- a/module/pipeline_step_db_exec.go +++ b/module/pipeline_step_db_exec.go @@ -10,14 +10,15 @@ import ( // DBExecStep executes parameterized SQL INSERT/UPDATE/DELETE against a named database service. type DBExecStep struct { - name string - database string - query string - params []string - ignoreError bool - tenantKey string // dot-path to resolve tenant value for automatic scoping - app modular.Application - tmpl *TemplateEngine + name string + database string + query string + params []string + ignoreError bool + tenantKey string // dot-path to resolve tenant value for automatic scoping + allowDynamicSQL bool + app modular.Application + tmpl *TemplateEngine } // NewDBExecStepFactory returns a StepFactory that creates DBExecStep instances. @@ -33,8 +34,10 @@ func NewDBExecStepFactory() StepFactory { return nil, fmt.Errorf("db_exec step %q: 'query' is required", name) } - // Safety: reject template expressions in SQL to prevent injection - if strings.Contains(query, "{{") { + // Safety: reject template expressions in SQL to prevent injection, + // unless allow_dynamic_sql is explicitly enabled. + allowDynamicSQL, _ := config["allow_dynamic_sql"].(bool) + if !allowDynamicSQL && strings.Contains(query, "{{") { return nil, fmt.Errorf("db_exec step %q: query must not contain template expressions (use params instead)", name) } @@ -53,14 +56,15 @@ func NewDBExecStepFactory() StepFactory { tenantKey, _ := config["tenantKey"].(string) return &DBExecStep{ - name: name, - database: database, - query: query, - params: params, - ignoreError: ignoreError, - tenantKey: tenantKey, - app: app, - tmpl: NewTemplateEngine(), + name: name, + database: database, + query: query, + params: params, + ignoreError: ignoreError, + tenantKey: tenantKey, + allowDynamicSQL: allowDynamicSQL, + app: app, + tmpl: NewTemplateEngine(), }, nil } } @@ -68,6 +72,18 @@ func NewDBExecStepFactory() StepFactory { func (s *DBExecStep) Name() string { return s.name } func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { + // Resolve template expressions in the query early (before any DB access) when + // dynamic SQL is enabled. This validates resolved identifiers against an + // allowlist before any database interaction. + query := s.query + if s.allowDynamicSQL { + var err error + query, err = resolveDynamicSQL(s.tmpl, query, pc) + if err != nil { + return nil, fmt.Errorf("db_exec step %q: %w", s.name, err) + } + } + if s.app == nil { return nil, fmt.Errorf("db_exec step %q: no application context", s.name) } @@ -104,7 +120,6 @@ func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResul } // Apply automatic tenant scoping when tenantKey is configured. - query := s.query if s.tenantKey != "" { // Reject tenantKey for INSERT statements — WHERE doesn't apply. upperQ := strings.TrimLeft(strings.ToUpper(strings.TrimSpace(s.query)), "(") diff --git a/module/pipeline_step_db_exec_test.go b/module/pipeline_step_db_exec_test.go index 1e5512ee..77dc5483 100644 --- a/module/pipeline_step_db_exec_test.go +++ b/module/pipeline_step_db_exec_test.go @@ -3,6 +3,7 @@ package module import ( "context" "database/sql" + "strings" "testing" _ "modernc.org/sqlite" @@ -195,6 +196,68 @@ func TestDBExecStep_RejectsTemplateInQuery(t *testing.T) { } } +func TestDBExecStep_DynamicTableName(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open db: %v", err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE items_alpha (id TEXT PRIMARY KEY, name TEXT NOT NULL)`) + if err != nil { + t.Fatalf("create table: %v", err) + } + + app := mockAppWithDB("test-db", db) + factory := NewDBExecStepFactory() + step, err := factory("dynamic-insert", map[string]any{ + "database": "test-db", + "query": `INSERT INTO items_{{.steps.auth.tenant}} (id, name) VALUES (?, ?)`, + "params": []any{"i1", "Widget"}, + "allow_dynamic_sql": true, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant": "alpha"}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + affected, _ := result.Output["affected_rows"].(int64) + if affected != 1 { + t.Errorf("expected affected_rows=1, got %v", affected) + } +} + +func TestDBExecStep_DynamicSQL_RejectsInjection(t *testing.T) { + factory := NewDBExecStepFactory() + step, err := factory("injection-exec", map[string]any{ + "database": "test-db", + "query": `DELETE FROM items_{{.steps.auth.tenant}} WHERE id = ?`, + "params": []any{"i1"}, + "allow_dynamic_sql": true, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant": "alpha'; DROP TABLE items;--"}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for unsafe SQL identifier") + } + if !strings.Contains(err.Error(), "unsafe character") { + t.Errorf("expected 'unsafe character' in error, got: %v", err) + } +} + func TestDBExecStep_MissingDatabase(t *testing.T) { factory := NewDBExecStepFactory() _, err := factory("no-db", map[string]any{ diff --git a/module/pipeline_step_db_query.go b/module/pipeline_step_db_query.go index 5fe9b2b4..89a61662 100644 --- a/module/pipeline_step_db_query.go +++ b/module/pipeline_step_db_query.go @@ -25,14 +25,15 @@ type DBDriverProvider interface { // DBQueryStep executes a parameterized SQL SELECT against a named database service. type DBQueryStep struct { - name string - database string - query string - params []string - mode string // "list" or "single" - tenantKey string // dot-path to resolve tenant value for automatic scoping - app modular.Application - tmpl *TemplateEngine + name string + database string + query string + params []string + mode string // "list" or "single" + tenantKey string // dot-path to resolve tenant value for automatic scoping + allowDynamicSQL bool + app modular.Application + tmpl *TemplateEngine } // NewDBQueryStepFactory returns a StepFactory that creates DBQueryStep instances. @@ -48,8 +49,10 @@ func NewDBQueryStepFactory() StepFactory { return nil, fmt.Errorf("db_query step %q: 'query' is required", name) } - // Safety: reject template expressions in SQL to prevent injection - if strings.Contains(query, "{{") { + // Safety: reject template expressions in SQL to prevent injection, + // unless allow_dynamic_sql is explicitly enabled. + allowDynamicSQL, _ := config["allow_dynamic_sql"].(bool) + if !allowDynamicSQL && strings.Contains(query, "{{") { return nil, fmt.Errorf("db_query step %q: query must not contain template expressions (use params instead)", name) } @@ -75,14 +78,15 @@ func NewDBQueryStepFactory() StepFactory { tenantKey, _ := config["tenantKey"].(string) return &DBQueryStep{ - name: name, - database: database, - query: query, - params: params, - mode: mode, - tenantKey: tenantKey, - app: app, - tmpl: NewTemplateEngine(), + name: name, + database: database, + query: query, + params: params, + mode: mode, + tenantKey: tenantKey, + allowDynamicSQL: allowDynamicSQL, + app: app, + tmpl: NewTemplateEngine(), }, nil } } @@ -90,6 +94,18 @@ func NewDBQueryStepFactory() StepFactory { func (s *DBQueryStep) Name() string { return s.name } func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { + // Resolve template expressions in the query early (before any DB access) when + // dynamic SQL is enabled. This validates resolved identifiers against an + // allowlist before any database interaction. + query := s.query + if s.allowDynamicSQL { + var err error + query, err = resolveDynamicSQL(s.tmpl, query, pc) + if err != nil { + return nil, fmt.Errorf("db_query step %q: %w", s.name, err) + } + } + // Resolve database service if s.app == nil { return nil, fmt.Errorf("db_query step %q: no application context", s.name) @@ -127,7 +143,6 @@ func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResu } // Apply automatic tenant scoping when tenantKey is configured. - query := s.query if s.tenantKey != "" { pkp, ok := svc.(PartitionKeyProvider) if !ok { diff --git a/module/pipeline_step_db_query_cached.go b/module/pipeline_step_db_query_cached.go index 2b40fe55..400f1dab 100644 --- a/module/pipeline_step_db_query_cached.go +++ b/module/pipeline_step_db_query_cached.go @@ -20,15 +20,16 @@ type dbQueryCacheEntry struct { // in an in-process, TTL-aware cache keyed by a template-resolved cache key. // Concurrent pipeline executions are safe: access is protected by a read-write mutex. type DBQueryCachedStep struct { - name string - database string - query string - params []string - cacheKey string - cacheTTL time.Duration - scanFields []string - app modular.Application - tmpl *TemplateEngine + name string + database string + query string + params []string + cacheKey string + cacheTTL time.Duration + scanFields []string + allowDynamicSQL bool + app modular.Application + tmpl *TemplateEngine mu sync.RWMutex cache map[string]dbQueryCacheEntry @@ -47,8 +48,10 @@ func NewDBQueryCachedStepFactory() StepFactory { return nil, fmt.Errorf("db_query_cached step %q: 'query' is required", name) } - // Safety: reject template expressions in SQL to prevent injection - if strings.Contains(query, "{{") { + // Safety: reject template expressions in SQL to prevent injection, + // unless allow_dynamic_sql is explicitly enabled. + allowDynamicSQL, _ := config["allow_dynamic_sql"].(bool) + if !allowDynamicSQL && strings.Contains(query, "{{") { return nil, fmt.Errorf("db_query_cached step %q: query must not contain template expressions (use params instead)", name) } @@ -92,16 +95,17 @@ func NewDBQueryCachedStepFactory() StepFactory { } return &DBQueryCachedStep{ - name: name, - database: database, - query: query, - params: params, - cacheKey: cacheKey, - cacheTTL: cacheTTL, - scanFields: scanFields, - app: app, - tmpl: NewTemplateEngine(), - cache: make(map[string]dbQueryCacheEntry), + name: name, + database: database, + query: query, + params: params, + cacheKey: cacheKey, + cacheTTL: cacheTTL, + scanFields: scanFields, + allowDynamicSQL: allowDynamicSQL, + app: app, + tmpl: NewTemplateEngine(), + cache: make(map[string]dbQueryCacheEntry), }, nil } } @@ -112,6 +116,18 @@ func (s *DBQueryCachedStep) Name() string { return s.name } // Execute checks the in-memory cache first; on a miss (or expiry) it queries // the database, stores the result, and returns it. func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + // Resolve template expressions in the query early (before any DB access) when + // dynamic SQL is enabled. This validates resolved identifiers against an + // allowlist before any database interaction. + query := s.query + if s.allowDynamicSQL { + var err error + query, err = resolveDynamicSQL(s.tmpl, query, pc) + if err != nil { + return nil, fmt.Errorf("db_query_cached step %q: %w", s.name, err) + } + } + if s.app == nil { return nil, fmt.Errorf("db_query_cached step %q: no application context", s.name) } @@ -151,7 +167,7 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (* s.mu.Unlock() // Query the database - result, err := s.runQuery(ctx, pc) + result, err := s.runQuery(ctx, pc, query) if err != nil { return nil, err } @@ -169,7 +185,8 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (* } // runQuery executes the SQL query and returns the result as a map. -func (s *DBQueryCachedStep) runQuery(ctx context.Context, pc *PipelineContext) (map[string]any, error) { +// query is the (already dynamic-SQL-resolved) query string to execute. +func (s *DBQueryCachedStep) runQuery(ctx context.Context, pc *PipelineContext, query string) (map[string]any, error) { svc, ok := s.app.SvcRegistry()[s.database] if !ok { return nil, fmt.Errorf("db_query_cached step %q: database service %q not found", s.name, s.database) @@ -200,7 +217,7 @@ func (s *DBQueryCachedStep) runQuery(ctx context.Context, pc *PipelineContext) ( resolvedParams[i] = resolved } - query := normalizePlaceholders(s.query, driver) + query = normalizePlaceholders(query, driver) rows, err := db.QueryContext(ctx, query, resolvedParams...) if err != nil { diff --git a/module/pipeline_step_db_query_cached_test.go b/module/pipeline_step_db_query_cached_test.go index ab21aebf..14679f83 100644 --- a/module/pipeline_step_db_query_cached_test.go +++ b/module/pipeline_step_db_query_cached_test.go @@ -2,6 +2,7 @@ package module import ( "context" + "strings" "testing" "time" ) @@ -405,3 +406,68 @@ func TestDBQueryCachedStep_NegativeTTLRejected(t *testing.T) { t.Fatal("expected error for negative cache_ttl") } } + +func TestDBQueryCachedStep_DynamicTableName(t *testing.T) { + db := setupTestDB(t) + _, err := db.Exec(`CREATE TABLE companies_beta (id TEXT PRIMARY KEY, name TEXT NOT NULL)`) + if err != nil { + t.Fatalf("create table: %v", err) + } + _, err = db.Exec(`INSERT INTO companies_beta (id, name) VALUES ('b1', 'Beta LLC')`) + if err != nil { + t.Fatalf("insert: %v", err) + } + + app := mockAppWithDB("test-db", db) + factory := NewDBQueryCachedStepFactory() + step, err := factory("dynamic-cached", map[string]any{ + "database": "test-db", + "query": `SELECT id, name FROM companies_{{.steps.auth.tenant}} WHERE id = $1`, + "params": []any{"b1"}, + "cache_key": `tenant:{{.steps.auth.tenant}}:b1`, + "cache_ttl": "5m", + "allow_dynamic_sql": true, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant": "beta"}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["cache_hit"] != false { + t.Errorf("expected cache_hit=false on first call, got %v", result.Output["cache_hit"]) + } + if result.Output["name"] != "Beta LLC" { + t.Errorf("expected name='Beta LLC', got %v", result.Output["name"]) + } +} + +func TestDBQueryCachedStep_DynamicSQL_RejectsInjection(t *testing.T) { + factory := NewDBQueryCachedStepFactory() + step, err := factory("injection-cached", map[string]any{ + "database": "test-db", + "query": `SELECT * FROM companies_{{.steps.auth.tenant}}`, + "cache_key": "k", + "allow_dynamic_sql": true, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant": "alpha'; DROP TABLE companies;--"}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for unsafe SQL identifier") + } + if !strings.Contains(err.Error(), "unsafe character") { + t.Errorf("expected 'unsafe character' in error, got: %v", err) + } +} diff --git a/module/pipeline_step_db_query_test.go b/module/pipeline_step_db_query_test.go index c3d19825..e704faf4 100644 --- a/module/pipeline_step_db_query_test.go +++ b/module/pipeline_step_db_query_test.go @@ -3,6 +3,7 @@ package module import ( "context" "database/sql" + "strings" "testing" _ "modernc.org/sqlite" @@ -218,6 +219,93 @@ func TestDBQueryStep_RejectsTemplateInQuery(t *testing.T) { } } +func TestDBQueryStep_DynamicTableName(t *testing.T) { + db := setupTestDB(t) + // Create a second table whose name is derived from a "tenant" value. + _, err := db.Exec(`CREATE TABLE companies_alpha (id TEXT PRIMARY KEY, name TEXT NOT NULL)`) + if err != nil { + t.Fatalf("create tenant table: %v", err) + } + _, err = db.Exec(`INSERT INTO companies_alpha (id, name) VALUES ('a1', 'Alpha Corp')`) + if err != nil { + t.Fatalf("insert: %v", err) + } + + app := mockAppWithDB("test-db", db) + factory := NewDBQueryStepFactory() + step, err := factory("dynamic-table", map[string]any{ + "database": "test-db", + "query": `SELECT id, name FROM companies_{{.steps.auth.tenant}} WHERE id = ?`, + "params": []any{"a1"}, + "mode": "single", + "allow_dynamic_sql": true, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant": "alpha"}) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + found, _ := result.Output["found"].(bool) + if !found { + t.Error("expected found=true") + } + row, _ := result.Output["row"].(map[string]any) + if row["name"] != "Alpha Corp" { + t.Errorf("expected name='Alpha Corp', got %v", row["name"]) + } +} + +func TestDBQueryStep_DynamicSQL_RejectsInjection(t *testing.T) { + factory := NewDBQueryStepFactory() + step, err := factory("injection-attempt", map[string]any{ + "database": "test-db", + "query": `SELECT * FROM companies_{{.steps.auth.tenant}}`, + "mode": "list", + "allow_dynamic_sql": true, + }, nil) // nil app is fine – we expect an error before the DB is touched + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("auth", map[string]any{"tenant": "alpha; DROP TABLE companies;--"}) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for unsafe SQL identifier") + } + if !strings.Contains(err.Error(), "unsafe character") { + t.Errorf("expected 'unsafe character' in error, got: %v", err) + } +} + +func TestDBQueryStep_DynamicSQL_RejectsEmpty(t *testing.T) { + factory := NewDBQueryStepFactory() + step, err := factory("empty-ident", map[string]any{ + "database": "test-db", + "query": `SELECT * FROM companies_{{.steps.auth.tenant}}`, + "mode": "list", + "allow_dynamic_sql": true, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + // Tenant resolves to empty string (missing key → zero value) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for empty SQL identifier") + } +} + func TestDBQueryStep_MissingDatabase(t *testing.T) { factory := NewDBQueryStepFactory() _, err := factory("no-db", map[string]any{ @@ -228,6 +316,28 @@ func TestDBQueryStep_MissingDatabase(t *testing.T) { } } +func TestDBQueryStep_DynamicSQL_UnclosedAction(t *testing.T) { + factory := NewDBQueryStepFactory() + step, err := factory("unclosed", map[string]any{ + "database": "test-db", + "query": `SELECT * FROM companies_{{.steps.auth.tenant`, + "mode": "list", + "allow_dynamic_sql": true, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for unclosed template action") + } + if !strings.Contains(err.Error(), "unclosed template action") { + t.Errorf("expected 'unclosed template action' in error, got: %v", err) + } +} + func TestDBQueryStep_EmptyResult(t *testing.T) { db := setupTestDB(t) app := mockAppWithDB("test-db", db) diff --git a/plugins/auth/plugin.go b/plugins/auth/plugin.go index 6b87fb18..8df826e8 100644 --- a/plugins/auth/plugin.go +++ b/plugins/auth/plugin.go @@ -196,6 +196,9 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { } } } + if claimsRaw, ok := cm["claims"].(map[string]any); ok { + client.Claims = claimsRaw + } if client.ClientID != "" && client.ClientSecret != "" { m.RegisterClient(client) } @@ -366,7 +369,7 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema { {Key: "privateKey", Label: "EC Private Key (PEM)", Type: schema.FieldTypeString, Description: "PEM-encoded EC private key for ES256 signing; if omitted a key is auto-generated", Sensitive: true}, {Key: "tokenExpiry", Label: "Token Expiry", Type: schema.FieldTypeDuration, DefaultValue: "1h", Description: "Access token expiration duration (e.g. 15m, 1h)", Placeholder: "1h"}, {Key: "issuer", Label: "Issuer", Type: schema.FieldTypeString, DefaultValue: "workflow", Description: "Token issuer (iss) claim", Placeholder: "workflow"}, - {Key: "clients", Label: "Registered Clients", Type: schema.FieldTypeJSON, Description: "List of OAuth2 clients: [{clientId, clientSecret, scopes, description}]"}, + {Key: "clients", Label: "Registered Clients", Type: schema.FieldTypeJSON, Description: "List of OAuth2 clients: [{clientId, clientSecret, scopes, description, claims}]"}, }, DefaultConfig: map[string]any{"algorithm": "ES256", "tokenExpiry": "1h", "issuer": "workflow", "clients": []any{}}, }, diff --git a/plugins/auth/plugin_test.go b/plugins/auth/plugin_test.go index a031daf8..441cf771 100644 --- a/plugins/auth/plugin_test.go +++ b/plugins/auth/plugin_test.go @@ -1,8 +1,13 @@ package auth import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" + "github.com/GoCodeAlone/workflow/module" "github.com/GoCodeAlone/workflow/plugin" ) @@ -120,3 +125,46 @@ func TestModuleSchemas(t *testing.T) { } } } + +func TestModuleFactoryM2MWithClaims(t *testing.T) { + p := New() + factories := p.ModuleFactories() + + mod := factories["auth.m2m"]("m2m-test", map[string]any{ + "algorithm": "HS256", + "secret": "this-is-a-valid-secret-32-bytes!", + "clients": []any{ + map[string]any{ + "clientId": "org-alpha", + "clientSecret": "secret-alpha", + "scopes": []any{"read"}, + "claims": map[string]any{ + "tenant_id": "alpha", + }, + }, + }, + }) + if mod == nil { + t.Fatal("auth.m2m factory returned nil") + } + + m2mMod, ok := mod.(*module.M2MAuthModule) + if !ok { + t.Fatal("expected *module.M2MAuthModule") + } + + // Issue a token via the Handle method. + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"org-alpha"}, + "client_secret": {"secret-alpha"}, + } + req := httptest.NewRequest(http.MethodPost, "/oauth/token", strings.NewReader(params.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + m2mMod.Handle(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d; body: %s", w.Code, w.Body.String()) + } +} diff --git a/schema/module_schema.go b/schema/module_schema.go index 0c685fe6..ba910c31 100644 --- a/schema/module_schema.go +++ b/schema/module_schema.go @@ -1021,10 +1021,11 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Query results as rows/count (list mode) or row/found (single mode)"}}, ConfigFields: []ConfigFieldDef{ {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of the database service (must implement DBProvider)", Placeholder: "admin-db", InheritFrom: "dependency.name"}, - {Key: "query", Label: "SQL Query", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL SELECT query (use ? for placeholders, no template expressions allowed)", Placeholder: "SELECT id, name FROM companies WHERE id = ?"}, + {Key: "query", Label: "SQL Query", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL SELECT query (use ? for placeholders). Template expressions are forbidden unless allow_dynamic_sql is true.", Placeholder: "SELECT id, name FROM companies WHERE id = ?"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for ? placeholders in query"}, {Key: "mode", Label: "Mode", Type: FieldTypeSelect, Options: []string{"list", "single"}, DefaultValue: "list", Description: "Result mode: 'list' returns rows/count, 'single' returns row/found"}, {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping (requires database.partitioned)", Placeholder: "steps.auth.tenant_id"}, + {Key: "allow_dynamic_sql", Label: "Allow Dynamic SQL", Type: FieldTypeBool, DefaultValue: "false", Description: "When true, template expressions in 'query' are resolved at runtime. Each resolved value must contain only letters, digits, underscores and hyphens to prevent SQL injection."}, }, }) @@ -1037,11 +1038,12 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Query result fields as top-level keys plus cache_hit boolean"}}, ConfigFields: []ConfigFieldDef{ {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of the database service (must implement DBProvider)", Placeholder: "db", InheritFrom: "dependency.name"}, - {Key: "query", Label: "SQL Query", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL SELECT query using $N placeholders (e.g. $1, $2); automatically converted to ? for SQLite drivers. No template expressions allowed.", Placeholder: "SELECT backend_url, settings FROM routing_config WHERE tenant_id = $1 LIMIT 1"}, + {Key: "query", Label: "SQL Query", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL SELECT query using $N placeholders (e.g. $1, $2); automatically converted to ? for SQLite drivers. Template expressions are forbidden unless allow_dynamic_sql is true.", Placeholder: "SELECT backend_url, settings FROM routing_config WHERE tenant_id = $1 LIMIT 1"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for query placeholders"}, {Key: "cache_key", Label: "Cache Key", Type: FieldTypeString, Required: true, Description: "Template-resolved key used to store/retrieve the cached result", Placeholder: "tenant_config:{{.steps.parse.headers.X-Tenant-Id}}"}, {Key: "cache_ttl", Label: "Cache TTL", Type: FieldTypeString, DefaultValue: "5m", Description: "Duration string for how long to cache the result (e.g. '5m', '30s', '1h')", Placeholder: "5m"}, {Key: "scan_fields", Label: "Scan Fields", Type: FieldTypeArray, ArrayItemType: "string", Description: "Column names to include in the output map (omit to include all columns)"}, + {Key: "allow_dynamic_sql", Label: "Allow Dynamic SQL", Type: FieldTypeBool, DefaultValue: "false", Description: "When true, template expressions in 'query' are resolved at runtime. Each resolved value must contain only letters, digits, underscores and hyphens to prevent SQL injection."}, }, }) @@ -1049,14 +1051,15 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { Type: "step.db_exec", Label: "Database Execute", Category: "pipeline", - Description: "Executes a parameterized SQL UPDATE/DELETE against a named database service. tenantKey is supported for UPDATE/DELETE but rejected for INSERT.", + Description: "Executes a parameterized SQL INSERT/UPDATE/DELETE against a named database service", Inputs: []ServiceIODef{{Name: "context", Type: "PipelineContext", Description: "Pipeline context for template parameter resolution"}}, Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Execution result with affected_rows and last_id"}}, ConfigFields: []ConfigFieldDef{ {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of the database service (must implement DBProvider)", Placeholder: "admin-db", InheritFrom: "dependency.name"}, - {Key: "query", Label: "SQL Statement", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL INSERT/UPDATE/DELETE statement (use ? for placeholders)", Placeholder: "INSERT INTO companies (id, name) VALUES (?, ?)"}, + {Key: "query", Label: "SQL Statement", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL INSERT/UPDATE/DELETE statement (use ? for placeholders). Template expressions are forbidden unless allow_dynamic_sql is true.", Placeholder: "INSERT INTO companies (id, name) VALUES (?, ?)"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for ? placeholders"}, {Key: "tenantKey", Label: "Tenant Key", Type: FieldTypeString, Description: "Dot-path in pipeline context to resolve the tenant value for automatic scoping. Supported for UPDATE/DELETE only (requires database.partitioned)", Placeholder: "steps.auth.tenant_id"}, + {Key: "allow_dynamic_sql", Label: "Allow Dynamic SQL", Type: FieldTypeBool, DefaultValue: "false", Description: "When true, template expressions in 'query' are resolved at runtime. Each resolved value must contain only letters, digits, underscores and hyphens to prevent SQL injection."}, }, })