diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 6cbc8fb9..b19802f9 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -147,7 +147,7 @@ flowchart TD | `step.delegate` | Delegates to a named service | pipelinesteps | | `step.request_parse` | Extracts path params, query params, and request body from HTTP requests | pipelinesteps | | `step.db_query` | Executes parameterized SQL SELECT queries against a named database | pipelinesteps | -| `step.db_exec` | Executes parameterized SQL INSERT/UPDATE/DELETE against a named database | pipelinesteps | +| `step.db_exec` | Executes parameterized SQL INSERT/UPDATE/DELETE against a named database. Supports `returning: true` with `mode: single` or `mode: list` to capture rows from a `RETURNING` clause | pipelinesteps | | `step.db_query_cached` | Executes a cached SQL SELECT query | pipelinesteps | | `step.db_create_partition` | Creates a time-based table partition | pipelinesteps | | `step.db_sync_partitions` | Ensures future partitions exist for a partitioned table | pipelinesteps | diff --git a/module/pipeline_step_db_exec.go b/module/pipeline_step_db_exec.go index 5bee9a90..c44b043b 100644 --- a/module/pipeline_step_db_exec.go +++ b/module/pipeline_step_db_exec.go @@ -17,6 +17,8 @@ type DBExecStep struct { ignoreError bool tenantKey string // dot-path to resolve tenant value for automatic scoping allowDynamicSQL bool + returning bool // when true, uses Query() and returns rows (for RETURNING clause) + mode string // "list" or "single" — used only when returning is true app modular.Application tmpl *TemplateEngine } @@ -54,6 +56,17 @@ func NewDBExecStepFactory() StepFactory { ignoreError, _ := config["ignore_error"].(bool) tenantKey, _ := config["tenantKey"].(string) + returning, _ := config["returning"].(bool) + + mode, _ := config["mode"].(string) + if returning { + if mode == "" { + mode = "list" + } + if mode != "list" && mode != "single" { + return nil, fmt.Errorf("db_exec step %q: mode must be 'list' or 'single', got %q", name, mode) + } + } return &DBExecStep{ name: name, @@ -63,6 +76,8 @@ func NewDBExecStepFactory() StepFactory { ignoreError: ignoreError, tenantKey: tenantKey, allowDynamicSQL: allowDynamicSQL, + returning: returning, + mode: mode, app: app, tmpl: NewTemplateEngine(), }, nil @@ -71,7 +86,7 @@ func NewDBExecStepFactory() StepFactory { func (s *DBExecStep) Name() string { return s.name } -func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { +func (s *DBExecStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { // Resolve template expressions in the query early (before any DB access) when // dynamic SQL is enabled. This validates resolved identifiers against an // allowlist before any database interaction. @@ -148,8 +163,35 @@ func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResul // engine converts to ? for SQLite automatically. query = normalizePlaceholders(query, driver) + // When returning is true, use QueryContext() so that RETURNING clause rows are available. + if s.returning { + rows, err := db.QueryContext(ctx, query, resolvedParams...) + if err != nil { + if s.ignoreError { + output := map[string]any{"ignored_error": err.Error()} + if s.mode == "single" { + output["row"] = map[string]any{} + output["found"] = false + } else { + output["rows"] = []map[string]any{} + output["count"] = 0 + } + return &StepResult{Output: output}, nil + } + return nil, fmt.Errorf("db_exec step %q: query failed: %w", s.name, err) + } + defer rows.Close() + + results, err := scanSQLRows(rows) + if err != nil { + return nil, fmt.Errorf("db_exec step %q: %w", s.name, err) + } + + return &StepResult{Output: formatQueryOutput(results, s.mode)}, nil + } + // Execute statement - result, err := db.Exec(query, resolvedParams...) + result, err := db.ExecContext(ctx, query, resolvedParams...) if err != nil { if s.ignoreError { return &StepResult{Output: map[string]any{ diff --git a/module/pipeline_step_db_exec_test.go b/module/pipeline_step_db_exec_test.go index 77dc5483..18542bae 100644 --- a/module/pipeline_step_db_exec_test.go +++ b/module/pipeline_step_db_exec_test.go @@ -314,3 +314,126 @@ func TestDBExecStep_PostgresPlaceholdersOnSQLite(t *testing.T) { t.Errorf("expected name='PostgresStyleWidget', got %q", name) } } + +// TestDBExecStep_Returning_SingleMode verifies that returning:true with mode:single +// uses Query() and returns the first row via RETURNING clause. +func TestDBExecStep_Returning_SingleMode(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open db: %v", err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE items (id TEXT PRIMARY KEY, name TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT 'now')`) + if err != nil { + t.Fatalf("create table: %v", err) + } + + app := mockAppWithDB("test-db", db) + factory := NewDBExecStepFactory() + step, err := factory("insert-returning", map[string]any{ + "database": "test-db", + "query": "INSERT INTO items (id, name) VALUES (?, ?) RETURNING id, name", + "params": []any{"r1", "ReturnedItem"}, + "returning": true, + "mode": "single", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + found, _ := result.Output["found"].(bool) + if !found { + t.Error("expected found=true") + } + row, ok := result.Output["row"].(map[string]any) + if !ok { + t.Fatal("expected 'row' in output") + } + if row["id"] != "r1" { + t.Errorf("expected row.id='r1', got %v", row["id"]) + } + if row["name"] != "ReturnedItem" { + t.Errorf("expected row.name='ReturnedItem', got %v", row["name"]) + } +} + +// TestDBExecStep_Returning_ListMode verifies that returning:true with mode:list (default) +// returns all affected rows via RETURNING clause. +func TestDBExecStep_Returning_ListMode(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open db: %v", err) + } + defer db.Close() + + _, err = db.Exec(` + CREATE TABLE items (id TEXT PRIMARY KEY, name TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'active'); + INSERT INTO items (id, name) VALUES ('i1', 'Alpha'); + INSERT INTO items (id, name) VALUES ('i2', 'Beta'); + INSERT INTO items (id, name) VALUES ('i3', 'Gamma'); + `) + if err != nil { + t.Fatalf("setup: %v", err) + } + + app := mockAppWithDB("test-db", db) + factory := NewDBExecStepFactory() + step, err := factory("update-returning", map[string]any{ + "database": "test-db", + "query": "UPDATE items SET status = ? RETURNING id, name, status", + "params": []any{"archived"}, + "returning": true, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + count, _ := result.Output["count"].(int) + if count != 3 { + t.Errorf("expected count=3, got %v", result.Output["count"]) + } + rows, ok := result.Output["rows"].([]map[string]any) + if !ok { + t.Fatal("expected 'rows' in output") + } + if len(rows) != 3 { + t.Errorf("expected 3 rows, got %d", len(rows)) + } + // Verify status was updated + for _, row := range rows { + if row["status"] != "archived" { + t.Errorf("expected status='archived', got %v", row["status"]) + } + } +} + +// TestDBExecStep_Returning_InvalidMode verifies that an invalid mode is rejected at factory time. +func TestDBExecStep_Returning_InvalidMode(t *testing.T) { + factory := NewDBExecStepFactory() + _, err := factory("bad-mode", map[string]any{ + "database": "test-db", + "query": "INSERT INTO x VALUES (?) RETURNING id", + "params": []any{"1"}, + "returning": true, + "mode": "invalid", + }, nil) + if err == nil { + t.Fatal("expected error for invalid mode") + } + if !strings.Contains(err.Error(), "mode must be") { + t.Errorf("expected 'mode must be' in error, got: %v", err) + } +} diff --git a/module/pipeline_step_db_helpers.go b/module/pipeline_step_db_helpers.go new file mode 100644 index 00000000..3c929d93 --- /dev/null +++ b/module/pipeline_step_db_helpers.go @@ -0,0 +1,73 @@ +package module + +import ( + "database/sql" + "fmt" +) + +// scanSQLRows iterates over rows and returns a slice of column→value maps. +// []byte values are decoded via parseJSONBytesOrString, which transparently +// handles PostgreSQL json/jsonb columns (returned as raw JSON bytes by pgx) +// and falls back to string conversion for binary data (e.g. bytea). Callers +// are responsible for closing rows after this function returns. +func scanSQLRows(rows *sql.Rows) ([]map[string]any, error) { + columns, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("failed to get columns: %w", err) + } + + var results []map[string]any + for rows.Next() { + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return nil, fmt.Errorf("scan failed: %w", err) + } + + row := make(map[string]any, len(columns)) + for i, col := range columns { + val := values[i] + // Convert []byte: try JSON parse first (handles PostgreSQL json/jsonb + // column types returned by the pgx driver as raw JSON bytes), then + // fall back to string conversion for non-JSON byte data (e.g. bytea). + if b, ok := val.([]byte); ok { + row[col] = parseJSONBytesOrString(b) + } else { + row[col] = val + } + } + results = append(results, row) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("row iteration error: %w", err) + } + + return results, nil +} + +// formatQueryOutput builds the standard step output map for query results. +// mode "single" returns {row, found}; any other mode returns {rows, count}. +func formatQueryOutput(results []map[string]any, mode string) map[string]any { + output := make(map[string]any) + if mode == "single" { + if len(results) > 0 { + output["row"] = results[0] + output["found"] = true + } else { + output["row"] = map[string]any{} + output["found"] = false + } + } else { + if results == nil { + results = []map[string]any{} + } + output["rows"] = results + output["count"] = len(results) + } + return output +} diff --git a/module/pipeline_step_db_query.go b/module/pipeline_step_db_query.go index cd94ff64..859f54a3 100644 --- a/module/pipeline_step_db_query.go +++ b/module/pipeline_step_db_query.go @@ -93,7 +93,7 @@ func NewDBQueryStepFactory() StepFactory { func (s *DBQueryStep) Name() string { return s.name } -func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { +func (s *DBQueryStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { // Resolve template expressions in the query early (before any DB access) when // dynamic SQL is enabled. This validates resolved identifiers against an // allowlist before any database interaction. @@ -167,64 +167,16 @@ func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResu query = normalizePlaceholders(query, driver) // Execute query - rows, err := db.Query(query, resolvedParams...) + rows, err := db.QueryContext(ctx, query, resolvedParams...) if err != nil { return nil, fmt.Errorf("db_query step %q: query failed: %w", s.name, err) } defer rows.Close() - columns, err := rows.Columns() + results, err := scanSQLRows(rows) if err != nil { - return nil, fmt.Errorf("db_query step %q: failed to get columns: %w", s.name, err) + return nil, fmt.Errorf("db_query step %q: %w", s.name, err) } - var results []map[string]any - for rows.Next() { - values := make([]any, len(columns)) - valuePtrs := make([]any, len(columns)) - for i := range values { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - return nil, fmt.Errorf("db_query step %q: scan failed: %w", s.name, err) - } - - row := make(map[string]any, len(columns)) - for i, col := range columns { - val := values[i] - // Convert []byte: try JSON parse first (handles PostgreSQL json/jsonb - // column types returned by the pgx driver as raw JSON bytes), then - // fall back to string conversion for non-JSON byte data (e.g. bytea). - if b, ok := val.([]byte); ok { - row[col] = parseJSONBytesOrString(b) - } else { - row[col] = val - } - } - results = append(results, row) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("db_query step %q: row iteration error: %w", s.name, err) - } - - output := make(map[string]any) - if s.mode == "single" { - if len(results) > 0 { - output["row"] = results[0] - output["found"] = true - } else { - output["row"] = map[string]any{} - output["found"] = false - } - } else { - if results == nil { - results = []map[string]any{} - } - output["rows"] = results - output["count"] = len(results) - } - - return &StepResult{Output: output}, nil + return &StepResult{Output: formatQueryOutput(results, s.mode)}, nil }