diff --git a/cmd/lambda/alert-dispatcher/main.go b/cmd/lambda/alert-dispatcher/main.go index 7e54802..c358b84 100644 --- a/cmd/lambda/alert-dispatcher/main.go +++ b/cmd/lambda/alert-dispatcher/main.go @@ -22,6 +22,11 @@ import ( func main() { logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + if err := ilambda.ValidateEnv("alert-dispatcher"); err != nil { + logger.Error("environment validation failed", "error", err) + os.Exit(1) + } + cfg, err := config.LoadDefaultConfig(context.Background()) if err != nil { logger.Error("load AWS config", "error", err) diff --git a/cmd/lambda/event-sink/main.go b/cmd/lambda/event-sink/main.go index 953b20a..7bf9ec0 100644 --- a/cmd/lambda/event-sink/main.go +++ b/cmd/lambda/event-sink/main.go @@ -19,6 +19,11 @@ import ( func main() { logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + if err := ilambda.ValidateEnv("event-sink"); err != nil { + logger.Error("environment validation failed", "error", err) + os.Exit(1) + } + cfg, err := awsconfig.LoadDefaultConfig(context.Background()) if err != nil { logger.Error("failed to load AWS config", "error", err) diff --git a/cmd/lambda/orchestrator/main.go b/cmd/lambda/orchestrator/main.go index 05d221f..54423ab 100644 --- a/cmd/lambda/orchestrator/main.go +++ b/cmd/lambda/orchestrator/main.go @@ -41,6 +41,11 @@ func (a *statusCheckerAdapter) CheckStatus(ctx context.Context, triggerType type func main() { logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + if err := ilambda.ValidateEnv("orchestrator"); err != nil { + logger.Error("environment validation failed", "error", err) + os.Exit(1) + } + cfg, err := awsconfig.LoadDefaultConfig(context.Background()) if err != nil { logger.Error("failed to load AWS config", "error", err) diff --git a/cmd/lambda/sla-monitor/main.go b/cmd/lambda/sla-monitor/main.go index 0fd3a5d..2a1ee13 100644 --- a/cmd/lambda/sla-monitor/main.go +++ b/cmd/lambda/sla-monitor/main.go @@ -21,6 +21,11 @@ import ( func main() { logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + if err := ilambda.ValidateEnv("sla-monitor"); err != nil { + logger.Error("environment validation failed", "error", err) + os.Exit(1) + } + cfg, err := awsconfig.LoadDefaultConfig(context.Background()) if err != nil { logger.Error("failed to load AWS config", "error", err) diff --git a/cmd/lambda/stream-router/main.go b/cmd/lambda/stream-router/main.go index e1312fb..1354881 100644 --- a/cmd/lambda/stream-router/main.go +++ b/cmd/lambda/stream-router/main.go @@ -23,6 +23,11 @@ import ( func main() { logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + if err := ilambda.ValidateEnv("stream-router"); err != nil { + logger.Error("environment validation failed", "error", err) + os.Exit(1) + } + cfg, err := awsconfig.LoadDefaultConfig(context.Background()) if err != nil { logger.Error("failed to load AWS config", "error", err) diff --git a/cmd/lambda/watchdog/main.go b/cmd/lambda/watchdog/main.go index fc729a3..ba607ad 100644 --- a/cmd/lambda/watchdog/main.go +++ b/cmd/lambda/watchdog/main.go @@ -24,6 +24,11 @@ import ( func main() { logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + if err := ilambda.ValidateEnv("watchdog"); err != nil { + logger.Error("environment validation failed", "error", err) + os.Exit(1) + } + cfg, err := awsconfig.LoadDefaultConfig(context.Background()) if err != nil { logger.Error("failed to load AWS config", "error", err) diff --git a/internal/calendar/registry_test.go b/internal/calendar/registry_test.go index daba396..7d82c0c 100644 --- a/internal/calendar/registry_test.go +++ b/internal/calendar/registry_test.go @@ -5,6 +5,7 @@ import ( "path/filepath" "testing" + "github.com/dwsmith1983/interlock/pkg/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -69,3 +70,57 @@ func TestRegistry_LoadDir_MissingDir(t *testing.T) { err := reg.LoadDir("/nonexistent/path") assert.Error(t, err) } + +func TestRegistry_Register_Success(t *testing.T) { + reg := NewRegistry() + cal := &types.Calendar{Name: "test-cal", Dates: []string{"2025-12-25"}} + require.NoError(t, reg.Register(cal)) + + retrieved := reg.Get("test-cal") + require.NotNil(t, retrieved) + assert.Equal(t, "test-cal", retrieved.Name) + assert.Contains(t, retrieved.Dates, "2025-12-25") +} + +func TestRegistry_Register_NoName(t *testing.T) { + reg := NewRegistry() + err := reg.Register(&types.Calendar{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no name") +} + +func TestRegistry_LoadFile_InvalidYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.yaml") + require.NoError(t, os.WriteFile(path, []byte("{{invalid yaml"), 0o644)) + + reg := NewRegistry() + err := reg.LoadFile(path) + assert.Error(t, err) + assert.Contains(t, err.Error(), "parsing YAML") +} + +func TestRegistry_LoadFile_FileNotFound(t *testing.T) { + reg := NewRegistry() + err := reg.LoadFile("/nonexistent/calendar.yaml") + assert.Error(t, err) + assert.Contains(t, err.Error(), "reading file") +} + +func TestRegistry_LoadDir_IgnoresSubdirs(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(dir, "subdir"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "subdir", "hidden.yaml"), []byte(` +name: hidden +dates: ["2025-01-01"] +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "main.yaml"), []byte(` +name: main-cal +dates: ["2025-12-25"] +`), 0o644)) + + reg := NewRegistry() + require.NoError(t, reg.LoadDir(dir)) + assert.NotNil(t, reg.Get("main-cal"), "top-level calendar should be loaded") + assert.Nil(t, reg.Get("hidden"), "subdirectory calendar should NOT be loaded") +} diff --git a/internal/lambda/defaults.go b/internal/lambda/defaults.go new file mode 100644 index 0000000..4a25442 --- /dev/null +++ b/internal/lambda/defaults.go @@ -0,0 +1,29 @@ +package lambda + +import "time" + +// Default timing constants for Step Function evaluation and job polling. +const ( + // DefaultEvalIntervalSec is the default evaluation loop interval (5 minutes). + DefaultEvalIntervalSec = 300 + + // DefaultEvalWindowSec is the default evaluation window (1 hour). + DefaultEvalWindowSec = 3600 + + // DefaultJobCheckIntervalSec is the default job status check interval (1 minute). + DefaultJobCheckIntervalSec = 60 + + // DefaultJobPollWindowSec is the default job poll window (1 hour). + DefaultJobPollWindowSec = 3600 + + // DefaultTriggerLockTTL is the default trigger lock duration when + // SFN_TIMEOUT_SECONDS is not set. + DefaultTriggerLockTTL = 4*time.Hour + 30*time.Minute + + // TriggerLockBuffer is the padding added to the SFN timeout to + // derive the trigger lock TTL. + TriggerLockBuffer = 30 * time.Minute + + // SFNExecNameMaxLen is the AWS limit for Step Function execution names. + SFNExecNameMaxLen = 80 +) diff --git a/internal/lambda/doc.go b/internal/lambda/doc.go new file mode 100644 index 0000000..b67ca6e --- /dev/null +++ b/internal/lambda/doc.go @@ -0,0 +1,18 @@ +// Package lambda implements the core business logic for the six AWS Lambda +// handlers that power the Interlock pipeline safety framework: +// +// - stream-router: processes DynamoDB stream events, evaluates trigger +// conditions, manages reruns, and detects post-run data drift +// - orchestrator: evaluates validation rules, executes triggers, polls +// job status, and manages trigger lifecycle +// - sla-monitor: calculates SLA deadlines, schedules/cancels EventBridge +// Scheduler entries, and fires warning/breach alerts +// - watchdog: detects silently missed pipeline schedules by scanning for +// pipelines that should have run but have no trigger record +// - event-sink: persists EventBridge events to DynamoDB for audit trail +// - alert-dispatcher: routes SQS alert messages to Slack +// +// All handlers share a common [Deps] struct for dependency injection, +// making the package fully testable with mock implementations of AWS SDK +// interfaces ([SFNAPI], [EventBridgeAPI], [SchedulerAPI], etc.). +package lambda diff --git a/internal/lambda/dynstream.go b/internal/lambda/dynstream.go new file mode 100644 index 0000000..32d9a49 --- /dev/null +++ b/internal/lambda/dynstream.go @@ -0,0 +1,241 @@ +package lambda + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-sdk-go-v2/service/eventbridge" + ebTypes "github.com/aws/aws-sdk-go-v2/service/eventbridge/types" + "github.com/dwsmith1983/interlock/pkg/types" +) + +// extractKeys returns the PK and SK string values from a DynamoDB stream record. +func extractKeys(record events.DynamoDBEventRecord) (pk, sk string) { + keys := record.Change.Keys + if pkAttr, ok := keys["PK"]; ok && pkAttr.DataType() == events.DataTypeString { + pk = pkAttr.String() + } + if skAttr, ok := keys["SK"]; ok && skAttr.DataType() == events.DataTypeString { + sk = skAttr.String() + } + return pk, sk +} + +// extractSensorData converts a DynamoDB stream NewImage to a plain map +// suitable for validation rule evaluation. If the item uses the canonical +// ControlRecord format (sensor fields nested inside a "data" map attribute), +// the "data" map is unwrapped so fields are accessible at the top level. +func extractSensorData(newImage map[string]events.DynamoDBAttributeValue) map[string]interface{} { + if newImage == nil { + return nil + } + + skipKeys := map[string]bool{"PK": true, "SK": true, "ttl": true} + result := make(map[string]interface{}, len(newImage)) + + for k, av := range newImage { + if skipKeys[k] { + continue + } + result[k] = convertAttributeValue(av) + } + + // Unwrap the "data" map if present (canonical ControlRecord sensor format). + if dataMap, ok := result["data"].(map[string]interface{}); ok { + return dataMap + } + return result +} + +// convertAttributeValue converts a DynamoDB stream attribute value to a Go native type. +func convertAttributeValue(av events.DynamoDBAttributeValue) interface{} { + switch av.DataType() { + case events.DataTypeString: + return av.String() + case events.DataTypeNumber: + // Try int first, fall back to float. + if i, err := strconv.ParseInt(av.Number(), 10, 64); err == nil { + return float64(i) + } + if f, err := strconv.ParseFloat(av.Number(), 64); err == nil { + return f + } + return av.Number() + case events.DataTypeBoolean: + return av.Boolean() + case events.DataTypeNull: + return nil + case events.DataTypeMap: + m := av.Map() + out := make(map[string]interface{}, len(m)) + for k, v := range m { + out[k] = convertAttributeValue(v) + } + return out + case events.DataTypeList: + l := av.List() + out := make([]interface{}, len(l)) + for i, v := range l { + out[i] = convertAttributeValue(v) + } + return out + default: + return nil + } +} + +// ResolveExecutionDate builds the execution date from sensor data fields. +// If both "date" and "hour" are present, returns "YYYY-MM-DDThh". +// If only "date", returns "YYYY-MM-DD". Falls back to today's date. +func ResolveExecutionDate(sensorData map[string]interface{}) string { + dateStr, _ := sensorData["date"].(string) + hourStr, _ := sensorData["hour"].(string) + + if dateStr == "" { + return time.Now().Format("2006-01-02") + } + + normalized := normalizeDate(dateStr) + // Validate YYYY-MM-DD format. + if _, err := time.Parse("2006-01-02", normalized); err != nil { + return time.Now().Format("2006-01-02") + } + + if hourStr != "" { + // Validate hour is 2-digit 00-23. + if len(hourStr) == 2 { + if h, err := strconv.Atoi(hourStr); err == nil && h >= 0 && h <= 23 { + return normalized + "T" + hourStr + } + } + return normalized + } + return normalized +} + +// normalizeDate converts YYYYMMDD to YYYY-MM-DD. Already-dashed dates pass through. +func normalizeDate(s string) string { + if len(s) == 8 && !strings.Contains(s, "-") { + return s[:4] + "-" + s[4:6] + "-" + s[6:8] + } + return s +} + +// resolveScheduleID returns "cron" if the pipeline uses a cron schedule, +// otherwise returns "stream". +func resolveScheduleID(cfg *types.PipelineConfig) string { + if cfg.Schedule.Cron != "" { + return "cron" + } + return "stream" +} + +// publishEvent sends an event to EventBridge. It is safe to call when +// EventBridge is nil or EventBusName is empty (returns nil with no action). +func publishEvent(ctx context.Context, d *Deps, eventType, pipelineID, schedule, date, message string, detail ...map[string]interface{}) error { + if d.EventBridge == nil || d.EventBusName == "" { + return nil + } + + evt := types.InterlockEvent{ + PipelineID: pipelineID, + ScheduleID: schedule, + Date: date, + Message: message, + Timestamp: time.Now(), + } + if len(detail) > 0 && detail[0] != nil { + evt.Detail = detail[0] + } + detailJSON, err := json.Marshal(evt) + if err != nil { + return fmt.Errorf("marshal event detail: %w", err) + } + + source := types.EventSource + detailStr := string(detailJSON) + + _, err = d.EventBridge.PutEvents(ctx, &eventbridge.PutEventsInput{ + Entries: []ebTypes.PutEventsRequestEntry{ + { + Source: &source, + DetailType: &eventType, + Detail: &detailStr, + EventBusName: &d.EventBusName, + }, + }, + }) + if err != nil { + return fmt.Errorf("publish %s event: %w", eventType, err) + } + return nil +} + +// resolveTimezone loads the time.Location for the given timezone name. +// Returns time.UTC if tz is empty or cannot be loaded. +func resolveTimezone(tz string) *time.Location { + if tz == "" { + return time.UTC + } + if loc, err := time.LoadLocation(tz); err == nil { + return loc + } + return time.UTC +} + +// isExcludedTime is the core calendar exclusion check. It evaluates +// whether the given time falls on a weekend or a specifically excluded date. +func isExcludedTime(excl *types.ExclusionConfig, t time.Time) bool { + if excl == nil { + return false + } + if excl.Weekends { + day := t.Weekday() + if day == time.Saturday || day == time.Sunday { + return true + } + } + dateStr := t.Format("2006-01-02") + for _, d := range excl.Dates { + if d == dateStr { + return true + } + } + return false +} + +// isExcludedDate checks calendar exclusions against a job's execution date +// (not wall-clock time). dateStr supports "YYYY-MM-DD" and "YYYY-MM-DDTHH". +func isExcludedDate(cfg *types.PipelineConfig, dateStr string) bool { + excl := cfg.Schedule.Exclude + if excl == nil || len(dateStr) < 10 { + return false + } + loc := resolveTimezone(cfg.Schedule.Timezone) + t, err := time.ParseInLocation("2006-01-02", dateStr[:10], loc) + if err != nil { + return false + } + return isExcludedTime(excl, t) +} + +// isExcluded checks whether the pipeline should be excluded from running +// based on calendar exclusions (weekends and specific dates). +// When no timezone is configured, now is used as-is (preserving its +// original location, which is UTC in AWS Lambda). +func isExcluded(cfg *types.PipelineConfig, now time.Time) bool { + excl := cfg.Schedule.Exclude + if excl == nil { + return false + } + t := now + if cfg.Schedule.Timezone != "" { + t = now.In(resolveTimezone(cfg.Schedule.Timezone)) + } + return isExcludedTime(excl, t) +} diff --git a/internal/lambda/envcheck.go b/internal/lambda/envcheck.go new file mode 100644 index 0000000..70ae426 --- /dev/null +++ b/internal/lambda/envcheck.go @@ -0,0 +1,38 @@ +package lambda + +import ( + "fmt" + "os" + "strings" +) + +// requiredEnvVars maps each Lambda handler name to the environment variables +// it requires at startup. Missing vars cause a fail-fast with a clear message. +var requiredEnvVars = map[string][]string{ + "stream-router": {"CONTROL_TABLE", "JOBLOG_TABLE", "RERUN_TABLE", "STATE_MACHINE_ARN", "EVENT_BUS_NAME"}, + "orchestrator": {"CONTROL_TABLE", "JOBLOG_TABLE", "RERUN_TABLE", "STATE_MACHINE_ARN", "EVENT_BUS_NAME"}, + "watchdog": {"CONTROL_TABLE", "JOBLOG_TABLE", "RERUN_TABLE", "EVENT_BUS_NAME"}, + "sla-monitor": {"CONTROL_TABLE", "JOBLOG_TABLE", "RERUN_TABLE", "EVENT_BUS_NAME", "SLA_MONITOR_ARN", "SCHEDULER_ROLE_ARN", "SCHEDULER_GROUP_NAME"}, + "event-sink": {"EVENTS_TABLE"}, + "alert-dispatcher": {"SLACK_BOT_TOKEN", "SLACK_CHANNEL_ID"}, +} + +// ValidateEnv checks that all required environment variables for the named +// handler are set and non-empty. Returns an error listing all missing vars, +// or nil if all are present. +func ValidateEnv(handler string) error { + vars, ok := requiredEnvVars[handler] + if !ok { + return nil // unknown handler, skip validation + } + var missing []string + for _, v := range vars { + if os.Getenv(v) == "" { + missing = append(missing, v) + } + } + if len(missing) > 0 { + return fmt.Errorf("%s: missing required env vars: %s", handler, strings.Join(missing, ", ")) + } + return nil +} diff --git a/internal/lambda/envcheck_test.go b/internal/lambda/envcheck_test.go new file mode 100644 index 0000000..51bd36a --- /dev/null +++ b/internal/lambda/envcheck_test.go @@ -0,0 +1,39 @@ +package lambda + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateEnv_AllPresent(t *testing.T) { + for _, v := range requiredEnvVars["event-sink"] { + t.Setenv(v, "test-value") + } + err := ValidateEnv("event-sink") + require.NoError(t, err) +} + +func TestValidateEnv_MissingVars(t *testing.T) { + // Clear all vars that stream-router needs + for _, v := range requiredEnvVars["stream-router"] { + t.Setenv(v, "") + } + err := ValidateEnv("stream-router") + assert.Error(t, err) + assert.Contains(t, err.Error(), "CONTROL_TABLE") + assert.Contains(t, err.Error(), "STATE_MACHINE_ARN") +} + +func TestValidateEnv_PartialMissing(t *testing.T) { + t.Setenv("EVENTS_TABLE", "") + err := ValidateEnv("event-sink") + assert.Error(t, err) + assert.Contains(t, err.Error(), "EVENTS_TABLE") +} + +func TestValidateEnv_UnknownHandler(t *testing.T) { + err := ValidateEnv("unknown-handler") + require.NoError(t, err) +} diff --git a/internal/lambda/orchestrator_unit_test.go b/internal/lambda/orchestrator_unit_test.go new file mode 100644 index 0000000..d6ce194 --- /dev/null +++ b/internal/lambda/orchestrator_unit_test.go @@ -0,0 +1,159 @@ +package lambda_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/dwsmith1983/interlock/internal/lambda" + "github.com/dwsmith1983/interlock/pkg/types" +) + +// --------------------------------------------------------------------------- +// ParseExecutionDate — table-driven +// --------------------------------------------------------------------------- + +func TestParseExecutionDate(t *testing.T) { + tests := []struct { + name string + date string + wantDate string + wantHour string + }{ + {"hourly", "2026-03-03T10", "2026-03-03", "10"}, + {"daily", "2026-03-03", "2026-03-03", ""}, + {"empty", "", "", ""}, + {"empty_hour", "2026-03-03T", "2026-03-03", ""}, + {"no_date", "T10", "", "10"}, + {"compact_hourly", "20260303T07", "20260303", "07"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + date, hour := lambda.ParseExecutionDate(tt.date) + assert.Equal(t, tt.wantDate, date) + assert.Equal(t, tt.wantHour, hour) + }) + } +} + +// --------------------------------------------------------------------------- +// InjectDateArgs — table-driven +// --------------------------------------------------------------------------- + +func TestInjectDateArgs(t *testing.T) { + t.Run("glue_daily", func(t *testing.T) { + tc := &types.TriggerConfig{Glue: &types.GlueTriggerConfig{}} + lambda.InjectDateArgs(tc, "2026-03-03") + assert.Equal(t, "20260303", tc.Glue.Arguments["--par_day"]) + assert.Empty(t, tc.Glue.Arguments["--par_hour"]) + }) + + t.Run("glue_hourly", func(t *testing.T) { + tc := &types.TriggerConfig{Glue: &types.GlueTriggerConfig{}} + lambda.InjectDateArgs(tc, "2026-03-03T10") + assert.Equal(t, "20260303", tc.Glue.Arguments["--par_day"]) + assert.Equal(t, "10", tc.Glue.Arguments["--par_hour"]) + }) + + t.Run("glue_preserves_existing_args", func(t *testing.T) { + tc := &types.TriggerConfig{Glue: &types.GlueTriggerConfig{ + Arguments: map[string]string{"--extra": "val"}, + }} + lambda.InjectDateArgs(tc, "2026-03-03") + assert.Equal(t, "20260303", tc.Glue.Arguments["--par_day"]) + assert.Equal(t, "val", tc.Glue.Arguments["--extra"]) + }) + + t.Run("http_empty_body_daily", func(t *testing.T) { + tc := &types.TriggerConfig{HTTP: &types.HTTPTriggerConfig{URL: "http://example.com"}} + lambda.InjectDateArgs(tc, "2026-03-03") + assert.Contains(t, tc.HTTP.Body, `"par_day":"20260303"`) + assert.NotContains(t, tc.HTTP.Body, "par_hour") + }) + + t.Run("http_empty_body_hourly", func(t *testing.T) { + tc := &types.TriggerConfig{HTTP: &types.HTTPTriggerConfig{URL: "http://example.com"}} + lambda.InjectDateArgs(tc, "2026-03-03T10") + assert.Contains(t, tc.HTTP.Body, `"par_day":"20260303"`) + assert.Contains(t, tc.HTTP.Body, `"par_hour":"10"`) + }) + + t.Run("http_existing_body_not_overwritten", func(t *testing.T) { + tc := &types.TriggerConfig{HTTP: &types.HTTPTriggerConfig{ + URL: "http://example.com", + Body: `{"custom":"data"}`, + }} + lambda.InjectDateArgs(tc, "2026-03-03") + assert.Equal(t, `{"custom":"data"}`, tc.HTTP.Body, "existing body should not be overwritten") + }) + + t.Run("nil_glue_no_panic", func(t *testing.T) { + tc := &types.TriggerConfig{} + assert.NotPanics(t, func() { lambda.InjectDateArgs(tc, "2026-03-03") }) + }) +} + +// --------------------------------------------------------------------------- +// RemapPerPeriodSensors — table-driven +// --------------------------------------------------------------------------- + +func TestRemapPerPeriodSensors(t *testing.T) { + t.Run("normalized_date_suffix", func(t *testing.T) { + sensors := map[string]map[string]interface{}{ + "hourly-status#2026-03-03T07": {"count": 100.0}, + } + lambda.RemapPerPeriodSensors(sensors, "2026-03-03T07") + assert.NotNil(t, sensors["hourly-status"], "base key should be added") + assert.Equal(t, 100.0, sensors["hourly-status"]["count"]) + }) + + t.Run("compact_date_suffix", func(t *testing.T) { + sensors := map[string]map[string]interface{}{ + "hourly-status#20260303T07": {"count": 200.0}, + } + lambda.RemapPerPeriodSensors(sensors, "2026-03-03T07") + assert.NotNil(t, sensors["hourly-status"], "compact suffix should match") + assert.Equal(t, 200.0, sensors["hourly-status"]["count"]) + }) + + t.Run("no_matching_suffix", func(t *testing.T) { + sensors := map[string]map[string]interface{}{ + "upstream-complete": {"complete": true}, + } + lambda.RemapPerPeriodSensors(sensors, "2026-03-03") + _, exists := sensors["upstream-complete"] + assert.True(t, exists, "existing key should remain") + // No base key alias should be created since there's no suffix match. + assert.Len(t, sensors, 1) + }) + + t.Run("empty_date_noop", func(t *testing.T) { + sensors := map[string]map[string]interface{}{ + "hourly-status#2026-03-03T07": {"count": 100.0}, + } + lambda.RemapPerPeriodSensors(sensors, "") + assert.Len(t, sensors, 1, "no remapping should occur with empty date") + }) + + t.Run("multiple_sensors", func(t *testing.T) { + sensors := map[string]map[string]interface{}{ + "hourly-status#2026-03-03T07": {"count": 100.0}, + "daily-check#2026-03-03": {"passed": true}, + "other-sensor": {"val": 42.0}, + } + lambda.RemapPerPeriodSensors(sensors, "2026-03-03T07") + assert.NotNil(t, sensors["hourly-status"]) + // daily-check#2026-03-03 doesn't match #2026-03-03T07 + _, hasDailyBase := sensors["daily-check"] + assert.False(t, hasDailyBase, "daily-check should not be remapped with hourly date") + }) + + t.Run("daily_date_remaps", func(t *testing.T) { + sensors := map[string]map[string]interface{}{ + "daily-check#2026-03-03": {"passed": true}, + } + lambda.RemapPerPeriodSensors(sensors, "2026-03-03") + assert.NotNil(t, sensors["daily-check"]) + assert.Equal(t, true, sensors["daily-check"]["passed"]) + }) +} diff --git a/internal/lambda/postrun.go b/internal/lambda/postrun.go new file mode 100644 index 0000000..e22e91c --- /dev/null +++ b/internal/lambda/postrun.go @@ -0,0 +1,166 @@ +package lambda + +import ( + "context" + "fmt" + "math" + "strings" + "time" + + "github.com/dwsmith1983/interlock/internal/validation" + "github.com/dwsmith1983/interlock/pkg/types" +) + +// matchesPostRunRule returns true if the sensor key matches any post-run rule key +// (prefix match to support per-period sensor keys). +func matchesPostRunRule(sensorKey string, rules []types.ValidationRule) bool { + for _, rule := range rules { + if strings.HasPrefix(sensorKey, rule.Key) { + return true + } + } + return false +} + +// handlePostRunSensorEvent evaluates post-run rules reactively when a sensor +// arrives via DynamoDB Stream. Compares current sensor values against the +// date-scoped baseline captured at trigger completion. +func handlePostRunSensorEvent(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, sensorKey string, sensorData map[string]interface{}) error { + scheduleID := resolveScheduleID(cfg) + date := ResolveExecutionDate(sensorData) + + // Consistent read to handle race where sensor stream event arrives + // before SFN sets trigger to COMPLETED. + trigger, err := d.Store.GetTrigger(ctx, pipelineID, scheduleID, date) + if err != nil { + return fmt.Errorf("get trigger for post-run: %w", err) + } + if trigger == nil { + return nil // No trigger for this date — not a post-run event. + } + + switch trigger.Status { + case types.TriggerStatusRunning: + // Job still running — evaluate rules for informational drift detection. + return handlePostRunInflight(ctx, d, cfg, pipelineID, scheduleID, date, sensorKey, sensorData) + + case types.TriggerStatusCompleted: + // Job completed — full post-run evaluation with baseline comparison. + return handlePostRunCompleted(ctx, d, cfg, pipelineID, scheduleID, date, sensorData) + + default: + // FAILED_FINAL or unknown — skip. + return nil + } +} + +// handlePostRunInflight evaluates post-run rules while the job is still running. +// If drift is detected, publishes an informational event but does NOT trigger a rerun. +func handlePostRunInflight(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date, sensorKey string, sensorData map[string]interface{}) error { + // Read baseline for comparison. + baselineKey := "postrun-baseline#" + date + baseline, err := d.Store.GetSensorData(ctx, pipelineID, baselineKey) + if err != nil { + return fmt.Errorf("get baseline for inflight check: %w", err) + } + if baseline == nil { + return nil // No baseline yet — job hasn't completed once. + } + + prevCount := ExtractFloat(baseline, "sensor_count") + currCount := ExtractFloat(sensorData, "sensor_count") + threshold := 0.0 + if cfg.PostRun.DriftThreshold != nil { + threshold = *cfg.PostRun.DriftThreshold + } + if prevCount > 0 && currCount > 0 && math.Abs(currCount-prevCount) > threshold { + if err := publishEvent(ctx, d, string(types.EventPostRunDriftInflight), pipelineID, scheduleID, date, + fmt.Sprintf("inflight drift detected for %s: %.0f → %.0f (informational)", pipelineID, prevCount, currCount), + map[string]interface{}{ + "previousCount": prevCount, + "currentCount": currCount, + "driftThreshold": threshold, + "sensorKey": sensorKey, + "source": "post-run-stream", + }); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunDriftInflight, "error", err) + } + } + return nil +} + +// handlePostRunCompleted evaluates post-run rules after the job has completed. +// Compares sensor values against the date-scoped baseline and triggers a rerun +// if drift is detected. +func handlePostRunCompleted(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date string, sensorData map[string]interface{}) error { + // Read baseline captured at trigger completion. + baselineKey := "postrun-baseline#" + date + baseline, err := d.Store.GetSensorData(ctx, pipelineID, baselineKey) + if err != nil { + return fmt.Errorf("get baseline for post-run: %w", err) + } + + // Check for data drift if baseline exists. + if baseline != nil { + prevCount := ExtractFloat(baseline, "sensor_count") + currCount := ExtractFloat(sensorData, "sensor_count") + threshold := 0.0 + if cfg.PostRun.DriftThreshold != nil { + threshold = *cfg.PostRun.DriftThreshold + } + if prevCount > 0 && currCount > 0 && math.Abs(currCount-prevCount) > threshold { + delta := currCount - prevCount + if err := publishEvent(ctx, d, string(types.EventPostRunDrift), pipelineID, scheduleID, date, + fmt.Sprintf("post-run drift detected for %s: %.0f → %.0f records", pipelineID, prevCount, currCount), + map[string]interface{}{ + "previousCount": prevCount, + "currentCount": currCount, + "delta": delta, + "driftThreshold": threshold, + "source": "post-run-stream", + }); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunDrift, "error", err) + } + + // Trigger rerun via the existing circuit breaker path only if the + // execution date is not excluded by the pipeline's calendar config. + if isExcludedDate(cfg, date) { + if pubErr := publishEvent(ctx, d, string(types.EventPipelineExcluded), pipelineID, scheduleID, date, + fmt.Sprintf("post-run drift rerun skipped for %s: execution date %s excluded by calendar", pipelineID, date)); pubErr != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPipelineExcluded, "error", pubErr) + } + d.Logger.InfoContext(ctx, "post-run drift rerun skipped: execution date excluded by calendar", + "pipelineId", pipelineID, "date", date) + } else { + if writeErr := d.Store.WriteRerunRequest(ctx, pipelineID, scheduleID, date, "data-drift"); writeErr != nil { + d.Logger.WarnContext(ctx, "failed to write rerun request on post-run drift", + "pipelineId", pipelineID, "error", writeErr) + } + } + return nil + } + } + + // Evaluate post-run validation rules. + sensors, err := d.Store.GetAllSensors(ctx, pipelineID) + if err != nil { + return fmt.Errorf("get sensors for post-run rules: %w", err) + } + RemapPerPeriodSensors(sensors, date) + + result := validation.EvaluateRules("ALL", cfg.PostRun.Rules, sensors, time.Now()) + + if result.Passed { + if err := publishEvent(ctx, d, string(types.EventPostRunPassed), pipelineID, scheduleID, date, + fmt.Sprintf("post-run validation passed for %s", pipelineID)); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunPassed, "error", err) + } + } else { + if err := publishEvent(ctx, d, string(types.EventPostRunFailed), pipelineID, scheduleID, date, + fmt.Sprintf("post-run validation failed for %s", pipelineID)); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunFailed, "error", err) + } + } + + return nil +} diff --git a/internal/lambda/rerun.go b/internal/lambda/rerun.go new file mode 100644 index 0000000..5cb4451 --- /dev/null +++ b/internal/lambda/rerun.go @@ -0,0 +1,384 @@ +package lambda + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/aws/aws-lambda-go/events" + "github.com/dwsmith1983/interlock/pkg/types" +) + +// handleRerunRequest processes a RERUN_REQUEST# stream record. It enforces +// per-source rerun limits (drift vs manual) and implements a circuit breaker +// that prevents unnecessary re-runs when the previous run succeeded and no +// sensor data has changed since. +func handleRerunRequest(ctx context.Context, d *Deps, pk, sk string, record events.DynamoDBEventRecord) error { + pipelineID := strings.TrimPrefix(pk, "PIPELINE#") + if pipelineID == pk { + return fmt.Errorf("unexpected PK format: %q", pk) + } + + schedule, date, err := parseRerunRequestSK(sk) + if err != nil { + return err + } + + cfg, err := getValidatedConfig(ctx, d, pipelineID) + if err != nil { + return fmt.Errorf("load config for %q: %w", pipelineID, err) + } + if cfg == nil { + d.Logger.Warn("no config found for pipeline, skipping rerun request", "pipelineId", pipelineID) + return nil + } + + // --- Calendar exclusion check (execution date) --- + if isExcludedDate(cfg, date) { + _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, types.JobEventRerunRejected, "", 0, "excluded by calendar") + if pubErr := publishEvent(ctx, d, string(types.EventPipelineExcluded), pipelineID, schedule, date, + fmt.Sprintf("rerun blocked for %s: execution date %s excluded by calendar", pipelineID, date)); pubErr != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPipelineExcluded, "error", pubErr) + } + return nil + } + + // Extract reason from stream record NewImage. Default to "manual". + reason := "manual" + if img := record.Change.NewImage; img != nil { + if r, ok := img["reason"]; ok && r.DataType() == events.DataTypeString { + if v := r.String(); v != "" { + reason = v + } + } + } + + // --- Rerun limit check --- + var budget int + var sources []string + var limitLabel string + switch reason { + case "data-drift", "late-data": + budget = types.IntOrDefault(cfg.Job.MaxDriftReruns, 1) + sources = []string{"data-drift", "late-data"} + limitLabel = "drift rerun limit exceeded" + default: + budget = types.IntOrDefault(cfg.Job.MaxManualReruns, 1) + sources = []string{reason} + limitLabel = "manual rerun limit exceeded" + } + + count, err := d.Store.CountRerunsBySource(ctx, pipelineID, schedule, date, sources) + if err != nil { + return fmt.Errorf("count reruns by source for %q: %w", pipelineID, err) + } + + if count >= budget { + _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, + types.JobEventRerunRejected, "", 0, limitLabel) + if err := publishEvent(ctx, d, string(types.EventRerunRejected), pipelineID, schedule, date, + fmt.Sprintf("rerun rejected for %s: %s", pipelineID, limitLabel)); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRerunRejected, "error", err) + } + d.Logger.Info("rerun request rejected (limit exceeded)", + "pipelineId", pipelineID, "schedule", schedule, "date", date, + "reason", reason, "count", count, "budget", budget) + return nil + } + + // --- Circuit breaker (sensor freshness) --- + job, err := d.Store.GetLatestJobEvent(ctx, pipelineID, schedule, date) + if err != nil { + return fmt.Errorf("get latest job event for %q/%s/%s: %w", pipelineID, schedule, date, err) + } + + allowed := true + rejectReason := "" + if job != nil && job.Event == types.JobEventSuccess { + fresh, err := checkSensorFreshness(ctx, d, pipelineID, job.SK) + if err != nil { + return fmt.Errorf("check sensor freshness for %q: %w", pipelineID, err) + } + if !fresh { + allowed = false + rejectReason = "previous run succeeded and no sensor data has changed" + } + } + + if !allowed { + _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, + types.JobEventRerunRejected, "", 0, rejectReason) + if err := publishEvent(ctx, d, string(types.EventRerunRejected), pipelineID, schedule, date, + fmt.Sprintf("rerun rejected for %s: %s", pipelineID, rejectReason)); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRerunRejected, "error", err) + } + d.Logger.Info("rerun request rejected", + "pipelineId", pipelineID, "schedule", schedule, "date", date, + "reason", rejectReason) + return nil + } + + // --- Acceptance: write rerun record FIRST (before lock reset) --- + if _, err := d.Store.WriteRerun(ctx, pipelineID, schedule, date, reason, ""); err != nil { + return fmt.Errorf("write rerun for %q: %w", pipelineID, err) + } + + // Delete date-scoped postrun-baseline so re-run captures fresh baseline. + if cfg.PostRun != nil { + _ = d.Store.DeleteSensor(ctx, pipelineID, "postrun-baseline#"+date) + } + + // Atomically reset the trigger lock for the new execution. + acquired, err := d.Store.ResetTriggerLock(ctx, pipelineID, schedule, date, ResolveTriggerLockTTL()) + if err != nil { + return fmt.Errorf("reset trigger lock for %q: %w", pipelineID, err) + } + if !acquired { + if pubErr := publishEvent(ctx, d, string(types.EventInfraFailure), pipelineID, schedule, date, + fmt.Sprintf("lock reset failed for rerun of %s, orphaned rerun record", pipelineID)); pubErr != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "error", pubErr) + } + d.Logger.Warn("failed to reset trigger lock, orphaned rerun record", + "pipelineId", pipelineID, "schedule", schedule, "date", date) + return nil + } + + // Publish acceptance event only after lock atomicity is confirmed. + _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, + types.JobEventRerunAccepted, "", 0, "") + + if pubErr := publishEvent(ctx, d, string(types.EventRerunAccepted), pipelineID, schedule, date, + fmt.Sprintf("rerun accepted for %s (reason: %s)", pipelineID, reason)); pubErr != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRerunAccepted, "error", pubErr) + } + + execName := truncateExecName(fmt.Sprintf("%s-%s-%s-%s-rerun-%d", pipelineID, schedule, date, reason, time.Now().Unix())) + if err := startSFNWithName(ctx, d, cfg, pipelineID, schedule, date, execName); err != nil { + if relErr := d.Store.ReleaseTriggerLock(ctx, pipelineID, schedule, date); relErr != nil { + d.Logger.Warn("failed to release lock after SFN start failure", "error", relErr) + } + return fmt.Errorf("start SFN rerun for %q: %w", pipelineID, err) + } + + d.Logger.Info("started rerun", + "pipelineId", pipelineID, "schedule", schedule, "date", date, "reason", reason) + return nil +} + +// parseRerunRequestSK extracts schedule and date from a RERUN_REQUEST# sort key. +// Expected format: RERUN_REQUEST## +func parseRerunRequestSK(sk string) (schedule, date string, err error) { + trimmed := strings.TrimPrefix(sk, "RERUN_REQUEST#") + parts := strings.SplitN(trimmed, "#", 2) + if len(parts) < 2 { + return "", "", fmt.Errorf("invalid RERUN_REQUEST SK format: %q", sk) + } + return parts[0], parts[1], nil +} + +// handleJobFailure processes a job failure or timeout by either re-running +// the pipeline (if under the retry limit) or marking it as permanently failed. +func handleJobFailure(ctx context.Context, d *Deps, pipelineID, schedule, date, jobEvent string) error { + cfg, err := getValidatedConfig(ctx, d, pipelineID) + if err != nil { + return fmt.Errorf("load config for %q: %w", pipelineID, err) + } + if cfg == nil { + d.Logger.Warn("no config found for pipeline, skipping rerun", "pipelineId", pipelineID) + return nil + } + + maxRetries := cfg.Job.MaxRetries + + // Check if the latest failure has a category for budget selection. + latestJob, jobErr := d.Store.GetLatestJobEvent(ctx, pipelineID, schedule, date) + if jobErr != nil { + d.Logger.Warn("could not read latest job event for failure category", + "pipelineId", pipelineID, "error", jobErr) + } + if latestJob != nil { + if types.FailureCategory(latestJob.Category) == types.FailurePermanent { + maxRetries = types.IntOrDefault(cfg.Job.MaxCodeRetries, 1) + } + // TRANSIENT, TIMEOUT, or empty → use cfg.Job.MaxRetries (already set). + } + + rerunCount, err := d.Store.CountRerunsBySource(ctx, pipelineID, schedule, date, []string{"job-fail-retry"}) + if err != nil { + return fmt.Errorf("count reruns for %q/%s/%s: %w", pipelineID, schedule, date, err) + } + + if rerunCount >= maxRetries { + // Retry limit reached — publish exhaustion event and mark as final failure. + if err := publishEvent(ctx, d, string(types.EventRetryExhausted), pipelineID, schedule, date, + fmt.Sprintf("retry limit reached (%d/%d) for %s", rerunCount, maxRetries, pipelineID)); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRetryExhausted, "error", err) + } + + if err := d.Store.SetTriggerStatus(ctx, pipelineID, schedule, date, types.TriggerStatusFailedFinal); err != nil { + return fmt.Errorf("set trigger status FAILED_FINAL for %q: %w", pipelineID, err) + } + + d.Logger.Info("retry limit reached", + "pipelineId", pipelineID, + "schedule", schedule, + "date", date, + "reruns", rerunCount, + "maxRetries", maxRetries, + ) + return nil + } + + // Calendar exclusion check: skip retry if the execution date is excluded. + // Mark trigger as terminal so the lock doesn't silently expire via TTL. + if isExcludedDate(cfg, date) { + if err := d.Store.SetTriggerStatus(ctx, pipelineID, schedule, date, types.TriggerStatusFailedFinal); err != nil { + d.Logger.WarnContext(ctx, "failed to set trigger status after calendar exclusion", "error", err) + } + if pubErr := publishEvent(ctx, d, string(types.EventPipelineExcluded), pipelineID, schedule, date, + fmt.Sprintf("job failure retry skipped for %s: execution date %s excluded by calendar", pipelineID, date)); pubErr != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPipelineExcluded, "error", pubErr) + } + return nil + } + + // Under retry limit — write rerun record and restart the pipeline. + attempt, err := d.Store.WriteRerun(ctx, pipelineID, schedule, date, "job-fail-retry", jobEvent) + if err != nil { + return fmt.Errorf("write rerun for %q: %w", pipelineID, err) + } + + acquired, err := d.Store.ResetTriggerLock(ctx, pipelineID, schedule, date, ResolveTriggerLockTTL()) + if err != nil { + return fmt.Errorf("reset trigger lock for %q: %w", pipelineID, err) + } + if !acquired { + d.Logger.Warn("failed to reset trigger lock, skipping rerun", + "pipelineId", pipelineID, "schedule", schedule, "date", date) + return nil + } + + // Use a unique execution name that includes the rerun attempt number. + execName := truncateExecName(fmt.Sprintf("%s-%s-%s-rerun-%d", pipelineID, schedule, date, attempt)) + if err := startSFNWithName(ctx, d, cfg, pipelineID, schedule, date, execName); err != nil { + if relErr := d.Store.ReleaseTriggerLock(ctx, pipelineID, schedule, date); relErr != nil { + d.Logger.Warn("failed to release lock after SFN start failure", "error", relErr) + } + return fmt.Errorf("start SFN rerun for %q: %w", pipelineID, err) + } + + d.Logger.Info("started rerun", + "pipelineId", pipelineID, + "schedule", schedule, + "date", date, + "attempt", attempt, + ) + return nil +} + +// checkSensorFreshness determines whether any sensor data has been updated +// after the given job completed. The job timestamp is extracted from the job +// SK (format: JOB#schedule#date#). Returns true if data has +// changed (rerun should proceed) or if freshness cannot be determined. +func checkSensorFreshness(ctx context.Context, d *Deps, pipelineID, jobSK string) (bool, error) { + // Extract timestamp from the job SK. + parts := strings.Split(jobSK, "#") + if len(parts) < 4 { + // Can't parse timestamp — allow to be safe. + return true, nil + } + jobTimestamp, err := strconv.ParseInt(parts[len(parts)-1], 10, 64) + if err != nil { + // Can't parse timestamp — allow to be safe. + return true, nil + } + + sensors, err := d.Store.GetAllSensors(ctx, pipelineID) + if err != nil { + return false, fmt.Errorf("get sensors for %q: %w", pipelineID, err) + } + if len(sensors) == 0 { + // No sensors — can't prove unchanged, allow. + return true, nil + } + + hasAnyUpdatedAt := false + for _, data := range sensors { + updatedAt, ok := data["updatedAt"] + if !ok { + continue + } + hasAnyUpdatedAt = true + + var ts int64 + switch v := updatedAt.(type) { + case float64: + ts = int64(v) + case int64: + ts = v + case string: + ts, err = strconv.ParseInt(v, 10, 64) + if err != nil { + continue + } + default: + continue + } + + if ts > jobTimestamp { + return true, nil // Data changed after job — allow rerun. + } + } + + if !hasAnyUpdatedAt { + // No sensors have updatedAt — can't prove unchanged, allow. + return true, nil + } + + // All sensor timestamps are older than the job — data unchanged. + return false, nil +} + +// checkLateDataArrival detects sensor updates after a pipeline has completed +// successfully. If the trigger is in terminal COMPLETED state and the latest +// job event is success, this sensor write represents late data that arrived +// after post-job monitoring closed. Dual-writes a joblog entry and publishes +// a LATE_DATA_ARRIVAL event. +func checkLateDataArrival(ctx context.Context, d *Deps, pipelineID, schedule, date string) error { + trigger, err := d.Store.GetTrigger(ctx, pipelineID, schedule, date) + if err != nil || trigger == nil { + return err + } + + if trigger.Status != types.TriggerStatusCompleted { + return nil // still running or failed — not late data + } + + job, err := d.Store.GetLatestJobEvent(ctx, pipelineID, schedule, date) + if err != nil || job == nil { + return err + } + + if job.Event != types.JobEventSuccess { + return nil // job didn't succeed — not a "late data after success" scenario + } + + // Dual-write: joblog entry (audit) + EventBridge event (alerting). + _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, + types.JobEventLateDataArrival, "", 0, + "sensor updated after pipeline completed successfully") + + if err := publishEvent(ctx, d, string(types.EventLateDataArrival), pipelineID, schedule, date, + fmt.Sprintf("late data arrival for %s: sensor updated after job completion", pipelineID)); err != nil { + d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventLateDataArrival, "error", err) + } + + // Trigger a re-run — circuit breaker in handleRerunRequest will validate sensor freshness. + if writeErr := d.Store.WriteRerunRequest(ctx, pipelineID, schedule, date, "late-data"); writeErr != nil { + d.Logger.WarnContext(ctx, "failed to write rerun request on late data", "pipelineId", pipelineID, "error", writeErr) + } + + return nil +} diff --git a/internal/lambda/sfn.go b/internal/lambda/sfn.go new file mode 100644 index 0000000..c33243b --- /dev/null +++ b/internal/lambda/sfn.go @@ -0,0 +1,119 @@ +package lambda + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/service/sfn" + "github.com/dwsmith1983/interlock/pkg/types" +) + +// sfnInput is the top-level input for the Step Function state machine. +// It includes pipeline identity fields and a config block used by Wait states. +type sfnInput struct { + PipelineID string `json:"pipelineId"` + ScheduleID string `json:"scheduleId"` + Date string `json:"date"` + Config sfnConfig `json:"config"` +} + +// sfnConfig holds timing parameters for the SFN evaluation loop and SLA branch. +type sfnConfig struct { + EvaluationIntervalSeconds int `json:"evaluationIntervalSeconds"` + EvaluationWindowSeconds int `json:"evaluationWindowSeconds"` + JobCheckIntervalSeconds int `json:"jobCheckIntervalSeconds"` + JobPollWindowSeconds int `json:"jobPollWindowSeconds"` + SLA *types.SLAConfig `json:"sla,omitempty"` +} + +// buildSFNConfig converts a PipelineConfig into the config block for the SFN input. +func buildSFNConfig(cfg *types.PipelineConfig) sfnConfig { + sc := sfnConfig{ + EvaluationIntervalSeconds: DefaultEvalIntervalSec, + EvaluationWindowSeconds: DefaultEvalWindowSec, + JobCheckIntervalSeconds: DefaultJobCheckIntervalSec, + JobPollWindowSeconds: DefaultJobPollWindowSec, + } + + if d, err := time.ParseDuration(cfg.Schedule.Evaluation.Interval); err == nil && d > 0 { + sc.EvaluationIntervalSeconds = int(d.Seconds()) + } + if d, err := time.ParseDuration(cfg.Schedule.Evaluation.Window); err == nil && d > 0 { + sc.EvaluationWindowSeconds = int(d.Seconds()) + } + + if cfg.Job.JobPollWindowSeconds != nil && *cfg.Job.JobPollWindowSeconds > 0 { + sc.JobPollWindowSeconds = *cfg.Job.JobPollWindowSeconds + } + + if cfg.SLA != nil { + sla := *cfg.SLA + if sla.Timezone == "" { + sla.Timezone = "UTC" + } + sc.SLA = &sla + } + + return sc +} + +// truncateExecName ensures an SFN execution name does not exceed the 80-character +// AWS limit. When truncation is needed the suffix (date + timestamp) is preserved +// by trimming characters from the beginning of the name. +func truncateExecName(name string) string { + if len(name) <= SFNExecNameMaxLen { + return name + } + return name[len(name)-SFNExecNameMaxLen:] +} + +// startSFN starts a Step Function execution with a unique execution name. +// The name includes a Unix timestamp suffix to avoid ExecutionAlreadyExists +// errors when a previous execution for the same pipeline/schedule/date failed. +func startSFN(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date string) error { + name := truncateExecName(fmt.Sprintf("%s-%s-%s-%d", pipelineID, scheduleID, date, time.Now().Unix())) + return startSFNWithName(ctx, d, cfg, pipelineID, scheduleID, date, name) +} + +// startSFNWithName starts a Step Function execution with a custom execution name. +func startSFNWithName(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date, name string) error { + sc := buildSFNConfig(cfg) + + // Warn if the sum of evaluation + poll windows exceeds the SFN timeout. + totalWindowSec := sc.EvaluationWindowSeconds + sc.JobPollWindowSeconds + sfnTimeout := ResolveTriggerLockTTL() - TriggerLockBuffer // strip the buffer to get raw SFN timeout + if sfnTimeout > 0 && time.Duration(totalWindowSec)*time.Second > sfnTimeout { + d.Logger.Warn("combined pipeline windows exceed SFN timeout", + "pipelineId", pipelineID, + "evalWindowSec", sc.EvaluationWindowSeconds, + "jobPollWindowSec", sc.JobPollWindowSeconds, + "totalWindowSec", totalWindowSec, + "sfnTimeoutSec", int(sfnTimeout.Seconds()), + ) + } + + input := sfnInput{ + PipelineID: pipelineID, + ScheduleID: scheduleID, + Date: date, + Config: sc, + } + payload, err := json.Marshal(input) + if err != nil { + return fmt.Errorf("marshal SFN input: %w", err) + } + + inputStr := string(payload) + + _, err = d.SFNClient.StartExecution(ctx, &sfn.StartExecutionInput{ + StateMachineArn: &d.StateMachineARN, + Name: &name, + Input: &inputStr, + }) + if err != nil { + return fmt.Errorf("StartExecution: %w", err) + } + return nil +} diff --git a/internal/lambda/stream_router.go b/internal/lambda/stream_router.go index 573d7a2..e7a6a9c 100644 --- a/internal/lambda/stream_router.go +++ b/internal/lambda/stream_router.go @@ -2,19 +2,13 @@ package lambda import ( "context" - "encoding/json" "fmt" - "math" "os" "strconv" "strings" "time" "github.com/aws/aws-lambda-go/events" - "github.com/aws/aws-sdk-go-v2/service/eventbridge" - ebTypes "github.com/aws/aws-sdk-go-v2/service/eventbridge/types" - "github.com/aws/aws-sdk-go-v2/service/sfn" - "github.com/dwsmith1983/interlock/internal/validation" "github.com/dwsmith1983/interlock/pkg/types" ) @@ -25,13 +19,13 @@ import ( func ResolveTriggerLockTTL() time.Duration { s := os.Getenv("SFN_TIMEOUT_SECONDS") if s == "" { - return 4*time.Hour + 30*time.Minute + return DefaultTriggerLockTTL } sec, err := strconv.Atoi(s) if err != nil || sec <= 0 { - return 4*time.Hour + 30*time.Minute + return DefaultTriggerLockTTL } - return time.Duration(sec)*time.Second + 30*time.Minute + return time.Duration(sec)*time.Second + TriggerLockBuffer } // getValidatedConfig loads a pipeline config and validates its retry/timeout @@ -137,342 +131,12 @@ func parseJobSK(sk string) (schedule, date string, err error) { return parts[0], parts[1], nil } -// handleJobFailure processes a job failure or timeout by either re-running -// the pipeline (if under the retry limit) or marking it as permanently failed. -func handleJobFailure(ctx context.Context, d *Deps, pipelineID, schedule, date, jobEvent string) error { - cfg, err := getValidatedConfig(ctx, d, pipelineID) - if err != nil { - return fmt.Errorf("load config for %q: %w", pipelineID, err) - } - if cfg == nil { - d.Logger.Warn("no config found for pipeline, skipping rerun", "pipelineId", pipelineID) - return nil - } - - maxRetries := cfg.Job.MaxRetries - - // Check if the latest failure has a category for budget selection. - latestJob, jobErr := d.Store.GetLatestJobEvent(ctx, pipelineID, schedule, date) - if jobErr != nil { - d.Logger.Warn("could not read latest job event for failure category", - "pipelineId", pipelineID, "error", jobErr) - } - if latestJob != nil { - if types.FailureCategory(latestJob.Category) == types.FailurePermanent { - maxRetries = types.IntOrDefault(cfg.Job.MaxCodeRetries, 1) - } - // TRANSIENT, TIMEOUT, or empty → use cfg.Job.MaxRetries (already set). - } - - rerunCount, err := d.Store.CountRerunsBySource(ctx, pipelineID, schedule, date, []string{"job-fail-retry"}) - if err != nil { - return fmt.Errorf("count reruns for %q/%s/%s: %w", pipelineID, schedule, date, err) - } - - if rerunCount >= maxRetries { - // Retry limit reached — publish exhaustion event and mark as final failure. - if err := publishEvent(ctx, d, string(types.EventRetryExhausted), pipelineID, schedule, date, - fmt.Sprintf("retry limit reached (%d/%d) for %s", rerunCount, maxRetries, pipelineID)); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRetryExhausted, "error", err) - } - - if err := d.Store.SetTriggerStatus(ctx, pipelineID, schedule, date, types.TriggerStatusFailedFinal); err != nil { - return fmt.Errorf("set trigger status FAILED_FINAL for %q: %w", pipelineID, err) - } - - d.Logger.Info("retry limit reached", - "pipelineId", pipelineID, - "schedule", schedule, - "date", date, - "reruns", rerunCount, - "maxRetries", maxRetries, - ) - return nil - } - - // Calendar exclusion check: skip retry if the execution date is excluded. - // Mark trigger as terminal so the lock doesn't silently expire via TTL. - if isExcludedDate(cfg, date) { - if err := d.Store.SetTriggerStatus(ctx, pipelineID, schedule, date, types.TriggerStatusFailedFinal); err != nil { - d.Logger.WarnContext(ctx, "failed to set trigger status after calendar exclusion", "error", err) - } - if pubErr := publishEvent(ctx, d, string(types.EventPipelineExcluded), pipelineID, schedule, date, - fmt.Sprintf("job failure retry skipped for %s: execution date %s excluded by calendar", pipelineID, date)); pubErr != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPipelineExcluded, "error", pubErr) - } - return nil - } - - // Under retry limit — write rerun record and restart the pipeline. - attempt, err := d.Store.WriteRerun(ctx, pipelineID, schedule, date, "job-fail-retry", jobEvent) - if err != nil { - return fmt.Errorf("write rerun for %q: %w", pipelineID, err) - } - - acquired, err := d.Store.ResetTriggerLock(ctx, pipelineID, schedule, date, ResolveTriggerLockTTL()) - if err != nil { - return fmt.Errorf("reset trigger lock for %q: %w", pipelineID, err) - } - if !acquired { - d.Logger.Warn("failed to reset trigger lock, skipping rerun", - "pipelineId", pipelineID, "schedule", schedule, "date", date) - return nil - } - - // Use a unique execution name that includes the rerun attempt number. - execName := truncateExecName(fmt.Sprintf("%s-%s-%s-rerun-%d", pipelineID, schedule, date, attempt)) - if err := startSFNWithName(ctx, d, cfg, pipelineID, schedule, date, execName); err != nil { - if relErr := d.Store.ReleaseTriggerLock(ctx, pipelineID, schedule, date); relErr != nil { - d.Logger.Warn("failed to release lock after SFN start failure", "error", relErr) - } - return fmt.Errorf("start SFN rerun for %q: %w", pipelineID, err) - } - - d.Logger.Info("started rerun", - "pipelineId", pipelineID, - "schedule", schedule, - "date", date, - "attempt", attempt, - ) - return nil -} - // handleJobSuccess publishes a job-completed event to EventBridge. func handleJobSuccess(ctx context.Context, d *Deps, pipelineID, schedule, date string) error { return publishEvent(ctx, d, string(types.EventJobCompleted), pipelineID, schedule, date, fmt.Sprintf("job completed for %s", pipelineID)) } -// handleRerunRequest processes a RERUN_REQUEST# stream record. It enforces -// per-source rerun limits (drift vs manual) and implements a circuit breaker -// that prevents unnecessary re-runs when the previous run succeeded and no -// sensor data has changed since. -func handleRerunRequest(ctx context.Context, d *Deps, pk, sk string, record events.DynamoDBEventRecord) error { - pipelineID := strings.TrimPrefix(pk, "PIPELINE#") - if pipelineID == pk { - return fmt.Errorf("unexpected PK format: %q", pk) - } - - schedule, date, err := parseRerunRequestSK(sk) - if err != nil { - return err - } - - cfg, err := getValidatedConfig(ctx, d, pipelineID) - if err != nil { - return fmt.Errorf("load config for %q: %w", pipelineID, err) - } - if cfg == nil { - d.Logger.Warn("no config found for pipeline, skipping rerun request", "pipelineId", pipelineID) - return nil - } - - // --- Calendar exclusion check (execution date) --- - if isExcludedDate(cfg, date) { - _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, types.JobEventRerunRejected, "", 0, "excluded by calendar") - if pubErr := publishEvent(ctx, d, string(types.EventPipelineExcluded), pipelineID, schedule, date, - fmt.Sprintf("rerun blocked for %s: execution date %s excluded by calendar", pipelineID, date)); pubErr != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPipelineExcluded, "error", pubErr) - } - return nil - } - - // Extract reason from stream record NewImage. Default to "manual". - reason := "manual" - if img := record.Change.NewImage; img != nil { - if r, ok := img["reason"]; ok && r.DataType() == events.DataTypeString { - if v := r.String(); v != "" { - reason = v - } - } - } - - // --- Rerun limit check --- - var budget int - var sources []string - var limitLabel string - switch reason { - case "data-drift", "late-data": - budget = types.IntOrDefault(cfg.Job.MaxDriftReruns, 1) - sources = []string{"data-drift", "late-data"} - limitLabel = "drift rerun limit exceeded" - default: - budget = types.IntOrDefault(cfg.Job.MaxManualReruns, 1) - sources = []string{reason} - limitLabel = "manual rerun limit exceeded" - } - - count, err := d.Store.CountRerunsBySource(ctx, pipelineID, schedule, date, sources) - if err != nil { - return fmt.Errorf("count reruns by source for %q: %w", pipelineID, err) - } - - if count >= budget { - _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, - types.JobEventRerunRejected, "", 0, limitLabel) - if err := publishEvent(ctx, d, string(types.EventRerunRejected), pipelineID, schedule, date, - fmt.Sprintf("rerun rejected for %s: %s", pipelineID, limitLabel)); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRerunRejected, "error", err) - } - d.Logger.Info("rerun request rejected (limit exceeded)", - "pipelineId", pipelineID, "schedule", schedule, "date", date, - "reason", reason, "count", count, "budget", budget) - return nil - } - - // --- Circuit breaker (sensor freshness) --- - job, err := d.Store.GetLatestJobEvent(ctx, pipelineID, schedule, date) - if err != nil { - return fmt.Errorf("get latest job event for %q/%s/%s: %w", pipelineID, schedule, date, err) - } - - allowed := true - rejectReason := "" - if job != nil && job.Event == types.JobEventSuccess { - fresh, err := checkSensorFreshness(ctx, d, pipelineID, job.SK) - if err != nil { - return fmt.Errorf("check sensor freshness for %q: %w", pipelineID, err) - } - if !fresh { - allowed = false - rejectReason = "previous run succeeded and no sensor data has changed" - } - } - - if !allowed { - _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, - types.JobEventRerunRejected, "", 0, rejectReason) - if err := publishEvent(ctx, d, string(types.EventRerunRejected), pipelineID, schedule, date, - fmt.Sprintf("rerun rejected for %s: %s", pipelineID, rejectReason)); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRerunRejected, "error", err) - } - d.Logger.Info("rerun request rejected", - "pipelineId", pipelineID, "schedule", schedule, "date", date, - "reason", rejectReason) - return nil - } - - // --- Acceptance: write rerun record FIRST (before lock reset) --- - if _, err := d.Store.WriteRerun(ctx, pipelineID, schedule, date, reason, ""); err != nil { - return fmt.Errorf("write rerun for %q: %w", pipelineID, err) - } - - // Delete date-scoped postrun-baseline so re-run captures fresh baseline. - if cfg.PostRun != nil { - _ = d.Store.DeleteSensor(ctx, pipelineID, "postrun-baseline#"+date) - } - - // Atomically reset the trigger lock for the new execution. - acquired, err := d.Store.ResetTriggerLock(ctx, pipelineID, schedule, date, ResolveTriggerLockTTL()) - if err != nil { - return fmt.Errorf("reset trigger lock for %q: %w", pipelineID, err) - } - if !acquired { - if pubErr := publishEvent(ctx, d, string(types.EventInfraFailure), pipelineID, schedule, date, - fmt.Sprintf("lock reset failed for rerun of %s, orphaned rerun record", pipelineID)); pubErr != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "error", pubErr) - } - d.Logger.Warn("failed to reset trigger lock, orphaned rerun record", - "pipelineId", pipelineID, "schedule", schedule, "date", date) - return nil - } - - // Publish acceptance event only after lock atomicity is confirmed. - _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, - types.JobEventRerunAccepted, "", 0, "") - - if pubErr := publishEvent(ctx, d, string(types.EventRerunAccepted), pipelineID, schedule, date, - fmt.Sprintf("rerun accepted for %s (reason: %s)", pipelineID, reason)); pubErr != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventRerunAccepted, "error", pubErr) - } - - execName := truncateExecName(fmt.Sprintf("%s-%s-%s-%s-rerun-%d", pipelineID, schedule, date, reason, time.Now().Unix())) - if err := startSFNWithName(ctx, d, cfg, pipelineID, schedule, date, execName); err != nil { - if relErr := d.Store.ReleaseTriggerLock(ctx, pipelineID, schedule, date); relErr != nil { - d.Logger.Warn("failed to release lock after SFN start failure", "error", relErr) - } - return fmt.Errorf("start SFN rerun for %q: %w", pipelineID, err) - } - - d.Logger.Info("started rerun", - "pipelineId", pipelineID, "schedule", schedule, "date", date, "reason", reason) - return nil -} - -// parseRerunRequestSK extracts schedule and date from a RERUN_REQUEST# sort key. -// Expected format: RERUN_REQUEST## -func parseRerunRequestSK(sk string) (schedule, date string, err error) { - trimmed := strings.TrimPrefix(sk, "RERUN_REQUEST#") - parts := strings.SplitN(trimmed, "#", 2) - if len(parts) < 2 { - return "", "", fmt.Errorf("invalid RERUN_REQUEST SK format: %q", sk) - } - return parts[0], parts[1], nil -} - -// checkSensorFreshness determines whether any sensor data has been updated -// after the given job completed. The job timestamp is extracted from the job -// SK (format: JOB#schedule#date#). Returns true if data has -// changed (rerun should proceed) or if freshness cannot be determined. -func checkSensorFreshness(ctx context.Context, d *Deps, pipelineID, jobSK string) (bool, error) { - // Extract timestamp from the job SK. - parts := strings.Split(jobSK, "#") - if len(parts) < 4 { - // Can't parse timestamp — allow to be safe. - return true, nil - } - jobTimestamp, err := strconv.ParseInt(parts[len(parts)-1], 10, 64) - if err != nil { - // Can't parse timestamp — allow to be safe. - return true, nil - } - - sensors, err := d.Store.GetAllSensors(ctx, pipelineID) - if err != nil { - return false, fmt.Errorf("get sensors for %q: %w", pipelineID, err) - } - if len(sensors) == 0 { - // No sensors — can't prove unchanged, allow. - return true, nil - } - - hasAnyUpdatedAt := false - for _, data := range sensors { - updatedAt, ok := data["updatedAt"] - if !ok { - continue - } - hasAnyUpdatedAt = true - - var ts int64 - switch v := updatedAt.(type) { - case float64: - ts = int64(v) - case int64: - ts = v - case string: - ts, err = strconv.ParseInt(v, 10, 64) - if err != nil { - continue - } - default: - continue - } - - if ts > jobTimestamp { - return true, nil // Data changed after job — allow rerun. - } - } - - if !hasAnyUpdatedAt { - // No sensors have updatedAt — can't prove unchanged, allow. - return true, nil - } - - // All sensor timestamps are older than the job — data unchanged. - return false, nil -} - // handleSensorEvent evaluates the trigger condition for a sensor write // and starts the Step Function execution if all conditions are met. func handleSensorEvent(ctx context.Context, d *Deps, pk, sk string, record events.DynamoDBEventRecord) error { @@ -583,547 +247,3 @@ func handleSensorEvent(ctx context.Context, d *Deps, pk, sk string, record event ) return nil } - -// checkLateDataArrival detects sensor updates after a pipeline has completed -// successfully. If the trigger is in terminal COMPLETED state and the latest -// job event is success, this sensor write represents late data that arrived -// after post-job monitoring closed. Dual-writes a joblog entry and publishes -// a LATE_DATA_ARRIVAL event. -func checkLateDataArrival(ctx context.Context, d *Deps, pipelineID, schedule, date string) error { - trigger, err := d.Store.GetTrigger(ctx, pipelineID, schedule, date) - if err != nil || trigger == nil { - return err - } - - if trigger.Status != types.TriggerStatusCompleted { - return nil // still running or failed — not late data - } - - job, err := d.Store.GetLatestJobEvent(ctx, pipelineID, schedule, date) - if err != nil || job == nil { - return err - } - - if job.Event != types.JobEventSuccess { - return nil // job didn't succeed — not a "late data after success" scenario - } - - // Dual-write: joblog entry (audit) + EventBridge event (alerting). - _ = d.Store.WriteJobEvent(ctx, pipelineID, schedule, date, - types.JobEventLateDataArrival, "", 0, - "sensor updated after pipeline completed successfully") - - if err := publishEvent(ctx, d, string(types.EventLateDataArrival), pipelineID, schedule, date, - fmt.Sprintf("late data arrival for %s: sensor updated after job completion", pipelineID)); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventLateDataArrival, "error", err) - } - - // Trigger a re-run — circuit breaker in handleRerunRequest will validate sensor freshness. - if writeErr := d.Store.WriteRerunRequest(ctx, pipelineID, schedule, date, "late-data"); writeErr != nil { - d.Logger.WarnContext(ctx, "failed to write rerun request on late data", "pipelineId", pipelineID, "error", writeErr) - } - - return nil -} - -// matchesPostRunRule returns true if the sensor key matches any post-run rule key -// (prefix match to support per-period sensor keys). -func matchesPostRunRule(sensorKey string, rules []types.ValidationRule) bool { - for _, rule := range rules { - if strings.HasPrefix(sensorKey, rule.Key) { - return true - } - } - return false -} - -// handlePostRunSensorEvent evaluates post-run rules reactively when a sensor -// arrives via DynamoDB Stream. Compares current sensor values against the -// date-scoped baseline captured at trigger completion. -func handlePostRunSensorEvent(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, sensorKey string, sensorData map[string]interface{}) error { - scheduleID := resolveScheduleID(cfg) - date := ResolveExecutionDate(sensorData) - - // Consistent read to handle race where sensor stream event arrives - // before SFN sets trigger to COMPLETED. - trigger, err := d.Store.GetTrigger(ctx, pipelineID, scheduleID, date) - if err != nil { - return fmt.Errorf("get trigger for post-run: %w", err) - } - if trigger == nil { - return nil // No trigger for this date — not a post-run event. - } - - switch trigger.Status { - case types.TriggerStatusRunning: - // Job still running — evaluate rules for informational drift detection. - return handlePostRunInflight(ctx, d, cfg, pipelineID, scheduleID, date, sensorKey, sensorData) - - case types.TriggerStatusCompleted: - // Job completed — full post-run evaluation with baseline comparison. - return handlePostRunCompleted(ctx, d, cfg, pipelineID, scheduleID, date, sensorData) - - default: - // FAILED_FINAL or unknown — skip. - return nil - } -} - -// handlePostRunInflight evaluates post-run rules while the job is still running. -// If drift is detected, publishes an informational event but does NOT trigger a rerun. -func handlePostRunInflight(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date, sensorKey string, sensorData map[string]interface{}) error { - // Read baseline for comparison. - baselineKey := "postrun-baseline#" + date - baseline, err := d.Store.GetSensorData(ctx, pipelineID, baselineKey) - if err != nil { - return fmt.Errorf("get baseline for inflight check: %w", err) - } - if baseline == nil { - return nil // No baseline yet — job hasn't completed once. - } - - prevCount := ExtractFloat(baseline, "sensor_count") - currCount := ExtractFloat(sensorData, "sensor_count") - threshold := 0.0 - if cfg.PostRun.DriftThreshold != nil { - threshold = *cfg.PostRun.DriftThreshold - } - if prevCount > 0 && currCount > 0 && math.Abs(currCount-prevCount) > threshold { - if err := publishEvent(ctx, d, string(types.EventPostRunDriftInflight), pipelineID, scheduleID, date, - fmt.Sprintf("inflight drift detected for %s: %.0f → %.0f (informational)", pipelineID, prevCount, currCount), - map[string]interface{}{ - "previousCount": prevCount, - "currentCount": currCount, - "driftThreshold": threshold, - "sensorKey": sensorKey, - "source": "post-run-stream", - }); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunDriftInflight, "error", err) - } - } - return nil -} - -// handlePostRunCompleted evaluates post-run rules after the job has completed. -// Compares sensor values against the date-scoped baseline and triggers a rerun -// if drift is detected. -func handlePostRunCompleted(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date string, sensorData map[string]interface{}) error { - // Read baseline captured at trigger completion. - baselineKey := "postrun-baseline#" + date - baseline, err := d.Store.GetSensorData(ctx, pipelineID, baselineKey) - if err != nil { - return fmt.Errorf("get baseline for post-run: %w", err) - } - - // Check for data drift if baseline exists. - if baseline != nil { - prevCount := ExtractFloat(baseline, "sensor_count") - currCount := ExtractFloat(sensorData, "sensor_count") - threshold := 0.0 - if cfg.PostRun.DriftThreshold != nil { - threshold = *cfg.PostRun.DriftThreshold - } - if prevCount > 0 && currCount > 0 && math.Abs(currCount-prevCount) > threshold { - delta := currCount - prevCount - if err := publishEvent(ctx, d, string(types.EventPostRunDrift), pipelineID, scheduleID, date, - fmt.Sprintf("post-run drift detected for %s: %.0f → %.0f records", pipelineID, prevCount, currCount), - map[string]interface{}{ - "previousCount": prevCount, - "currentCount": currCount, - "delta": delta, - "driftThreshold": threshold, - "source": "post-run-stream", - }); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunDrift, "error", err) - } - - // Trigger rerun via the existing circuit breaker path only if the - // execution date is not excluded by the pipeline's calendar config. - if isExcludedDate(cfg, date) { - if pubErr := publishEvent(ctx, d, string(types.EventPipelineExcluded), pipelineID, scheduleID, date, - fmt.Sprintf("post-run drift rerun skipped for %s: execution date %s excluded by calendar", pipelineID, date)); pubErr != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPipelineExcluded, "error", pubErr) - } - d.Logger.InfoContext(ctx, "post-run drift rerun skipped: execution date excluded by calendar", - "pipelineId", pipelineID, "date", date) - } else { - if writeErr := d.Store.WriteRerunRequest(ctx, pipelineID, scheduleID, date, "data-drift"); writeErr != nil { - d.Logger.WarnContext(ctx, "failed to write rerun request on post-run drift", - "pipelineId", pipelineID, "error", writeErr) - } - } - return nil - } - } - - // Evaluate post-run validation rules. - sensors, err := d.Store.GetAllSensors(ctx, pipelineID) - if err != nil { - return fmt.Errorf("get sensors for post-run rules: %w", err) - } - RemapPerPeriodSensors(sensors, date) - - result := validation.EvaluateRules("ALL", cfg.PostRun.Rules, sensors, time.Now()) - - if result.Passed { - if err := publishEvent(ctx, d, string(types.EventPostRunPassed), pipelineID, scheduleID, date, - fmt.Sprintf("post-run validation passed for %s", pipelineID)); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunPassed, "error", err) - } - } else { - if err := publishEvent(ctx, d, string(types.EventPostRunFailed), pipelineID, scheduleID, date, - fmt.Sprintf("post-run validation failed for %s", pipelineID)); err != nil { - d.Logger.WarnContext(ctx, "failed to publish event", "type", types.EventPostRunFailed, "error", err) - } - } - - return nil -} - -// sfnInput is the top-level input for the Step Function state machine. -// It includes pipeline identity fields and a config block used by Wait states. -type sfnInput struct { - PipelineID string `json:"pipelineId"` - ScheduleID string `json:"scheduleId"` - Date string `json:"date"` - Config sfnConfig `json:"config"` -} - -// sfnConfig holds timing parameters for the SFN evaluation loop and SLA branch. -type sfnConfig struct { - EvaluationIntervalSeconds int `json:"evaluationIntervalSeconds"` - EvaluationWindowSeconds int `json:"evaluationWindowSeconds"` - JobCheckIntervalSeconds int `json:"jobCheckIntervalSeconds"` - JobPollWindowSeconds int `json:"jobPollWindowSeconds"` - SLA *types.SLAConfig `json:"sla,omitempty"` -} - -// buildSFNConfig converts a PipelineConfig into the config block for the SFN input. -func buildSFNConfig(cfg *types.PipelineConfig) sfnConfig { - sc := sfnConfig{ - EvaluationIntervalSeconds: 300, // 5m default - EvaluationWindowSeconds: 3600, // 1h default - JobCheckIntervalSeconds: 60, // 1m default - JobPollWindowSeconds: 3600, // 1h default - } - - if d, err := time.ParseDuration(cfg.Schedule.Evaluation.Interval); err == nil && d > 0 { - sc.EvaluationIntervalSeconds = int(d.Seconds()) - } - if d, err := time.ParseDuration(cfg.Schedule.Evaluation.Window); err == nil && d > 0 { - sc.EvaluationWindowSeconds = int(d.Seconds()) - } - - if cfg.Job.JobPollWindowSeconds != nil && *cfg.Job.JobPollWindowSeconds > 0 { - sc.JobPollWindowSeconds = *cfg.Job.JobPollWindowSeconds - } - - if cfg.SLA != nil { - sla := *cfg.SLA - if sla.Timezone == "" { - sla.Timezone = "UTC" - } - sc.SLA = &sla - } - - return sc -} - -// truncateExecName ensures an SFN execution name does not exceed the 80-character -// AWS limit. When truncation is needed the suffix (date + timestamp) is preserved -// by trimming characters from the beginning of the name. -func truncateExecName(name string) string { - const maxLen = 80 - if len(name) <= maxLen { - return name - } - return name[len(name)-maxLen:] -} - -// startSFN starts a Step Function execution with a unique execution name. -// The name includes a Unix timestamp suffix to avoid ExecutionAlreadyExists -// errors when a previous execution for the same pipeline/schedule/date failed. -func startSFN(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date string) error { - name := truncateExecName(fmt.Sprintf("%s-%s-%s-%d", pipelineID, scheduleID, date, time.Now().Unix())) - return startSFNWithName(ctx, d, cfg, pipelineID, scheduleID, date, name) -} - -// startSFNWithName starts a Step Function execution with a custom execution name. -func startSFNWithName(ctx context.Context, d *Deps, cfg *types.PipelineConfig, pipelineID, scheduleID, date, name string) error { - sc := buildSFNConfig(cfg) - - // Warn if the sum of evaluation + poll windows exceeds the SFN timeout. - totalWindowSec := sc.EvaluationWindowSeconds + sc.JobPollWindowSeconds - sfnTimeout := ResolveTriggerLockTTL() - 30*time.Minute // strip the buffer to get raw SFN timeout - if sfnTimeout > 0 && time.Duration(totalWindowSec)*time.Second > sfnTimeout { - d.Logger.Warn("combined pipeline windows exceed SFN timeout", - "pipelineId", pipelineID, - "evalWindowSec", sc.EvaluationWindowSeconds, - "jobPollWindowSec", sc.JobPollWindowSeconds, - "totalWindowSec", totalWindowSec, - "sfnTimeoutSec", int(sfnTimeout.Seconds()), - ) - } - - input := sfnInput{ - PipelineID: pipelineID, - ScheduleID: scheduleID, - Date: date, - Config: sc, - } - payload, err := json.Marshal(input) - if err != nil { - return fmt.Errorf("marshal SFN input: %w", err) - } - - inputStr := string(payload) - - _, err = d.SFNClient.StartExecution(ctx, &sfn.StartExecutionInput{ - StateMachineArn: &d.StateMachineARN, - Name: &name, - Input: &inputStr, - }) - if err != nil { - return fmt.Errorf("StartExecution: %w", err) - } - return nil -} - -// extractKeys returns the PK and SK string values from a DynamoDB stream record. -func extractKeys(record events.DynamoDBEventRecord) (pk, sk string) { - keys := record.Change.Keys - if pkAttr, ok := keys["PK"]; ok && pkAttr.DataType() == events.DataTypeString { - pk = pkAttr.String() - } - if skAttr, ok := keys["SK"]; ok && skAttr.DataType() == events.DataTypeString { - sk = skAttr.String() - } - return pk, sk -} - -// extractSensorData converts a DynamoDB stream NewImage to a plain map -// suitable for validation rule evaluation. If the item uses the canonical -// ControlRecord format (sensor fields nested inside a "data" map attribute), -// the "data" map is unwrapped so fields are accessible at the top level. -func extractSensorData(newImage map[string]events.DynamoDBAttributeValue) map[string]interface{} { - if newImage == nil { - return nil - } - - skipKeys := map[string]bool{"PK": true, "SK": true, "ttl": true} - result := make(map[string]interface{}, len(newImage)) - - for k, av := range newImage { - if skipKeys[k] { - continue - } - result[k] = convertAttributeValue(av) - } - - // Unwrap the "data" map if present (canonical ControlRecord sensor format). - if dataMap, ok := result["data"].(map[string]interface{}); ok { - return dataMap - } - return result -} - -// convertAttributeValue converts a DynamoDB stream attribute value to a Go native type. -func convertAttributeValue(av events.DynamoDBAttributeValue) interface{} { - switch av.DataType() { - case events.DataTypeString: - return av.String() - case events.DataTypeNumber: - // Try int first, fall back to float. - if i, err := strconv.ParseInt(av.Number(), 10, 64); err == nil { - return float64(i) - } - if f, err := strconv.ParseFloat(av.Number(), 64); err == nil { - return f - } - return av.Number() - case events.DataTypeBoolean: - return av.Boolean() - case events.DataTypeNull: - return nil - case events.DataTypeMap: - m := av.Map() - out := make(map[string]interface{}, len(m)) - for k, v := range m { - out[k] = convertAttributeValue(v) - } - return out - case events.DataTypeList: - l := av.List() - out := make([]interface{}, len(l)) - for i, v := range l { - out[i] = convertAttributeValue(v) - } - return out - default: - return nil - } -} - -// ResolveExecutionDate builds the execution date from sensor data fields. -// If both "date" and "hour" are present, returns "YYYY-MM-DDThh". -// If only "date", returns "YYYY-MM-DD". Falls back to today's date. -func ResolveExecutionDate(sensorData map[string]interface{}) string { - dateStr, _ := sensorData["date"].(string) - hourStr, _ := sensorData["hour"].(string) - - if dateStr == "" { - return time.Now().Format("2006-01-02") - } - - normalized := normalizeDate(dateStr) - // Validate YYYY-MM-DD format. - if _, err := time.Parse("2006-01-02", normalized); err != nil { - return time.Now().Format("2006-01-02") - } - - if hourStr != "" { - // Validate hour is 2-digit 00-23. - if len(hourStr) == 2 { - if h, err := strconv.Atoi(hourStr); err == nil && h >= 0 && h <= 23 { - return normalized + "T" + hourStr - } - } - return normalized - } - return normalized -} - -// normalizeDate converts YYYYMMDD to YYYY-MM-DD. Already-dashed dates pass through. -func normalizeDate(s string) string { - if len(s) == 8 && !strings.Contains(s, "-") { - return s[:4] + "-" + s[4:6] + "-" + s[6:8] - } - return s -} - -// resolveScheduleID returns "cron" if the pipeline uses a cron schedule, -// otherwise returns "stream". -func resolveScheduleID(cfg *types.PipelineConfig) string { - if cfg.Schedule.Cron != "" { - return "cron" - } - return "stream" -} - -// isExcludedDate checks calendar exclusions against a job's execution date -// (not wall-clock time). dateStr supports "YYYY-MM-DD" and "YYYY-MM-DDTHH". -func isExcludedDate(cfg *types.PipelineConfig, dateStr string) bool { - excl := cfg.Schedule.Exclude - if excl == nil { - return false - } - if len(dateStr) < 10 { - return false // unparseable, safe default - } - datePortion := dateStr[:10] - - // Resolve the location to interpret the execution date in. - loc := time.UTC - if cfg.Schedule.Timezone != "" { - if l, err := time.LoadLocation(cfg.Schedule.Timezone); err == nil { - loc = l - } - } - - // Parse the date as midnight in the configured timezone so that weekday - // and date-string comparisons reflect the local calendar date. - t, err := time.ParseInLocation("2006-01-02", datePortion, loc) - if err != nil { - return false // safe default - } - - if excl.Weekends { - day := t.Weekday() - if day == time.Saturday || day == time.Sunday { - return true - } - } - dateStr2 := t.Format("2006-01-02") - for _, d := range excl.Dates { - if d == dateStr2 { - return true - } - } - return false -} - -// isExcluded checks whether the pipeline should be excluded from running -// based on calendar exclusions (weekends and specific dates). -func isExcluded(cfg *types.PipelineConfig, now time.Time) bool { - excl := cfg.Schedule.Exclude - if excl == nil { - return false - } - - // Resolve timezone if configured. - t := now - if cfg.Schedule.Timezone != "" { - if loc, err := time.LoadLocation(cfg.Schedule.Timezone); err == nil { - t = now.In(loc) - } - } - - // Check weekends. - if excl.Weekends { - day := t.Weekday() - if day == time.Saturday || day == time.Sunday { - return true - } - } - - // Check specific dates. - dateStr := t.Format("2006-01-02") - for _, d := range excl.Dates { - if d == dateStr { - return true - } - } - - return false -} - -// publishEvent sends an event to EventBridge. It is safe to call when -// EventBridge is nil or EventBusName is empty (returns nil with no action). -func publishEvent(ctx context.Context, d *Deps, eventType, pipelineID, schedule, date, message string, detail ...map[string]interface{}) error { - if d.EventBridge == nil || d.EventBusName == "" { - return nil - } - - evt := types.InterlockEvent{ - PipelineID: pipelineID, - ScheduleID: schedule, - Date: date, - Message: message, - Timestamp: time.Now(), - } - if len(detail) > 0 && detail[0] != nil { - evt.Detail = detail[0] - } - detailJSON, err := json.Marshal(evt) - if err != nil { - return fmt.Errorf("marshal event detail: %w", err) - } - - source := types.EventSource - detailStr := string(detailJSON) - - _, err = d.EventBridge.PutEvents(ctx, &eventbridge.PutEventsInput{ - Entries: []ebTypes.PutEventsRequestEntry{ - { - Source: &source, - DetailType: &eventType, - Detail: &detailStr, - EventBusName: &d.EventBusName, - }, - }, - }) - if err != nil { - return fmt.Errorf("publish %s event: %w", eventType, err) - } - return nil -} diff --git a/internal/store/control.go b/internal/store/control.go index d4b4a8f..2c14ed5 100644 --- a/internal/store/control.go +++ b/internal/store/control.go @@ -20,9 +20,8 @@ import ( func (s *Store) ScanConfigs(ctx context.Context) (map[string]*types.PipelineConfig, error) { configs := make(map[string]*types.PipelineConfig) - var startKey map[string]ddbtypes.AttributeValue - for { - input := &dynamodb.ScanInput{ + err := ScanAll(ctx, s.Client, func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.ScanInput { + return &dynamodb.ScanInput{ TableName: &s.ControlTable, FilterExpression: aws.String("SK = :sk"), ExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ @@ -30,50 +29,43 @@ func (s *Store) ScanConfigs(ctx context.Context) (map[string]*types.PipelineConf }, ExclusiveStartKey: startKey, } - - out, err := s.Client.Scan(ctx, input) - if err != nil { - return nil, fmt.Errorf("scan configs: %w", err) - } - - for _, item := range out.Items { + }, func(items []map[string]ddbtypes.AttributeValue) error { + for _, item := range items { pkAttr, ok := item["PK"] if !ok { - return nil, fmt.Errorf("scan configs: row missing PK attribute") + return fmt.Errorf("row missing PK attribute") } pkStr, ok := pkAttr.(*ddbtypes.AttributeValueMemberS) if !ok { - return nil, fmt.Errorf("scan configs: PK is not a string") + return fmt.Errorf("PK is not a string") } const prefix = "PIPELINE#" if len(pkStr.Value) <= len(prefix) { - return nil, fmt.Errorf("scan configs: invalid PK %q", pkStr.Value) + return fmt.Errorf("invalid PK %q", pkStr.Value) } pipelineID := pkStr.Value[len(prefix):] configAttr, ok := item["config"] if !ok { - return nil, fmt.Errorf("scan configs: config attribute missing for %q", pipelineID) + return fmt.Errorf("config attribute missing for %q", pipelineID) } configStr, ok := configAttr.(*ddbtypes.AttributeValueMemberS) if !ok { - return nil, fmt.Errorf("scan configs: config is not a string for %q", pipelineID) + return fmt.Errorf("config is not a string for %q", pipelineID) } var cfg types.PipelineConfig if err := json.Unmarshal([]byte(configStr.Value), &cfg); err != nil { - return nil, fmt.Errorf("scan configs: unmarshal config for %q: %w", pipelineID, err) + return fmt.Errorf("unmarshal config for %q: %w", pipelineID, err) } configs[pipelineID] = &cfg } - - if out.LastEvaluatedKey == nil { - break - } - startKey = out.LastEvaluatedKey + return nil + }) + if err != nil { + return nil, fmt.Errorf("scan configs: %w", err) } - return configs, nil } @@ -162,9 +154,8 @@ func (s *Store) GetSensorData(ctx context.Context, pipelineID, sensorKey string) func (s *Store) GetAllSensors(ctx context.Context, pipelineID string) (map[string]map[string]interface{}, error) { result := make(map[string]map[string]interface{}) - var startKey map[string]ddbtypes.AttributeValue - for { - input := &dynamodb.QueryInput{ + err := QueryAll(ctx, s.Client, func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.QueryInput { + return &dynamodb.QueryInput{ TableName: &s.ControlTable, ConsistentRead: aws.Bool(true), KeyConditionExpression: aws.String("PK = :pk AND begins_with(SK, :prefix)"), @@ -174,29 +165,22 @@ func (s *Store) GetAllSensors(ctx context.Context, pipelineID string) (map[strin }, ExclusiveStartKey: startKey, } - - out, err := s.Client.Query(ctx, input) - if err != nil { - return nil, fmt.Errorf("query sensors for %q: %w", pipelineID, err) - } - - for _, item := range out.Items { + }, func(items []map[string]ddbtypes.AttributeValue) error { + for _, item := range items { var rec types.ControlRecord if err := attributevalue.UnmarshalMap(item, &rec); err != nil { - return nil, fmt.Errorf("unmarshal sensor row: %w", err) + return fmt.Errorf("unmarshal sensor row: %w", err) } - // Extract the sensor key from the SK: "SENSOR#" const prefix = "SENSOR#" if len(rec.SK) > len(prefix) { key := rec.SK[len(prefix):] result[key] = rec.Data } } - - if out.LastEvaluatedKey == nil { - break - } - startKey = out.LastEvaluatedKey + return nil + }) + if err != nil { + return nil, fmt.Errorf("query sensors for %q: %w", pipelineID, err) } return result, nil } @@ -304,9 +288,8 @@ func (s *Store) ReleaseTriggerLock(ctx context.Context, pipelineID, schedule, da func (s *Store) ScanRunningTriggers(ctx context.Context) ([]types.ControlRecord, error) { var records []types.ControlRecord - var startKey map[string]ddbtypes.AttributeValue - for { - input := &dynamodb.ScanInput{ + err := ScanAll(ctx, s.Client, func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.ScanInput { + return &dynamodb.ScanInput{ TableName: &s.ControlTable, FilterExpression: aws.String("begins_with(SK, :prefix) AND #status = :running"), ExpressionAttributeNames: map[string]string{ @@ -318,24 +301,18 @@ func (s *Store) ScanRunningTriggers(ctx context.Context) ([]types.ControlRecord, }, ExclusiveStartKey: startKey, } - - out, err := s.Client.Scan(ctx, input) - if err != nil { - return nil, fmt.Errorf("scan running triggers: %w", err) - } - - for _, item := range out.Items { + }, func(items []map[string]ddbtypes.AttributeValue) error { + for _, item := range items { var rec types.ControlRecord if err := attributevalue.UnmarshalMap(item, &rec); err != nil { - return nil, fmt.Errorf("unmarshal trigger row: %w", err) + return fmt.Errorf("unmarshal trigger row: %w", err) } records = append(records, rec) } - - if out.LastEvaluatedKey == nil { - break - } - startKey = out.LastEvaluatedKey + return nil + }) + if err != nil { + return nil, fmt.Errorf("scan running triggers: %w", err) } return records, nil } diff --git a/internal/store/pagination.go b/internal/store/pagination.go new file mode 100644 index 0000000..14e51e2 --- /dev/null +++ b/internal/store/pagination.go @@ -0,0 +1,69 @@ +package store + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + ddbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// ScanAll runs a paginated DynamoDB Scan. buildInput is called for each page +// with the ExclusiveStartKey (nil for the first page). processPage is called +// for each page of results. +func ScanAll(ctx context.Context, client DynamoAPI, buildInput func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.ScanInput, processPage func(items []map[string]ddbtypes.AttributeValue) error) error { + var startKey map[string]ddbtypes.AttributeValue + for { + input := buildInput(startKey) + out, err := client.Scan(ctx, input) + if err != nil { + return err + } + if err := processPage(out.Items); err != nil { + return err + } + if out.LastEvaluatedKey == nil { + return nil + } + startKey = out.LastEvaluatedKey + } +} + +// QueryAll runs a paginated DynamoDB Query. buildInput is called for each page +// with the ExclusiveStartKey (nil for the first page). processPage is called +// for each page of results. +func QueryAll(ctx context.Context, client DynamoAPI, buildInput func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.QueryInput, processPage func(items []map[string]ddbtypes.AttributeValue) error) error { + var startKey map[string]ddbtypes.AttributeValue + for { + input := buildInput(startKey) + out, err := client.Query(ctx, input) + if err != nil { + return err + } + if err := processPage(out.Items); err != nil { + return err + } + if out.LastEvaluatedKey == nil { + return nil + } + startKey = out.LastEvaluatedKey + } +} + +// QueryCount runs a paginated DynamoDB Query with Select=COUNT and returns +// the total count across all pages. +func QueryCount(ctx context.Context, client DynamoAPI, buildInput func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.QueryInput) (int, error) { + total := 0 + var startKey map[string]ddbtypes.AttributeValue + for { + input := buildInput(startKey) + out, err := client.Query(ctx, input) + if err != nil { + return 0, err + } + total += int(out.Count) + if out.LastEvaluatedKey == nil { + return total, nil + } + startKey = out.LastEvaluatedKey + } +} diff --git a/internal/store/rerun.go b/internal/store/rerun.go index 96b5622..96d7ad9 100644 --- a/internal/store/rerun.go +++ b/internal/store/rerun.go @@ -59,11 +59,8 @@ func (s *Store) WriteRerun(ctx context.Context, pipelineID, schedule, date, reas // It uses a count-only query (no item data returned) and handles pagination. func (s *Store) CountReruns(ctx context.Context, pipelineID, schedule, date string) (int, error) { prefix := fmt.Sprintf("RERUN#%s#%s#", schedule, date) - total := 0 - - var startKey map[string]ddbtypes.AttributeValue - for { - out, err := s.Client.Query(ctx, &dynamodb.QueryInput{ + total, err := QueryCount(ctx, s.Client, func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.QueryInput { + return &dynamodb.QueryInput{ TableName: &s.RerunTable, Select: ddbtypes.SelectCount, KeyConditionExpression: aws.String("PK = :pk AND begins_with(SK, :prefix)"), @@ -72,19 +69,11 @@ func (s *Store) CountReruns(ctx context.Context, pipelineID, schedule, date stri ":prefix": &ddbtypes.AttributeValueMemberS{Value: prefix}, }, ExclusiveStartKey: startKey, - }) - if err != nil { - return 0, fmt.Errorf("count reruns %q/%s/%s: %w", pipelineID, schedule, date, err) } - - total += int(out.Count) - - if out.LastEvaluatedKey == nil { - break - } - startKey = out.LastEvaluatedKey + }) + if err != nil { + return 0, fmt.Errorf("count reruns %q/%s/%s: %w", pipelineID, schedule, date, err) } - return total, nil } @@ -111,25 +100,18 @@ func (s *Store) CountRerunsBySource(ctx context.Context, pipelineID, schedule, d } filterExpr := fmt.Sprintf("reason IN (%s)", strings.Join(placeholders, ", ")) - total := 0 - var startKey map[string]ddbtypes.AttributeValue - for { - out, err := s.Client.Query(ctx, &dynamodb.QueryInput{ + total, err := QueryCount(ctx, s.Client, func(startKey map[string]ddbtypes.AttributeValue) *dynamodb.QueryInput { + return &dynamodb.QueryInput{ TableName: &s.RerunTable, Select: ddbtypes.SelectCount, KeyConditionExpression: aws.String("PK = :pk AND begins_with(SK, :prefix)"), FilterExpression: aws.String(filterExpr), ExpressionAttributeValues: filterValues, ExclusiveStartKey: startKey, - }) - if err != nil { - return 0, fmt.Errorf("count reruns by source %q/%s/%s: %w", pipelineID, schedule, date, err) - } - total += int(out.Count) - if out.LastEvaluatedKey == nil { - break } - startKey = out.LastEvaluatedKey + }) + if err != nil { + return 0, fmt.Errorf("count reruns by source %q/%s/%s: %w", pipelineID, schedule, date, err) } return total, nil } diff --git a/internal/trigger/airflow_test.go b/internal/trigger/airflow_test.go index 02efd0b..ddd7197 100644 --- a/internal/trigger/airflow_test.go +++ b/internal/trigger/airflow_test.go @@ -123,3 +123,135 @@ func TestCheckAirflowStatus_Failed(t *testing.T) { require.NoError(t, err) assert.Equal(t, "failed", state) } + +func TestExecuteAirflow_MissingURL(t *testing.T) { + cfg := &types.AirflowTriggerConfig{DagID: "my_dag"} + _, err := ExecuteAirflow(context.Background(), cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "url is required") +} + +func TestExecuteAirflow_MissingDagID(t *testing.T) { + cfg := &types.AirflowTriggerConfig{URL: "http://example.com"} + _, err := ExecuteAirflow(context.Background(), cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "dagID is required") +} + +func TestExecuteAirflow_WithBody(t *testing.T) { + var receivedConf interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var payload map[string]interface{} + _ = json.NewDecoder(r.Body).Decode(&payload) + receivedConf = payload["conf"] + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "dag_run_id": "run-with-body", + }) + })) + defer srv.Close() + + cfg := &types.AirflowTriggerConfig{ + URL: srv.URL, + DagID: "my_dag", + Body: `{"key": "value"}`, + } + + meta, err := ExecuteAirflow(context.Background(), cfg) + require.NoError(t, err) + assert.Equal(t, "run-with-body", meta["airflow_dag_run_id"]) + assert.NotNil(t, receivedConf) +} + +func TestExecuteAirflow_InvalidBodyJSON(t *testing.T) { + cfg := &types.AirflowTriggerConfig{ + URL: "http://example.com", + DagID: "my_dag", + Body: `{invalid json`, + } + _, err := ExecuteAirflow(context.Background(), cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid body JSON") +} + +func TestExecuteAirflow_MissingDagRunIDInResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "queued", + }) + })) + defer srv.Close() + + cfg := &types.AirflowTriggerConfig{ + URL: srv.URL, + DagID: "my_dag", + } + _, err := ExecuteAirflow(context.Background(), cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "response missing dag_run_id") +} + +func TestExecuteAirflow_CustomTimeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "dag_run_id": "run-timeout", + }) + })) + defer srv.Close() + + cfg := &types.AirflowTriggerConfig{ + URL: srv.URL, + DagID: "my_dag", + Timeout: 60, // Different from defaultTriggerTimeout (30s) + } + meta, err := ExecuteAirflow(context.Background(), cfg) + require.NoError(t, err) + assert.Equal(t, "run-timeout", meta["airflow_dag_run_id"]) +} + +func TestCheckAirflowStatus_ServerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("server error")) + })) + defer srv.Close() + + _, err := CheckAirflowStatus(context.Background(), srv.URL, "my_dag", "run-1", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "status 500") +} + +func TestCheckAirflowStatus_MissingStateField(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "dag_run_id": "run-no-state", + }) + })) + defer srv.Close() + + _, err := CheckAirflowStatus(context.Background(), srv.URL, "my_dag", "run-1", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "response missing state field") +} + +func TestCheckAirflowStatus_WithHeaders(t *testing.T) { + var receivedAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "running", + }) + })) + defer srv.Close() + + state, err := CheckAirflowStatus(context.Background(), srv.URL, "my_dag", "run-1", map[string]string{ + "Authorization": "Bearer test-token", + }) + require.NoError(t, err) + assert.Equal(t, "running", state) + assert.Equal(t, "Bearer test-token", receivedAuth) +} diff --git a/internal/trigger/databricks_test.go b/internal/trigger/databricks_test.go index 889bc1e..c255545 100644 --- a/internal/trigger/databricks_test.go +++ b/internal/trigger/databricks_test.go @@ -145,3 +145,62 @@ func TestCheckDatabricksStatus_MissingMetadata(t *testing.T) { require.NoError(t, err) assert.Equal(t, RunCheckRunning, result.State) } + +func TestCheckDatabricksStatus_ServerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal error")) + })) + defer srv.Close() + + r := NewRunner(WithHTTPClient(srv.Client())) + _, err := r.checkDatabricksStatus(context.Background(), map[string]interface{}{ + "databricks_workspace_url": srv.URL, + "databricks_run_id": "123", + }, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "status 500") +} + +func TestCheckDatabricksStatus_MissingWorkspaceURL(t *testing.T) { + r := NewRunner() + result, err := r.checkDatabricksStatus(context.Background(), map[string]interface{}{ + "databricks_run_id": "123", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "missing databricks metadata", result.Message) +} + +func TestCheckDatabricksStatus_MissingRunID(t *testing.T) { + r := NewRunner() + result, err := r.checkDatabricksStatus(context.Background(), map[string]interface{}{ + "databricks_workspace_url": "https://example.com", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "missing databricks metadata", result.Message) +} + +func TestCheckDatabricksStatus_WithHeaders(t *testing.T) { + var receivedAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": map[string]interface{}{ + "life_cycle_state": "RUNNING", + }, + }) + })) + defer srv.Close() + + r := NewRunner(WithHTTPClient(srv.Client())) + result, err := r.checkDatabricksStatus(context.Background(), map[string]interface{}{ + "databricks_workspace_url": srv.URL, + "databricks_run_id": "123", + }, map[string]string{"Authorization": "Bearer test-token"}) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "Bearer test-token", receivedAuth) +} diff --git a/internal/trigger/doc.go b/internal/trigger/doc.go new file mode 100644 index 0000000..6a76526 --- /dev/null +++ b/internal/trigger/doc.go @@ -0,0 +1,23 @@ +// Package trigger implements pipeline trigger execution and status polling +// for nine trigger types: +// +// - command: local shell command execution +// - http: HTTP/webhook trigger with configurable method, headers, and body +// - airflow: Apache Airflow DAG trigger via REST API +// - glue: AWS Glue job trigger via SDK +// - emr: Amazon EMR step submission via SDK +// - emr-serverless: Amazon EMR Serverless job run via SDK +// - step-function: AWS Step Functions execution via SDK +// - databricks: Databricks job run via REST API +// - lambda: AWS Lambda direct invocation via SDK +// +// The [Runner] struct provides dependency injection for AWS SDK clients +// using functional options ([WithGlueClient], [WithEMRClient], etc.). +// Status polling is handled by [Runner.CheckStatus], which normalizes +// provider-specific states into [RunCheckRunning], [RunCheckSucceeded], +// or [RunCheckFailed]. +// +// Errors from trigger execution are wrapped in [TriggerError] with a +// [types.FailureCategory] for retry decisions. [ClassifyFailure] inspects +// any error and returns the appropriate category. +package trigger diff --git a/internal/trigger/emr_serverless_test.go b/internal/trigger/emr_serverless_test.go index e57b72b..4872fba 100644 --- a/internal/trigger/emr_serverless_test.go +++ b/internal/trigger/emr_serverless_test.go @@ -115,3 +115,37 @@ func TestCheckEMRServerlessStatus_MissingMetadata(t *testing.T) { require.NoError(t, err) assert.Equal(t, RunCheckRunning, result.State) } + +func TestCheckEMRServerlessStatus_APIError(t *testing.T) { + client := &mockEMRServerlessClient{getErr: assert.AnError} + r := NewRunner(WithEMRServerlessClient(client)) + _, err := r.checkEMRServerlessStatus(context.Background(), map[string]interface{}{ + "emr_sl_application_id": "app-1", + "emr_sl_job_run_id": "run-1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "GetJobRun failed") +} + +func TestCheckEMRServerlessStatus_NilJobRun(t *testing.T) { + client := &mockEMRServerlessClient{ + getOut: &emrserverless.GetJobRunOutput{JobRun: nil}, + } + r := NewRunner(WithEMRServerlessClient(client)) + _, err := r.checkEMRServerlessStatus(context.Background(), map[string]interface{}{ + "emr_sl_application_id": "app-1", + "emr_sl_job_run_id": "run-1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "returned nil JobRun") +} + +func TestCheckEMRServerlessStatus_MissingAppID(t *testing.T) { + r := NewRunner() + result, err := r.checkEMRServerlessStatus(context.Background(), map[string]interface{}{ + "emr_sl_job_run_id": "run-1", + }) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "missing emr-serverless metadata", result.Message) +} diff --git a/internal/trigger/emr_test.go b/internal/trigger/emr_test.go index baf96ef..fc6a199 100644 --- a/internal/trigger/emr_test.go +++ b/internal/trigger/emr_test.go @@ -125,3 +125,57 @@ func TestCheckEMRStatus_MissingMetadata(t *testing.T) { require.NoError(t, err) assert.Equal(t, RunCheckRunning, result.State) } + +func TestCheckEMRStatus_APIError(t *testing.T) { + client := &mockEMRClient{describeErr: assert.AnError} + r := NewRunner(WithEMRClient(client)) + _, err := r.checkEMRStatus(context.Background(), map[string]interface{}{ + "emr_cluster_id": "j-1", + "emr_step_id": "s-1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "DescribeStep failed") +} + +func TestCheckEMRStatus_NilStep(t *testing.T) { + client := &mockEMRClient{ + describeOut: &emr.DescribeStepOutput{Step: nil}, + } + r := NewRunner(WithEMRClient(client)) + _, err := r.checkEMRStatus(context.Background(), map[string]interface{}{ + "emr_cluster_id": "j-1", + "emr_step_id": "s-1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "returned nil Step") +} + +func TestCheckEMRStatus_NilStepStatus(t *testing.T) { + client := &mockEMRClient{ + describeOut: &emr.DescribeStepOutput{ + Step: &emrtypes.Step{Status: nil}, + }, + } + r := NewRunner(WithEMRClient(client)) + _, err := r.checkEMRStatus(context.Background(), map[string]interface{}{ + "emr_cluster_id": "j-1", + "emr_step_id": "s-1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "returned nil Step.Status") +} + +func TestExecuteEMR_WithArguments(t *testing.T) { + client := &mockEMRClient{ + addOut: &emr.AddJobFlowStepsOutput{StepIds: []string{"s-XYZ"}}, + } + cfg := &types.EMRTriggerConfig{ + ClusterID: "j-1", + StepName: "step-with-args", + Command: "s3://bucket/jar.jar", + Arguments: map[string]string{"--input": "s3://data"}, + } + meta, err := ExecuteEMR(context.Background(), cfg, client) + require.NoError(t, err) + assert.Equal(t, "s-XYZ", meta["emr_step_id"]) +} diff --git a/internal/trigger/runner_test.go b/internal/trigger/runner_test.go index 1f3aaad..34b12a4 100644 --- a/internal/trigger/runner_test.go +++ b/internal/trigger/runner_test.go @@ -2,12 +2,18 @@ package trigger import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "github.com/aws/aws-sdk-go-v2/service/emr" emrtypes "github.com/aws/aws-sdk-go-v2/service/emr/types" + "github.com/aws/aws-sdk-go-v2/service/emrserverless" + emrsltypes "github.com/aws/aws-sdk-go-v2/service/emrserverless/types" "github.com/aws/aws-sdk-go-v2/service/glue" gluetypes "github.com/aws/aws-sdk-go-v2/service/glue/types" + "github.com/aws/aws-sdk-go-v2/service/lambda" "github.com/aws/aws-sdk-go-v2/service/sfn" sfntypes "github.com/aws/aws-sdk-go-v2/service/sfn/types" "github.com/dwsmith1983/interlock/pkg/types" @@ -111,3 +117,393 @@ func TestPackageLevelCheckStatus_NonPolling(t *testing.T) { require.NoError(t, err) assert.Equal(t, RunCheckRunning, result.State) } + +// --- Runner.Execute dispatch tests --- + +func TestRunner_Execute_CommandType(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerCommand, + Command: &types.CommandTriggerConfig{Command: "echo test"}, + }) + require.NoError(t, err) +} + +func TestRunner_Execute_CommandType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerCommand, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "command trigger config is nil") +} + +func TestRunner_Execute_HTTPType(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerHTTP, + HTTP: &types.HTTPTriggerConfig{Method: "POST", URL: srv.URL}, + }) + require.NoError(t, err) +} + +func TestRunner_Execute_HTTPType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerHTTP, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "http trigger config is nil") +} + +func TestRunner_Execute_AirflowType(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "dag_run_id": "manual__run-1", + }) + })) + defer srv.Close() + + r := NewRunner() + meta, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerAirflow, + Airflow: &types.AirflowTriggerConfig{ + URL: srv.URL, + DagID: "my_dag", + }, + }) + require.NoError(t, err) + assert.Equal(t, "manual__run-1", meta["airflow_dag_run_id"]) +} + +func TestRunner_Execute_AirflowType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerAirflow, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "airflow trigger config is nil") +} + +func TestRunner_Execute_GlueType(t *testing.T) { + runID := "jr_123" + client := &mockGlueClient{ + startOut: &glue.StartJobRunOutput{JobRunId: &runID}, + } + r := NewRunner(WithGlueClient(client)) + meta, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerGlue, + Glue: &types.GlueTriggerConfig{JobName: "test-job"}, + }) + require.NoError(t, err) + assert.Equal(t, "jr_123", meta["glue_job_run_id"]) + assert.Equal(t, "test-job", meta["glue_job_name"]) +} + +func TestRunner_Execute_GlueType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerGlue, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "glue trigger config is nil") +} + +func TestRunner_Execute_EMRType(t *testing.T) { + client := &mockEMRClient{ + addOut: &emr.AddJobFlowStepsOutput{StepIds: []string{"s-abc"}}, + } + r := NewRunner(WithEMRClient(client)) + meta, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerEMR, + EMR: &types.EMRTriggerConfig{ + ClusterID: "j-1", + StepName: "my-step", + Command: "s3://bucket/step.jar", + }, + }) + require.NoError(t, err) + assert.Equal(t, "s-abc", meta["emr_step_id"]) + assert.Equal(t, "j-1", meta["emr_cluster_id"]) +} + +func TestRunner_Execute_EMRType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerEMR, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "emr trigger config is nil") +} + +func TestRunner_Execute_EMRServerlessType(t *testing.T) { + runID := "run-abc" + client := &mockEMRServerlessClient{ + startOut: &emrserverless.StartJobRunOutput{JobRunId: &runID}, + } + r := NewRunner(WithEMRServerlessClient(client)) + meta, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerEMRServerless, + EMRServerless: &types.EMRServerlessTriggerConfig{ + ApplicationID: "app-1", + JobName: "spark-job", + }, + }) + require.NoError(t, err) + assert.Equal(t, "run-abc", meta["emr_sl_job_run_id"]) + assert.Equal(t, "app-1", meta["emr_sl_application_id"]) +} + +func TestRunner_Execute_EMRServerlessType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerEMRServerless, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "emr-serverless trigger config is nil") +} + +func TestRunner_Execute_SFNType(t *testing.T) { + execArn := "arn:aws:states:us-east-1:123:execution:my-sfn:run-1" + client := &mockSFNClient{ + startOut: &sfn.StartExecutionOutput{ExecutionArn: &execArn}, + } + r := NewRunner(WithSFNClient(client)) + meta, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerStepFunction, + StepFunction: &types.StepFunctionTriggerConfig{ + StateMachineARN: "arn:aws:states:us-east-1:123:stateMachine:my-sfn", + }, + }) + require.NoError(t, err) + assert.Equal(t, execArn, meta["sfn_execution_arn"]) +} + +func TestRunner_Execute_SFNType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerStepFunction, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "step-function trigger config is nil") +} + +func TestRunner_Execute_DatabricksType(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"run_id": 12345}) + })) + defer srv.Close() + + r := NewRunner(WithHTTPClient(srv.Client())) + meta, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerDatabricks, + Databricks: &types.DatabricksTriggerConfig{ + WorkspaceURL: srv.URL, + JobID: "my-job", + }, + }) + require.NoError(t, err) + assert.Equal(t, "12345", meta["databricks_run_id"]) +} + +func TestRunner_Execute_DatabricksType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerDatabricks, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "databricks trigger config is nil") +} + +func TestRunner_Execute_LambdaType(t *testing.T) { + client := &mockLambdaClient{ + invokeOut: &lambda.InvokeOutput{StatusCode: 200, Payload: []byte(`{"ok":true}`)}, + } + r := NewRunner(WithLambdaClient(client)) + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerLambda, + Lambda: &types.LambdaTriggerConfig{FunctionName: "my-func"}, + }) + require.NoError(t, err) +} + +func TestRunner_Execute_LambdaType_NilConfig(t *testing.T) { + r := NewRunner() + _, err := r.Execute(context.Background(), &types.TriggerConfig{ + Type: types.TriggerLambda, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "lambda trigger config is nil") +} + +// --- Runner.checkAirflowStatus tests --- + +func TestRunner_CheckAirflowStatus_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + }) + })) + defer srv.Close() + + // Temporarily replace defaultHTTPClient to route requests to the test server. + origClient := defaultHTTPClient + defaultHTTPClient = srv.Client() + defer func() { defaultHTTPClient = origClient }() + + r := NewRunner() + result, err := r.checkAirflowStatus(context.Background(), map[string]interface{}{ + "airflow_url": srv.URL, + "airflow_dag_id": "my_dag", + "airflow_dag_run_id": "run_1", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckSucceeded, result.State) + assert.Equal(t, "success", result.Message) +} + +func TestRunner_CheckAirflowStatus_Failed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "failed", + }) + })) + defer srv.Close() + + origClient := defaultHTTPClient + defaultHTTPClient = srv.Client() + defer func() { defaultHTTPClient = origClient }() + + r := NewRunner() + result, err := r.checkAirflowStatus(context.Background(), map[string]interface{}{ + "airflow_url": srv.URL, + "airflow_dag_id": "my_dag", + "airflow_dag_run_id": "run_1", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckFailed, result.State) + assert.Equal(t, "failed", result.Message) +} + +func TestRunner_CheckAirflowStatus_Running(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "running", + }) + })) + defer srv.Close() + + origClient := defaultHTTPClient + defaultHTTPClient = srv.Client() + defer func() { defaultHTTPClient = origClient }() + + r := NewRunner() + result, err := r.checkAirflowStatus(context.Background(), map[string]interface{}{ + "airflow_url": srv.URL, + "airflow_dag_id": "my_dag", + "airflow_dag_run_id": "run_1", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "running", result.Message) +} + +func TestRunner_CheckAirflowStatus_MissingMetadata(t *testing.T) { + r := NewRunner() + + // Missing all fields + result, err := r.checkAirflowStatus(context.Background(), map[string]interface{}{}, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "missing airflow metadata", result.Message) + + // Missing dag_id and dag_run_id + result, err = r.checkAirflowStatus(context.Background(), map[string]interface{}{ + "airflow_url": "http://example.com", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "missing airflow metadata", result.Message) +} + +// --- Runner.CheckStatus dispatch tests for remaining types --- + +func TestRunner_CheckStatus_EMRServerlessDispatch(t *testing.T) { + client := &mockEMRServerlessClient{ + getOut: &emrserverless.GetJobRunOutput{ + JobRun: &emrsltypes.JobRun{ + State: emrsltypes.JobRunStateSuccess, + }, + }, + } + r := NewRunner(WithEMRServerlessClient(client)) + + result, err := r.CheckStatus(context.Background(), types.TriggerEMRServerless, map[string]interface{}{ + "emr_sl_application_id": "app-1", + "emr_sl_job_run_id": "run-1", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckSucceeded, result.State) +} + +func TestRunner_CheckStatus_DatabricksDispatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": map[string]interface{}{ + "life_cycle_state": "TERMINATED", + "result_state": "SUCCESS", + }, + }) + })) + defer srv.Close() + + r := NewRunner(WithHTTPClient(srv.Client())) + result, err := r.CheckStatus(context.Background(), types.TriggerDatabricks, map[string]interface{}{ + "databricks_workspace_url": srv.URL, + "databricks_run_id": "999", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckSucceeded, result.State) +} + +func TestRunner_CheckStatus_AirflowDispatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + }) + })) + defer srv.Close() + + origClient := defaultHTTPClient + defaultHTTPClient = srv.Client() + defer func() { defaultHTTPClient = origClient }() + + r := NewRunner() + result, err := r.CheckStatus(context.Background(), types.TriggerAirflow, map[string]interface{}{ + "airflow_url": srv.URL, + "airflow_dag_id": "my_dag", + "airflow_dag_run_id": "run_1", + }, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckSucceeded, result.State) +} + +func TestRunner_CheckStatus_LambdaDispatch(t *testing.T) { + r := NewRunner() + result, err := r.CheckStatus(context.Background(), types.TriggerLambda, nil, nil) + require.NoError(t, err) + assert.Equal(t, RunCheckRunning, result.State) + assert.Equal(t, "non-polling trigger type", result.Message) +} diff --git a/internal/trigger/status_test.go b/internal/trigger/status_test.go new file mode 100644 index 0000000..b5f4ecc --- /dev/null +++ b/internal/trigger/status_test.go @@ -0,0 +1,23 @@ +package trigger + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunCheckState_Constants(t *testing.T) { + assert.Equal(t, RunCheckState("running"), RunCheckRunning) + assert.Equal(t, RunCheckState("succeeded"), RunCheckSucceeded) + assert.Equal(t, RunCheckState("failed"), RunCheckFailed) +} + +func TestStatusResult_ConstructAndCompare(t *testing.T) { + result := StatusResult{ + State: RunCheckSucceeded, + Message: "job completed", + } + assert.Equal(t, RunCheckSucceeded, result.State) + assert.Equal(t, "job completed", result.Message) + assert.Empty(t, result.FailureCategory) +} diff --git a/internal/trigger/trigger.go b/internal/trigger/trigger.go index f8cd59b..545a70d 100644 --- a/internal/trigger/trigger.go +++ b/internal/trigger/trigger.go @@ -1,4 +1,3 @@ -// Package trigger implements pipeline trigger execution. package trigger import ( diff --git a/internal/trigger/trigger_test.go b/internal/trigger/trigger_test.go index 94063a9..9b1066d 100644 --- a/internal/trigger/trigger_test.go +++ b/internal/trigger/trigger_test.go @@ -261,3 +261,9 @@ func TestExecuteHTTP_Returns_TriggerError_On5xx(t *testing.T) { assert.Equal(t, types.FailureTransient, te.Category) assert.Contains(t, te.Message, "status 503") } + +func TestExecuteCommand_EmptyCommand(t *testing.T) { + err := ExecuteCommand(context.Background(), "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "command is empty") +}