Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions module/pipeline_step_db_dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (

// validateSQLIdentifier checks that s is safe to interpolate directly into SQL as an
// identifier (e.g. a table name). Only ASCII letters (A-Z, a-z), ASCII digits (0-9),
// underscores (_) and hyphens (-) are permitted. This strict allowlist prevents SQL
// injection when dynamic values are embedded in queries via allow_dynamic_sql.
// and underscores (_) are permitted. This strict allowlist prevents SQL injection
// when dynamic values are embedded in queries via allow_dynamic_sql.
func validateSQLIdentifier(s string) error {
if s == "" {
return fmt.Errorf("dynamic SQL identifier must not be empty")
Expand All @@ -17,8 +17,8 @@ func validateSQLIdentifier(s string) error {
if (c < 'a' || c > 'z') &&
(c < 'A' || c > 'Z') &&
(c < '0' || c > '9') &&
c != '_' && c != '-' {
return fmt.Errorf("dynamic SQL identifier %q contains unsafe character %q (only ASCII letters, digits, underscores and hyphens are allowed)", s, string(c))
c != '_' {
return fmt.Errorf("dynamic SQL identifier %q contains unsafe character %q (only ASCII letters, digits, and underscores are allowed)", s, string(c))
}
}
return nil
Expand Down
74 changes: 63 additions & 11 deletions module/pipeline_step_db_query_cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type DBQueryCachedStep struct {
database string
query string
params []string
mode string // "single" or "list"
cacheKey string
cacheTTL time.Duration
scanFields []string
Expand Down Expand Up @@ -83,6 +84,15 @@ func NewDBQueryCachedStepFactory() StepFactory {
}
}

mode, _ := config["mode"].(string)
mode = strings.TrimSpace(mode)
// Backwards compatibility: if mode is omitted or empty, leave it blank.
// Execution logic treats a blank mode as the legacy single-row flat
// output, while explicit "single" uses the new row/found envelope.
if mode != "" && mode != "list" && mode != "single" {
return nil, fmt.Errorf("db_query_cached step %q: mode must be 'list' or 'single', got %q", name, mode)
}

var scanFields []string
if sf, ok := config["scan_fields"]; ok {
if list, ok := sf.([]any); ok {
Expand All @@ -99,6 +109,7 @@ func NewDBQueryCachedStepFactory() StepFactory {
database: database,
query: query,
params: params,
mode: mode,
cacheKey: cacheKey,
cacheTTL: cacheTTL,
scanFields: scanFields,
Expand Down Expand Up @@ -145,7 +156,7 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (*
s.mu.RUnlock()

if found && time.Now().Before(entry.expiresAt) {
output := copyMap(entry.value)
output := deepCopyMap(entry.value)
output["cache_hit"] = true
return &StepResult{Output: output}, nil
}
Expand All @@ -155,7 +166,7 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (*
entry, found = s.cache[key]
if found && time.Now().Before(entry.expiresAt) {
// Another goroutine populated the cache while we were waiting for the lock
output := copyMap(entry.value)
output := deepCopyMap(entry.value)
s.mu.Unlock()
output["cache_hit"] = true
return &StepResult{Output: output}, nil
Expand All @@ -175,7 +186,7 @@ func (s *DBQueryCachedStep) Execute(ctx context.Context, pc *PipelineContext) (*
// Store in cache (write lock)
s.mu.Lock()
s.cache[key] = dbQueryCacheEntry{
value: copyMap(result),
value: deepCopyMap(result),
expiresAt: time.Now().Add(s.cacheTTL),
}
s.mu.Unlock()
Expand Down Expand Up @@ -236,7 +247,7 @@ func (s *DBQueryCachedStep) runQuery(ctx context.Context, pc *PipelineContext, q
fieldSet[f] = true
}

output := make(map[string]any)
var results []map[string]any
for rows.Next() {
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
Expand All @@ -248,33 +259,74 @@ func (s *DBQueryCachedStep) runQuery(ctx context.Context, pc *PipelineContext, q
return nil, fmt.Errorf("db_query_cached step %q: scan failed: %w", s.name, err)
}

row := make(map[string]any, len(columns))
for i, col := range columns {
if len(fieldSet) > 0 && !fieldSet[col] {
continue
}
val := values[i]
if b, ok := val.([]byte); ok {
output[col] = string(b)
row[col] = string(b)
} else {
output[col] = val
row[col] = val
}
}
// Only take the first row
break
results = append(results, row)
if s.mode != "list" {
// Only take the first row for single and legacy modes
break
}
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("db_query_cached step %q: row iteration error: %w", s.name, err)
}

output := make(map[string]any)
switch s.mode {
case "list":
if results == nil {
results = []map[string]any{}
}
output["rows"] = results
output["count"] = len(results)
case "single":
if len(results) > 0 {
output["row"] = results[0]
output["found"] = true
} else {
output["row"] = map[string]any{}
output["found"] = false
}
default: // "" — legacy flat column map (backward compatible)
if len(results) > 0 {
for k, v := range results[0] {
output[k] = v
}
}
}

return output, nil
}

// copyMap creates a shallow copy of a map.
func copyMap(m map[string]any) map[string]any {
// deepCopyMap creates a deep copy of a map, recursively copying nested
// map[string]any values and []map[string]any slices to prevent callers from
// mutating cached data or triggering data races across goroutines.
func deepCopyMap(m map[string]any) map[string]any {
cp := make(map[string]any, len(m))
for k, v := range m {
cp[k] = v
switch val := v.(type) {
case map[string]any:
cp[k] = deepCopyMap(val)
case []map[string]any:
sliceCopy := make([]map[string]any, len(val))
for i, row := range val {
sliceCopy[i] = deepCopyMap(row)
}
cp[k] = sliceCopy
default:
cp[k] = v
}
}
return cp
}
214 changes: 213 additions & 1 deletion module/pipeline_step_db_query_cached_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func TestDBQueryCachedStep_NoRows(t *testing.T) {
t.Fatalf("execute error: %v", err)
}

// No rows means empty output (no id/name keys), cache_hit=false
// No rows in legacy mode means empty output (no column keys), cache_hit=false
if result.Output["cache_hit"] != false {
t.Errorf("expected cache_hit=false, got %v", result.Output["cache_hit"])
}
Expand Down Expand Up @@ -407,6 +407,218 @@ func TestDBQueryCachedStep_NegativeTTLRejected(t *testing.T) {
}
}

// TestDBQueryCachedStep_InvalidModeRejected verifies that an unknown mode is rejected.
func TestDBQueryCachedStep_InvalidModeRejected(t *testing.T) {
factory := NewDBQueryCachedStepFactory()
_, err := factory("bad", map[string]any{
"database": "db",
"query": "SELECT 1",
"cache_key": "k",
"mode": "bulk",
}, nil)
if err == nil {
t.Fatal("expected error for invalid mode")
}
}

// TestDBQueryCachedStep_ListMode verifies that mode: list returns rows/count format.
func TestDBQueryCachedStep_ListMode(t *testing.T) {
db := setupTestDB(t)
app := mockAppWithDB("test-db", db)

factory := NewDBQueryCachedStepFactory()
step, err := factory("list-companies", map[string]any{
"database": "test-db",
"query": "SELECT id, name, slug FROM companies WHERE parent_id IS NULL ORDER BY name",
"mode": "list",
"cache_key": "companies:list",
"cache_ttl": "5m",
}, app)
if err != nil {
t.Fatalf("factory error: %v", err)
}

pc := NewPipelineContext(nil, nil)
result, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("execute error: %v", err)
}

if result.Output["cache_hit"] != false {
t.Errorf("expected cache_hit=false on first call, got %v", result.Output["cache_hit"])
}
rows, ok := result.Output["rows"].([]map[string]any)
if !ok {
t.Fatal("expected rows in output for list mode")
}
count, ok := result.Output["count"].(int)
if !ok {
t.Fatal("expected count in output for list mode")
}
if count != 2 {
t.Errorf("expected count=2, got %d", count)
}
if len(rows) != 2 {
t.Errorf("expected 2 rows, got %d", len(rows))
}
if rows[0]["name"] != "Acme Corp" {
t.Errorf("expected first row name='Acme Corp', got %v", rows[0]["name"])
}
}

// TestDBQueryCachedStep_ListModeCacheHit verifies that list mode results are cached and returned correctly.
func TestDBQueryCachedStep_ListModeCacheHit(t *testing.T) {
db := setupTestDB(t)
app := mockAppWithDB("test-db", db)

factory := NewDBQueryCachedStepFactory()
step, err := factory("list-companies", map[string]any{
"database": "test-db",
"query": "SELECT id, name FROM companies WHERE parent_id IS NULL ORDER BY name",
"mode": "list",
"cache_key": "companies:list",
"cache_ttl": "5m",
}, app)
if err != nil {
t.Fatalf("factory error: %v", err)
}

pc := NewPipelineContext(nil, nil)

// First call — cache miss
first, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("first execute error: %v", err)
}
if first.Output["cache_hit"] != false {
t.Errorf("expected cache_hit=false on first call")
}

// Second call — cache hit
second, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("second execute error: %v", err)
}
if second.Output["cache_hit"] != true {
t.Errorf("expected cache_hit=true on second call, got %v", second.Output["cache_hit"])
}
rows, ok := second.Output["rows"].([]map[string]any)
if !ok {
t.Fatal("expected rows in cached output")
}
if len(rows) != 2 {
t.Errorf("expected 2 rows from cache, got %d", len(rows))
}
}

// TestDBQueryCachedStep_ListModeEmpty verifies that list mode returns an empty rows slice when no rows match.
func TestDBQueryCachedStep_ListModeEmpty(t *testing.T) {
db := setupTestDB(t)
app := mockAppWithDB("test-db", db)

factory := NewDBQueryCachedStepFactory()
step, err := factory("list-empty", map[string]any{
"database": "test-db",
"query": "SELECT id, name FROM companies WHERE id = ?",
"params": []any{"nonexistent"},
"mode": "list",
"cache_key": "companies:empty",
}, app)
if err != nil {
t.Fatalf("factory error: %v", err)
}

pc := NewPipelineContext(nil, nil)
result, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("execute error: %v", err)
}

rows, ok := result.Output["rows"].([]map[string]any)
if !ok {
t.Fatal("expected rows in output for list mode even when empty")
}
if len(rows) != 0 {
t.Errorf("expected 0 rows, got %d", len(rows))
}
count, _ := result.Output["count"].(int)
if count != 0 {
t.Errorf("expected count=0, got %d", count)
}
}

// TestDBQueryCachedStep_SingleModeFound verifies that mode: single returns row/found format when a row is found.
func TestDBQueryCachedStep_SingleModeFound(t *testing.T) {
db := setupTestDB(t)
app := mockAppWithDB("test-db", db)

factory := NewDBQueryCachedStepFactory()
step, err := factory("get-company", map[string]any{
"database": "test-db",
"query": "SELECT id, name FROM companies WHERE id = ?",
"params": []any{"c1"},
"mode": "single",
"cache_key": "company:c1",
}, app)
if err != nil {
t.Fatalf("factory error: %v", err)
}

pc := NewPipelineContext(nil, nil)
result, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("execute error: %v", err)
}

found, _ := result.Output["found"].(bool)
if !found {
t.Error("expected found=true")
}
row, ok := result.Output["row"].(map[string]any)
if !ok {
t.Fatal("expected row in output")
}
if row["name"] != "Acme Corp" {
t.Errorf("expected name='Acme Corp', got %v", row["name"])
}
}

// TestDBQueryCachedStep_SingleModeNotFound verifies that mode: single returns row={}/found=false when no row matches.
func TestDBQueryCachedStep_SingleModeNotFound(t *testing.T) {
db := setupTestDB(t)
app := mockAppWithDB("test-db", db)

factory := NewDBQueryCachedStepFactory()
step, err := factory("get-missing", map[string]any{
"database": "test-db",
"query": "SELECT id, name FROM companies WHERE id = ?",
"params": []any{"nonexistent"},
"mode": "single",
"cache_key": "company:nonexistent",
}, app)
if err != nil {
t.Fatalf("factory error: %v", err)
}

pc := NewPipelineContext(nil, nil)
result, err := step.Execute(context.Background(), pc)
if err != nil {
t.Fatalf("execute error: %v", err)
}

found, _ := result.Output["found"].(bool)
if found {
t.Error("expected found=false")
}
row, ok := result.Output["row"].(map[string]any)
if !ok {
t.Fatal("expected row in output even when not found")
}
if len(row) != 0 {
t.Errorf("expected empty row map, got %v", row)
}
}

func TestDBQueryCachedStep_DynamicTableName(t *testing.T) {
db := setupTestDB(t)
_, err := db.Exec(`CREATE TABLE companies_beta (id TEXT PRIMARY KEY, name TEXT NOT NULL)`)
Expand Down
Loading
Loading