diff --git a/module/pipeline_step_auth_validate.go b/module/pipeline_step_auth_validate.go new file mode 100644 index 00000000..f1d46918 --- /dev/null +++ b/module/pipeline_step_auth_validate.go @@ -0,0 +1,125 @@ +package module + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/CrisisTextLine/modular" +) + +// AuthValidateStep validates a Bearer token against a registered AuthProvider +// module and outputs the claims returned by the provider into the pipeline context. +type AuthValidateStep struct { + name string + authModule string // service name of the AuthProvider module + tokenSource string // dot-path to the token in pipeline context + subjectField string // output field name for the subject claim + app modular.Application +} + +// NewAuthValidateStepFactory returns a StepFactory that creates AuthValidateStep instances. +func NewAuthValidateStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + authModule, _ := config["auth_module"].(string) + if authModule == "" { + return nil, fmt.Errorf("auth_validate step %q: 'auth_module' is required", name) + } + + tokenSource, _ := config["token_source"].(string) + if tokenSource == "" { + return nil, fmt.Errorf("auth_validate step %q: 'token_source' is required", name) + } + + subjectField, _ := config["subject_field"].(string) + if subjectField == "" { + subjectField = "auth_user_id" + } + + return &AuthValidateStep{ + name: name, + authModule: authModule, + tokenSource: tokenSource, + subjectField: subjectField, + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *AuthValidateStep) Name() string { return s.name } + +// Execute validates the Bearer token and outputs claims from the AuthProvider. +func (s *AuthValidateStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("auth_validate step %q: no application context", s.name) + } + + // 1. Extract the token value from the pipeline context using the configured dot-path. + rawToken := resolveBodyFrom(s.tokenSource, pc) + tokenStr, _ := rawToken.(string) + if tokenStr == "" { + return s.unauthorizedResponse(pc, "missing or empty authorization header") + } + + // 2. Strip "Bearer " prefix. + if !strings.HasPrefix(tokenStr, "Bearer ") { + return s.unauthorizedResponse(pc, "malformed authorization header") + } + token := strings.TrimPrefix(tokenStr, "Bearer ") + if token == "" { + return s.unauthorizedResponse(pc, "empty bearer token") + } + + // 3. Resolve the AuthProvider from the service registry. + var provider AuthProvider + if err := s.app.GetService(s.authModule, &provider); err != nil { + return nil, fmt.Errorf("auth_validate step %q: auth module %q not found: %w", s.name, s.authModule, err) + } + + // 4. Authenticate the token. + valid, claims, err := provider.Authenticate(token) + if err != nil { + return s.unauthorizedResponse(pc, "authentication error") + } + if !valid { + return s.unauthorizedResponse(pc, "invalid token") + } + + // 5. Build output: all claims as flat keys + configured subject_field from "sub". + output := make(map[string]any, len(claims)+1) + for k, v := range claims { + output[k] = v + } + if sub, ok := claims["sub"]; ok { + output[s.subjectField] = sub + } + + return &StepResult{Output: output}, nil +} + +// unauthorizedResponse writes a 401 JSON error response and stops the pipeline. +func (s *AuthValidateStep) unauthorizedResponse(pc *PipelineContext, message string) (*StepResult, error) { + errorBody := map[string]any{ + "error": "unauthorized", + "message": message, + } + + if w, ok := pc.Metadata["_http_response_writer"].(http.ResponseWriter); ok { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(errorBody) + pc.Metadata["_response_handled"] = true + } + + return &StepResult{ + Output: map[string]any{ + "status": http.StatusUnauthorized, + "error": "unauthorized", + "message": message, + }, + Stop: true, + }, nil +} diff --git a/module/pipeline_step_auth_validate_test.go b/module/pipeline_step_auth_validate_test.go new file mode 100644 index 00000000..d316b345 --- /dev/null +++ b/module/pipeline_step_auth_validate_test.go @@ -0,0 +1,394 @@ +package module + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// testAuthProvider implements AuthProvider for auth_validate step tests. +type testAuthProvider struct { + validTokens map[string]map[string]any + returnErr error +} + +func (m *testAuthProvider) Authenticate(token string) (bool, map[string]any, error) { + if m.returnErr != nil { + return false, nil, m.returnErr + } + if claims, ok := m.validTokens[token]; ok { + return true, claims, nil + } + return false, nil, nil +} + +func newTestAuthApp(authModule string, provider AuthProvider) *MockApplication { + app := NewMockApplication() + app.Services[authModule] = provider + return app +} + +func TestAuthValidateStep_SuccessfulAuth(t *testing.T) { + factory := NewAuthValidateStepFactory() + provider := &testAuthProvider{ + validTokens: map[string]map[string]any{ + "valid-token-123": { + "sub": "user-42", + "scope": "read write", + "iss": "test-issuer", + }, + }, + } + app := newTestAuthApp("m2m-auth", provider) + + step, err := factory("auth", map[string]any{ + "auth_module": "m2m-auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "headers": map[string]any{ + "Authorization": "Bearer valid-token-123", + }, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Stop { + t.Error("expected Stop=false on successful auth") + } + if result.Output["sub"] != "user-42" { + t.Errorf("expected sub=user-42, got %v", result.Output["sub"]) + } + if result.Output["scope"] != "read write" { + t.Errorf("expected scope='read write', got %v", result.Output["scope"]) + } + if result.Output["iss"] != "test-issuer" { + t.Errorf("expected iss=test-issuer, got %v", result.Output["iss"]) + } + if result.Output["auth_user_id"] != "user-42" { + t.Errorf("expected auth_user_id=user-42, got %v", result.Output["auth_user_id"]) + } +} + +func TestAuthValidateStep_CustomSubjectField(t *testing.T) { + factory := NewAuthValidateStepFactory() + provider := &testAuthProvider{ + validTokens: map[string]map[string]any{ + "tok": {"sub": "svc-1"}, + }, + } + app := newTestAuthApp("auth", provider) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + "subject_field": "service_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "headers": map[string]any{"Authorization": "Bearer tok"}, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Output["service_id"] != "svc-1" { + t.Errorf("expected service_id=svc-1, got %v", result.Output["service_id"]) + } +} + +func TestAuthValidateStep_CustomTokenSource(t *testing.T) { + factory := NewAuthValidateStepFactory() + provider := &testAuthProvider{ + validTokens: map[string]map[string]any{ + "my-token": {"sub": "u1"}, + }, + } + app := newTestAuthApp("auth", provider) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.headers.auth_header", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("headers", map[string]any{ + "auth_header": "Bearer my-token", + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Stop { + t.Error("expected Stop=false on success") + } + if result.Output["sub"] != "u1" { + t.Errorf("expected sub=u1, got %v", result.Output["sub"]) + } +} + +func TestAuthValidateStep_MissingToken(t *testing.T) { + factory := NewAuthValidateStepFactory() + app := newTestAuthApp("auth", &testAuthProvider{}) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + // No step output for "parse" — token is missing. + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true for missing token") + } + if result.Output["status"] != http.StatusUnauthorized { + t.Errorf("expected status=401, got %v", result.Output["status"]) + } +} + +func TestAuthValidateStep_MalformedHeader(t *testing.T) { + factory := NewAuthValidateStepFactory() + app := newTestAuthApp("auth", &testAuthProvider{}) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "headers": map[string]any{"Authorization": "Basic dXNlcjpwYXNz"}, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true for malformed header") + } + if result.Output["error"] != "unauthorized" { + t.Errorf("expected error=unauthorized, got %v", result.Output["error"]) + } +} + +func TestAuthValidateStep_InvalidToken(t *testing.T) { + factory := NewAuthValidateStepFactory() + provider := &testAuthProvider{ + validTokens: map[string]map[string]any{}, // no valid tokens + } + app := newTestAuthApp("auth", provider) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "headers": map[string]any{"Authorization": "Bearer bad-token"}, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true for invalid token") + } + if result.Output["status"] != http.StatusUnauthorized { + t.Errorf("expected status=401, got %v", result.Output["status"]) + } +} + +func TestAuthValidateStep_AuthError(t *testing.T) { + factory := NewAuthValidateStepFactory() + provider := &testAuthProvider{ + returnErr: fmt.Errorf("provider error"), + } + app := newTestAuthApp("auth", provider) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "headers": map[string]any{"Authorization": "Bearer some-token"}, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true when provider returns error") + } +} + +func TestAuthValidateStep_WritesHTTPResponse(t *testing.T) { + factory := NewAuthValidateStepFactory() + app := newTestAuthApp("auth", &testAuthProvider{}) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + w := httptest.NewRecorder() + pc := NewPipelineContext(nil, map[string]any{ + "_http_response_writer": w, + }) + // No auth header → 401 + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true") + } + if w.Code != http.StatusUnauthorized { + t.Errorf("expected HTTP 401, got %d", w.Code) + } + if ct := w.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %q", ct) + } + if !strings.Contains(w.Body.String(), "unauthorized") { + t.Errorf("expected 'unauthorized' in response body, got %q", w.Body.String()) + } + if pc.Metadata["_response_handled"] != true { + t.Error("expected _response_handled=true in metadata") + } +} + +func TestAuthValidateStep_FactoryRequiresAuthModule(t *testing.T) { + factory := NewAuthValidateStepFactory() + + _, err := factory("auth", map[string]any{}, nil) + if err == nil { + t.Fatal("expected error for missing auth_module") + } + if !strings.Contains(err.Error(), "'auth_module' is required") { + t.Errorf("expected auth_module error, got: %v", err) + } +} + +func TestAuthValidateStep_Name(t *testing.T) { + factory := NewAuthValidateStepFactory() + app := newTestAuthApp("auth", &testAuthProvider{}) + + step, err := factory("my-auth-step", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + if step.Name() != "my-auth-step" { + t.Errorf("expected name 'my-auth-step', got %q", step.Name()) + } +} + +func TestAuthValidateStep_FactoryRequiresTokenSource(t *testing.T) { + factory := NewAuthValidateStepFactory() + + _, err := factory("auth", map[string]any{"auth_module": "auth"}, nil) + if err == nil { + t.Fatal("expected error for missing token_source") + } + if !strings.Contains(err.Error(), "'token_source' is required") { + t.Errorf("expected token_source error, got: %v", err) + } +} + +func TestAuthValidateStep_NilApp(t *testing.T) { + factory := NewAuthValidateStepFactory() + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "headers": map[string]any{"Authorization": "Bearer some-token"}, + }) + + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when app is nil") + } + if !strings.Contains(err.Error(), "no application context") { + t.Errorf("expected 'no application context' error, got: %v", err) + } +} + +func TestAuthValidateStep_EmptyBearerToken(t *testing.T) { + factory := NewAuthValidateStepFactory() + app := newTestAuthApp("auth", &testAuthProvider{}) + + step, err := factory("auth", map[string]any{ + "auth_module": "auth", + "token_source": "steps.parse.headers.Authorization", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + pc.MergeStepOutput("parse", map[string]any{ + "headers": map[string]any{"Authorization": "Bearer "}, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true for empty bearer token") + } +} diff --git a/plugins/pipelinesteps/plugin.go b/plugins/pipelinesteps/plugin.go index fe739edd..4c62acf9 100644 --- a/plugins/pipelinesteps/plugin.go +++ b/plugins/pipelinesteps/plugin.go @@ -3,7 +3,8 @@ // http_call, request_parse, db_query, db_exec, json_response, // validate_path_param, validate_pagination, validate_request_body, // foreach, webhook_verify, ui_scaffold, ui_scaffold_analyze, -// dlq_send, dlq_replay, retry_with_backoff, circuit_breaker (wrapping). +// dlq_send, dlq_replay, retry_with_backoff, circuit_breaker (wrapping), +// auth_validate. // It also provides the PipelineWorkflowHandler for composable pipelines. package pipelinesteps @@ -80,6 +81,7 @@ func New() *Plugin { "step.dlq_replay", "step.retry_with_backoff", "step.resilient_circuit_breaker", + "step.auth_validate", }, WorkflowTypes: []string{"pipeline"}, Capabilities: []plugin.CapabilityDecl{ @@ -111,7 +113,7 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { "step.delegate": wrapStepFactory(module.NewDelegateStepFactory()), "step.jq": wrapStepFactory(module.NewJQStepFactory()), "step.publish": wrapStepFactory(module.NewPublishStepFactory()), - "step.event_publish": wrapStepFactory(module.NewEventPublishStepFactory()), + "step.event_publish": wrapStepFactory(module.NewEventPublishStepFactory()), "step.http_call": wrapStepFactory(module.NewHTTPCallStepFactory()), "step.request_parse": wrapStepFactory(module.NewRequestParseStepFactory()), "step.db_query": wrapStepFactory(module.NewDBQueryStepFactory()), @@ -125,7 +127,7 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { "step.foreach": wrapStepFactory(module.NewForEachStepFactory(func() *module.StepRegistry { return p.concreteStepRegistry })), - "step.webhook_verify": wrapStepFactory(module.NewWebhookVerifyStepFactory()), + "step.webhook_verify": wrapStepFactory(module.NewWebhookVerifyStepFactory()), "step.cache_get": wrapStepFactory(module.NewCacheGetStepFactory()), "step.cache_set": wrapStepFactory(module.NewCacheSetStepFactory()), "step.cache_delete": wrapStepFactory(module.NewCacheDeleteStepFactory()), @@ -140,6 +142,7 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { "step.resilient_circuit_breaker": wrapStepFactory(module.NewResilienceCircuitBreakerStepFactory(func() *module.StepRegistry { return p.concreteStepRegistry })), + "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), } } diff --git a/plugins/pipelinesteps/plugin_test.go b/plugins/pipelinesteps/plugin_test.go index df833648..322ce76c 100644 --- a/plugins/pipelinesteps/plugin_test.go +++ b/plugins/pipelinesteps/plugin_test.go @@ -59,6 +59,7 @@ func TestStepFactories(t *testing.T) { "step.dlq_replay", "step.retry_with_backoff", "step.resilient_circuit_breaker", + "step.auth_validate", } for _, stepType := range expectedSteps { @@ -80,7 +81,7 @@ func TestPluginLoads(t *testing.T) { } steps := loader.StepFactories() - if len(steps) != 28 { - t.Fatalf("expected 28 step factories after load, got %d", len(steps)) + if len(steps) != 29 { + t.Fatalf("expected 29 step factories after load, got %d", len(steps)) } }