diff --git a/module/pipeline_step_db_dynamic.go b/module/pipeline_step_db_dynamic.go index 0175022c..aa93495a 100644 --- a/module/pipeline_step_db_dynamic.go +++ b/module/pipeline_step_db_dynamic.go @@ -7,8 +7,8 @@ import ( // validateSQLIdentifier checks that s is safe to interpolate directly into SQL as an // identifier (e.g. a table name). Only ASCII letters (A-Z, a-z), ASCII digits (0-9), -// underscores (_) and hyphens (-) are permitted. This strict allowlist prevents SQL -// injection when dynamic values are embedded in queries via allow_dynamic_sql. +// and underscores (_) are permitted. This strict allowlist prevents SQL injection +// when dynamic values are embedded in queries via allow_dynamic_sql. func validateSQLIdentifier(s string) error { if s == "" { return fmt.Errorf("dynamic SQL identifier must not be empty") @@ -17,8 +17,8 @@ func validateSQLIdentifier(s string) error { if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') && - c != '_' && c != '-' { - return fmt.Errorf("dynamic SQL identifier %q contains unsafe character %q (only ASCII letters, digits, underscores and hyphens are allowed)", s, string(c)) + c != '_' { + return fmt.Errorf("dynamic SQL identifier %q contains unsafe character %q (only ASCII letters, digits, and underscores are allowed)", s, string(c)) } } return nil diff --git a/module/pipeline_step_db_query_cached.go b/module/pipeline_step_db_query_cached.go index 400f1dab..88385392 100644 --- a/module/pipeline_step_db_query_cached.go +++ b/module/pipeline_step_db_query_cached.go @@ -24,6 +24,7 @@ type DBQueryCachedStep struct { database string query string params []string + mode string // "single" or "list" cacheKey string cacheTTL time.Duration scanFields []string @@ -83,6 +84,15 @@ func NewDBQueryCachedStepFactory() StepFactory { } } + mode, _ := config["mode"].(string) + mode = strings.TrimSpace(mode) + // Backwards compatibility: if mode is omitted or empty, leave it blank. + // Execution logic treats a blank mode as the legacy single-row flat + // output, while explicit "single" uses the new row/found envelope. + if mode != "" && mode != "list" && mode != "single" { + return nil, fmt.Errorf("db_query_cached step %q: mode must be 'list' or 'single', got %q", name, mode) + } + var scanFields []string if sf, ok := config["scan_fields"]; ok { if list, ok := sf.([]any); ok { @@ -99,6 +109,7 @@ func NewDBQueryCachedStepFactory() StepFactory { database: database, query: query, params: params, + mode: mode, cacheKey: cacheKey, cacheTTL: cacheTTL, scanFields: scanFields, @@ -145,7 +156,7 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (* s.mu.RUnlock() if found && time.Now().Before(entry.expiresAt) { - output := copyMap(entry.value) + output := deepCopyMap(entry.value) output["cache_hit"] = true return &StepResult{Output: output}, nil } @@ -155,7 +166,7 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (* entry, found = s.cache[key] if found && time.Now().Before(entry.expiresAt) { // Another goroutine populated the cache while we were waiting for the lock - output := copyMap(entry.value) + output := deepCopyMap(entry.value) s.mu.Unlock() output["cache_hit"] = true return &StepResult{Output: output}, nil @@ -175,7 +186,7 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (* // Store in cache (write lock) s.mu.Lock() s.cache[key] = dbQueryCacheEntry{ - value: copyMap(result), + value: deepCopyMap(result), expiresAt: time.Now().Add(s.cacheTTL), } s.mu.Unlock() @@ -236,7 +247,7 @@ func (s *DBQueryCachedStep) runQuery(ctx context.Context, pc *PipelineContext, q fieldSet[f] = true } - output := make(map[string]any) + var results []map[string]any for rows.Next() { values := make([]any, len(columns)) valuePtrs := make([]any, len(columns)) @@ -248,33 +259,74 @@ func (s *DBQueryCachedStep) runQuery(ctx context.Context, pc *PipelineContext, q return nil, fmt.Errorf("db_query_cached step %q: scan failed: %w", s.name, err) } + row := make(map[string]any, len(columns)) for i, col := range columns { if len(fieldSet) > 0 && !fieldSet[col] { continue } val := values[i] if b, ok := val.([]byte); ok { - output[col] = string(b) + row[col] = string(b) } else { - output[col] = val + row[col] = val } } - // Only take the first row - break + results = append(results, row) + if s.mode != "list" { + // Only take the first row for single and legacy modes + break + } } if err := rows.Err(); err != nil { return nil, fmt.Errorf("db_query_cached step %q: row iteration error: %w", s.name, err) } + output := make(map[string]any) + switch s.mode { + case "list": + if results == nil { + results = []map[string]any{} + } + output["rows"] = results + output["count"] = len(results) + case "single": + if len(results) > 0 { + output["row"] = results[0] + output["found"] = true + } else { + output["row"] = map[string]any{} + output["found"] = false + } + default: // "" — legacy flat column map (backward compatible) + if len(results) > 0 { + for k, v := range results[0] { + output[k] = v + } + } + } + return output, nil } -// copyMap creates a shallow copy of a map. -func copyMap(m map[string]any) map[string]any { +// deepCopyMap creates a deep copy of a map, recursively copying nested +// map[string]any values and []map[string]any slices to prevent callers from +// mutating cached data or triggering data races across goroutines. +func deepCopyMap(m map[string]any) map[string]any { cp := make(map[string]any, len(m)) for k, v := range m { - cp[k] = v + switch val := v.(type) { + case map[string]any: + cp[k] = deepCopyMap(val) + case []map[string]any: + sliceCopy := make([]map[string]any, len(val)) + for i, row := range val { + sliceCopy[i] = deepCopyMap(row) + } + cp[k] = sliceCopy + default: + cp[k] = v + } } return cp } diff --git a/module/pipeline_step_db_query_cached_test.go b/module/pipeline_step_db_query_cached_test.go index 14679f83..4dee217c 100644 --- a/module/pipeline_step_db_query_cached_test.go +++ b/module/pipeline_step_db_query_cached_test.go @@ -293,7 +293,7 @@ func TestDBQueryCachedStep_NoRows(t *testing.T) { t.Fatalf("execute error: %v", err) } - // No rows means empty output (no id/name keys), cache_hit=false + // No rows in legacy mode means empty output (no column keys), cache_hit=false if result.Output["cache_hit"] != false { t.Errorf("expected cache_hit=false, got %v", result.Output["cache_hit"]) } @@ -407,6 +407,218 @@ func TestDBQueryCachedStep_NegativeTTLRejected(t *testing.T) { } } +// TestDBQueryCachedStep_InvalidModeRejected verifies that an unknown mode is rejected. +func TestDBQueryCachedStep_InvalidModeRejected(t *testing.T) { + factory := NewDBQueryCachedStepFactory() + _, err := factory("bad", map[string]any{ + "database": "db", + "query": "SELECT 1", + "cache_key": "k", + "mode": "bulk", + }, nil) + if err == nil { + t.Fatal("expected error for invalid mode") + } +} + +// TestDBQueryCachedStep_ListMode verifies that mode: list returns rows/count format. +func TestDBQueryCachedStep_ListMode(t *testing.T) { + db := setupTestDB(t) + app := mockAppWithDB("test-db", db) + + factory := NewDBQueryCachedStepFactory() + step, err := factory("list-companies", map[string]any{ + "database": "test-db", + "query": "SELECT id, name, slug FROM companies WHERE parent_id IS NULL ORDER BY name", + "mode": "list", + "cache_key": "companies:list", + "cache_ttl": "5m", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["cache_hit"] != false { + t.Errorf("expected cache_hit=false on first call, got %v", result.Output["cache_hit"]) + } + rows, ok := result.Output["rows"].([]map[string]any) + if !ok { + t.Fatal("expected rows in output for list mode") + } + count, ok := result.Output["count"].(int) + if !ok { + t.Fatal("expected count in output for list mode") + } + if count != 2 { + t.Errorf("expected count=2, got %d", count) + } + if len(rows) != 2 { + t.Errorf("expected 2 rows, got %d", len(rows)) + } + if rows[0]["name"] != "Acme Corp" { + t.Errorf("expected first row name='Acme Corp', got %v", rows[0]["name"]) + } +} + +// TestDBQueryCachedStep_ListModeCacheHit verifies that list mode results are cached and returned correctly. +func TestDBQueryCachedStep_ListModeCacheHit(t *testing.T) { + db := setupTestDB(t) + app := mockAppWithDB("test-db", db) + + factory := NewDBQueryCachedStepFactory() + step, err := factory("list-companies", map[string]any{ + "database": "test-db", + "query": "SELECT id, name FROM companies WHERE parent_id IS NULL ORDER BY name", + "mode": "list", + "cache_key": "companies:list", + "cache_ttl": "5m", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + + // First call — cache miss + first, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("first execute error: %v", err) + } + if first.Output["cache_hit"] != false { + t.Errorf("expected cache_hit=false on first call") + } + + // Second call — cache hit + second, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("second execute error: %v", err) + } + if second.Output["cache_hit"] != true { + t.Errorf("expected cache_hit=true on second call, got %v", second.Output["cache_hit"]) + } + rows, ok := second.Output["rows"].([]map[string]any) + if !ok { + t.Fatal("expected rows in cached output") + } + if len(rows) != 2 { + t.Errorf("expected 2 rows from cache, got %d", len(rows)) + } +} + +// TestDBQueryCachedStep_ListModeEmpty verifies that list mode returns an empty rows slice when no rows match. +func TestDBQueryCachedStep_ListModeEmpty(t *testing.T) { + db := setupTestDB(t) + app := mockAppWithDB("test-db", db) + + factory := NewDBQueryCachedStepFactory() + step, err := factory("list-empty", map[string]any{ + "database": "test-db", + "query": "SELECT id, name FROM companies WHERE id = ?", + "params": []any{"nonexistent"}, + "mode": "list", + "cache_key": "companies:empty", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + rows, ok := result.Output["rows"].([]map[string]any) + if !ok { + t.Fatal("expected rows in output for list mode even when empty") + } + if len(rows) != 0 { + t.Errorf("expected 0 rows, got %d", len(rows)) + } + count, _ := result.Output["count"].(int) + if count != 0 { + t.Errorf("expected count=0, got %d", count) + } +} + +// TestDBQueryCachedStep_SingleModeFound verifies that mode: single returns row/found format when a row is found. +func TestDBQueryCachedStep_SingleModeFound(t *testing.T) { + db := setupTestDB(t) + app := mockAppWithDB("test-db", db) + + factory := NewDBQueryCachedStepFactory() + step, err := factory("get-company", map[string]any{ + "database": "test-db", + "query": "SELECT id, name FROM companies WHERE id = ?", + "params": []any{"c1"}, + "mode": "single", + "cache_key": "company:c1", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + found, _ := result.Output["found"].(bool) + if !found { + t.Error("expected found=true") + } + row, ok := result.Output["row"].(map[string]any) + if !ok { + t.Fatal("expected row in output") + } + if row["name"] != "Acme Corp" { + t.Errorf("expected name='Acme Corp', got %v", row["name"]) + } +} + +// TestDBQueryCachedStep_SingleModeNotFound verifies that mode: single returns row={}/found=false when no row matches. +func TestDBQueryCachedStep_SingleModeNotFound(t *testing.T) { + db := setupTestDB(t) + app := mockAppWithDB("test-db", db) + + factory := NewDBQueryCachedStepFactory() + step, err := factory("get-missing", map[string]any{ + "database": "test-db", + "query": "SELECT id, name FROM companies WHERE id = ?", + "params": []any{"nonexistent"}, + "mode": "single", + "cache_key": "company:nonexistent", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + found, _ := result.Output["found"].(bool) + if found { + t.Error("expected found=false") + } + row, ok := result.Output["row"].(map[string]any) + if !ok { + t.Fatal("expected row in output even when not found") + } + if len(row) != 0 { + t.Errorf("expected empty row map, got %v", row) + } +} + func TestDBQueryCachedStep_DynamicTableName(t *testing.T) { db := setupTestDB(t) _, err := db.Exec(`CREATE TABLE companies_beta (id TEXT PRIMARY KEY, name TEXT NOT NULL)`) diff --git a/schema/module_schema.go b/schema/module_schema.go index ba910c31..aec88db2 100644 --- a/schema/module_schema.go +++ b/schema/module_schema.go @@ -1035,15 +1035,16 @@ func (r *ModuleSchemaRegistry) registerBuiltins() { Category: "pipeline", Description: "Executes a parameterized SQL SELECT query and caches the result in-process with TTL. On subsequent calls the cached value is returned until the TTL expires.", Inputs: []ServiceIODef{{Name: "context", Type: "PipelineContext", Description: "Pipeline context for template parameter and cache key resolution"}}, - Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Query result fields as top-level keys plus cache_hit boolean"}}, + Outputs: []ServiceIODef{{Name: "result", Type: "StepResult", Description: "Query results as rows/count (list mode) or row/found (single mode), plus cache_hit boolean"}}, ConfigFields: []ConfigFieldDef{ {Key: "database", Label: "Database", Type: FieldTypeString, Required: true, Description: "Name of the database service (must implement DBProvider)", Placeholder: "db", InheritFrom: "dependency.name"}, {Key: "query", Label: "SQL Query", Type: FieldTypeSQL, Required: true, Description: "Parameterized SQL SELECT query using $N placeholders (e.g. $1, $2); automatically converted to ? for SQLite drivers. Template expressions are forbidden unless allow_dynamic_sql is true.", Placeholder: "SELECT backend_url, settings FROM routing_config WHERE tenant_id = $1 LIMIT 1"}, {Key: "params", Label: "Parameters", Type: FieldTypeArray, ArrayItemType: "string", Description: "Template-resolved parameter values for query placeholders"}, + {Key: "mode", Label: "Mode", Type: FieldTypeSelect, Options: []string{"single", "list"}, DefaultValue: "single", Description: "Result mode: 'single' returns row/found, 'list' returns rows/count"}, {Key: "cache_key", Label: "Cache Key", Type: FieldTypeString, Required: true, Description: "Template-resolved key used to store/retrieve the cached result", Placeholder: "tenant_config:{{.steps.parse.headers.X-Tenant-Id}}"}, {Key: "cache_ttl", Label: "Cache TTL", Type: FieldTypeString, DefaultValue: "5m", Description: "Duration string for how long to cache the result (e.g. '5m', '30s', '1h')", Placeholder: "5m"}, - {Key: "scan_fields", Label: "Scan Fields", Type: FieldTypeArray, ArrayItemType: "string", Description: "Column names to include in the output map (omit to include all columns)"}, - {Key: "allow_dynamic_sql", Label: "Allow Dynamic SQL", Type: FieldTypeBool, DefaultValue: "false", Description: "When true, template expressions in 'query' are resolved at runtime. Each resolved value must contain only letters, digits, underscores and hyphens to prevent SQL injection."}, + {Key: "scan_fields", Label: "Scan Fields", Type: FieldTypeArray, ArrayItemType: "string", Description: "Column names to include in the output (omit to include all columns)"}, + {Key: "allow_dynamic_sql", Label: "Allow Dynamic SQL", Type: FieldTypeBool, DefaultValue: "false", Description: "When true, template expressions in 'query' are resolved at runtime. Each resolved value must contain only letters, digits, and underscores to prevent SQL injection."}, }, })