Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ flowchart TD
| `step.delegate` | Delegates to a named service | pipelinesteps |
| `step.request_parse` | Extracts path params, query params, and request body from HTTP requests | pipelinesteps |
| `step.db_query` | Executes parameterized SQL SELECT queries against a named database | pipelinesteps |
| `step.db_exec` | Executes parameterized SQL INSERT/UPDATE/DELETE against a named database | pipelinesteps |
| `step.db_exec` | Executes parameterized SQL INSERT/UPDATE/DELETE against a named database. Supports `returning: true` with `mode: single` or `mode: list` to capture rows from a `RETURNING` clause | pipelinesteps |
| `step.db_query_cached` | Executes a cached SQL SELECT query | pipelinesteps |
| `step.db_create_partition` | Creates a time-based table partition | pipelinesteps |
| `step.db_sync_partitions` | Ensures future partitions exist for a partitioned table | pipelinesteps |
Expand Down
46 changes: 44 additions & 2 deletions module/pipeline_step_db_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type DBExecStep struct {
ignoreError bool
tenantKey string // dot-path to resolve tenant value for automatic scoping
allowDynamicSQL bool
returning bool // when true, uses Query() and returns rows (for RETURNING clause)
mode string // "list" or "single" — used only when returning is true
app modular.Application
tmpl *TemplateEngine
}
Expand Down Expand Up @@ -54,6 +56,17 @@ func NewDBExecStepFactory() StepFactory {

ignoreError, _ := config["ignore_error"].(bool)
tenantKey, _ := config["tenantKey"].(string)
returning, _ := config["returning"].(bool)

mode, _ := config["mode"].(string)
if returning {
if mode == "" {
mode = "list"
}
if mode != "list" && mode != "single" {
return nil, fmt.Errorf("db_exec step %q: mode must be 'list' or 'single', got %q", name, mode)
}
}

return &DBExecStep{
name: name,
Expand All @@ -63,6 +76,8 @@ func NewDBExecStepFactory() StepFactory {
ignoreError: ignoreError,
tenantKey: tenantKey,
allowDynamicSQL: allowDynamicSQL,
returning: returning,
mode: mode,
app: app,
tmpl: NewTemplateEngine(),
}, nil
Expand All @@ -71,7 +86,7 @@ func NewDBExecStepFactory() StepFactory {

func (s *DBExecStep) Name() string { return s.name }

func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) {
func (s *DBExecStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) {
// Resolve template expressions in the query early (before any DB access) when
// dynamic SQL is enabled. This validates resolved identifiers against an
// allowlist before any database interaction.
Expand Down Expand Up @@ -148,8 +163,35 @@ func (s *DBExecStep) Execute(_ context.Context, pc *PipelineContext) (*StepResul
// engine converts to ? for SQLite automatically.
query = normalizePlaceholders(query, driver)

// When returning is true, use QueryContext() so that RETURNING clause rows are available.
if s.returning {
rows, err := db.QueryContext(ctx, query, resolvedParams...)
if err != nil {
if s.ignoreError {
output := map[string]any{"ignored_error": err.Error()}
if s.mode == "single" {
output["row"] = map[string]any{}
output["found"] = false
} else {
output["rows"] = []map[string]any{}
output["count"] = 0
}
return &StepResult{Output: output}, nil
}
return nil, fmt.Errorf("db_exec step %q: query failed: %w", s.name, err)
}
defer rows.Close()

results, err := scanSQLRows(rows)
if err != nil {
return nil, fmt.Errorf("db_exec step %q: %w", s.name, err)
}

return &StepResult{Output: formatQueryOutput(results, s.mode)}, nil
}

// Execute statement
result, err := db.Exec(query, resolvedParams...)
result, err := db.ExecContext(ctx, query, resolvedParams...)
if err != nil {
if s.ignoreError {
return &StepResult{Output: map[string]any{
Expand Down
123 changes: 123 additions & 0 deletions module/pipeline_step_db_exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,126 @@ func TestDBExecStep_PostgresPlaceholdersOnSQLite(t *testing.T) {
t.Errorf("expected name='PostgresStyleWidget', got %q", name)
}
}

// TestDBExecStep_Returning_SingleMode verifies that returning:true with mode:single
// uses Query() and returns the first row via RETURNING clause.
func TestDBExecStep_Returning_SingleMode(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()

_, err = db.Exec(`CREATE TABLE items (id TEXT PRIMARY KEY, name TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT 'now')`)
if err != nil {
t.Fatalf("create table: %v", err)
}

app := mockAppWithDB("test-db", db)
factory := NewDBExecStepFactory()
step, err := factory("insert-returning", map[string]any{
"database": "test-db",
"query": "INSERT INTO items (id, name) VALUES (?, ?) RETURNING id, name",
"params": []any{"r1", "ReturnedItem"},
"returning": true,
"mode": "single",
}, app)
if err != nil {
t.Fatalf("factory error: %v", err)
}

pc := NewPipelineContext(nil, nil)
result, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("execute error: %v", err)
}

found, _ := result.Output["found"].(bool)
if !found {
t.Error("expected found=true")
}
row, ok := result.Output["row"].(map[string]any)
if !ok {
t.Fatal("expected 'row' in output")
}
if row["id"] != "r1" {
t.Errorf("expected row.id='r1', got %v", row["id"])
}
if row["name"] != "ReturnedItem" {
t.Errorf("expected row.name='ReturnedItem', got %v", row["name"])
}
}

// TestDBExecStep_Returning_ListMode verifies that returning:true with mode:list (default)
// returns all affected rows via RETURNING clause.
func TestDBExecStep_Returning_ListMode(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()

_, err = db.Exec(`
CREATE TABLE items (id TEXT PRIMARY KEY, name TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'active');
INSERT INTO items (id, name) VALUES ('i1', 'Alpha');
INSERT INTO items (id, name) VALUES ('i2', 'Beta');
INSERT INTO items (id, name) VALUES ('i3', 'Gamma');
`)
if err != nil {
t.Fatalf("setup: %v", err)
}

app := mockAppWithDB("test-db", db)
factory := NewDBExecStepFactory()
step, err := factory("update-returning", map[string]any{
"database": "test-db",
"query": "UPDATE items SET status = ? RETURNING id, name, status",
"params": []any{"archived"},
"returning": true,
}, app)
if err != nil {
t.Fatalf("factory error: %v", err)
}

pc := NewPipelineContext(nil, nil)
result, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("execute error: %v", err)
}

count, _ := result.Output["count"].(int)
if count != 3 {
t.Errorf("expected count=3, got %v", result.Output["count"])
}
rows, ok := result.Output["rows"].([]map[string]any)
if !ok {
t.Fatal("expected 'rows' in output")
}
if len(rows) != 3 {
t.Errorf("expected 3 rows, got %d", len(rows))
}
// Verify status was updated
for _, row := range rows {
if row["status"] != "archived" {
t.Errorf("expected status='archived', got %v", row["status"])
}
}
}

// TestDBExecStep_Returning_InvalidMode verifies that an invalid mode is rejected at factory time.
func TestDBExecStep_Returning_InvalidMode(t *testing.T) {
factory := NewDBExecStepFactory()
_, err := factory("bad-mode", map[string]any{
"database": "test-db",
"query": "INSERT INTO x VALUES (?) RETURNING id",
"params": []any{"1"},
"returning": true,
"mode": "invalid",
}, nil)
if err == nil {
t.Fatal("expected error for invalid mode")
}
if !strings.Contains(err.Error(), "mode must be") {
t.Errorf("expected 'mode must be' in error, got: %v", err)
}
}
73 changes: 73 additions & 0 deletions module/pipeline_step_db_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package module

import (
"database/sql"
"fmt"
)

// scanSQLRows iterates over rows and returns a slice of column→value maps.
// []byte values are decoded via parseJSONBytesOrString, which transparently
// handles PostgreSQL json/jsonb columns (returned as raw JSON bytes by pgx)
// and falls back to string conversion for binary data (e.g. bytea). Callers
// are responsible for closing rows after this function returns.
func scanSQLRows(rows *sql.Rows) ([]map[string]any, error) {
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %w", err)
}

var results []map[string]any
for rows.Next() {
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}

if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("scan failed: %w", err)
}

row := make(map[string]any, len(columns))
for i, col := range columns {
val := values[i]
// Convert []byte: try JSON parse first (handles PostgreSQL json/jsonb
// column types returned by the pgx driver as raw JSON bytes), then
// fall back to string conversion for non-JSON byte data (e.g. bytea).
if b, ok := val.([]byte); ok {
row[col] = parseJSONBytesOrString(b)
} else {
row[col] = val
}
}
results = append(results, row)
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration error: %w", err)
}

return results, nil
}

// formatQueryOutput builds the standard step output map for query results.
// mode "single" returns {row, found}; any other mode returns {rows, count}.
func formatQueryOutput(results []map[string]any, mode string) map[string]any {
output := make(map[string]any)
if mode == "single" {
if len(results) > 0 {
output["row"] = results[0]
output["found"] = true
} else {
output["row"] = map[string]any{}
output["found"] = false
}
} else {
if results == nil {
results = []map[string]any{}
}
output["rows"] = results
output["count"] = len(results)
}
return output
}
58 changes: 5 additions & 53 deletions module/pipeline_step_db_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func NewDBQueryStepFactory() StepFactory {

func (s *DBQueryStep) Name() string { return s.name }

func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) {
func (s *DBQueryStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) {
// Resolve template expressions in the query early (before any DB access) when
// dynamic SQL is enabled. This validates resolved identifiers against an
// allowlist before any database interaction.
Expand Down Expand Up @@ -167,64 +167,16 @@ func (s *DBQueryStep) Execute(_ context.Context, pc *PipelineContext) (*StepResu
query = normalizePlaceholders(query, driver)

// Execute query
rows, err := db.Query(query, resolvedParams...)
rows, err := db.QueryContext(ctx, query, resolvedParams...)
if err != nil {
return nil, fmt.Errorf("db_query step %q: query failed: %w", s.name, err)
}
defer rows.Close()

columns, err := rows.Columns()
results, err := scanSQLRows(rows)
if err != nil {
return nil, fmt.Errorf("db_query step %q: failed to get columns: %w", s.name, err)
return nil, fmt.Errorf("db_query step %q: %w", s.name, err)
}

var results []map[string]any
for rows.Next() {
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}

if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("db_query step %q: scan failed: %w", s.name, err)
}

row := make(map[string]any, len(columns))
for i, col := range columns {
val := values[i]
// Convert []byte: try JSON parse first (handles PostgreSQL json/jsonb
// column types returned by the pgx driver as raw JSON bytes), then
// fall back to string conversion for non-JSON byte data (e.g. bytea).
if b, ok := val.([]byte); ok {
row[col] = parseJSONBytesOrString(b)
} else {
row[col] = val
}
}
results = append(results, row)
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("db_query step %q: row iteration error: %w", s.name, err)
}

output := make(map[string]any)
if s.mode == "single" {
if len(results) > 0 {
output["row"] = results[0]
output["found"] = true
} else {
output["row"] = map[string]any{}
output["found"] = false
}
} else {
if results == nil {
results = []map[string]any{}
}
output["rows"] = results
output["count"] = len(results)
}

return &StepResult{Output: output}, nil
return &StepResult{Output: formatQueryOutput(results, s.mode)}, nil
}
Loading