diff --git a/module/pipeline_step_authz_check.go b/module/pipeline_step_authz_check.go new file mode 100644 index 00000000..0bebca29 --- /dev/null +++ b/module/pipeline_step_authz_check.go @@ -0,0 +1,148 @@ +package module + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/CrisisTextLine/modular" +) + +// AuthzCheckStep evaluates a policy engine decision for the current pipeline +// subject. On denial it writes a 403 Forbidden JSON response to the HTTP +// response writer (when present) and stops the pipeline, matching the +// pattern used by step.auth_validate for 401 responses. +type AuthzCheckStep struct { + name string + engineName string // service name of the PolicyEngineModule + subjectField string // field in pc.Current that holds the subject + inputFrom string // optional: field in pc.Current to use as policy input + app modular.Application +} + +// NewAuthzCheckStepFactory returns a StepFactory that creates AuthzCheckStep instances. +func NewAuthzCheckStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + engineName, _ := config["policy_engine"].(string) + if engineName == "" { + return nil, fmt.Errorf("authz_check step %q: 'policy_engine' is required", name) + } + + subjectField, _ := config["subject_field"].(string) + if subjectField == "" { + subjectField = "subject" + } + + inputFrom, _ := config["input_from"].(string) + + return &AuthzCheckStep{ + name: name, + engineName: engineName, + subjectField: subjectField, + inputFrom: inputFrom, + app: app, + }, nil + } +} + +// Name returns the step name. +func (s *AuthzCheckStep) Name() string { return s.name } + +// Execute evaluates the policy engine and writes a 403 response on denial. +func (s *AuthzCheckStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("authz_check step %q: no application context", s.name) + } + + // Resolve the PolicyEngineModule from the service registry. + eng, err := resolvePolicyEngine(s.app, s.engineName, s.name) + if err != nil { + return nil, err + } + + // Build the policy input: use a named field if configured, otherwise use + // the full pipeline context (same strategy as step.policy_evaluate). + // Track whether the input shares the same backing data as pc.Current so we + // can clone before adding the subject key. + var input map[string]any + inputIsShared := false + if s.inputFrom != "" { + if raw, ok := pc.Current[s.inputFrom]; ok { + if m, ok := raw.(map[string]any); ok { + input = m + } + } + } + if input == nil { + input = pc.Current + inputIsShared = true + } + + // Map the configured subject field into the policy input so that + // authorization decisions can depend on it. We read the subject from + // pc.Current[s.subjectField] and expose it under the canonical "subject" + // key in the input. Clone the input first when it shares data with + // pc.Current to avoid side effects on the pipeline context. + if s.subjectField != "" && s.subjectField != "subject" { + if subj, ok := pc.Current[s.subjectField]; ok { + if inputIsShared { + cloned := make(map[string]any, len(input)+1) + for k, v := range input { + cloned[k] = v + } + input = cloned + } + input["subject"] = subj + } + } + + // Evaluate the policy. + decision, err := eng.Engine().Evaluate(ctx, input) + if err != nil { + return nil, fmt.Errorf("authz_check step %q: evaluate: %w", s.name, err) + } + + if !decision.Allowed { + reason := "authorization denied" + if len(decision.Reasons) > 0 { + reason = decision.Reasons[0] + } + return s.forbiddenResponse(pc, reason) + } + + return &StepResult{Output: map[string]any{ + "allowed": true, + "reasons": decision.Reasons, + "metadata": decision.Metadata, + }}, nil +} + +// forbiddenResponse writes a 403 JSON error response to the HTTP response +// writer (when present) and stops the pipeline. The response body format +// matches the expected {"error":"forbidden: ..."} shape described in the issue. +func (s *AuthzCheckStep) forbiddenResponse(pc *PipelineContext, message string) (*StepResult, error) { + errorMsg := fmt.Sprintf("forbidden: %s", message) + errorBody := map[string]any{ + "error": errorMsg, + } + + if w, ok := pc.Metadata["_http_response_writer"].(http.ResponseWriter); ok { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + _ = json.NewEncoder(w).Encode(errorBody) + pc.Metadata["_response_handled"] = true + } + + return &StepResult{ + Output: map[string]any{ + "response_status": http.StatusForbidden, + "response_body": fmt.Sprintf(`{"error":%q}`, errorMsg), + "response_headers": map[string]string{ + "Content-Type": "application/json", + }, + "error": errorMsg, + }, + Stop: true, + }, nil +} diff --git a/module/pipeline_step_authz_check_test.go b/module/pipeline_step_authz_check_test.go new file mode 100644 index 00000000..9fb9c40a --- /dev/null +++ b/module/pipeline_step_authz_check_test.go @@ -0,0 +1,326 @@ +package module + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// newTestPolicyApp creates a MockApplication with a PolicyEngineModule registered. +func newTestPolicyApp(engineName string, eng PolicyEngine) *MockApplication { + app := NewMockApplication() + mod := &PolicyEngineModule{ + name: engineName, + engine: eng, + } + app.Services[engineName] = mod + return app +} + +func TestAuthzCheckStep_Allowed(t *testing.T) { + factory := NewAuthzCheckStepFactory() + app := newTestPolicyApp("policy", newMockPolicyEngine()) + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{"subject": "user-1"}, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Stop { + t.Error("expected Stop=false when policy allows") + } + if result.Output["allowed"] != true { + t.Errorf("expected allowed=true, got %v", result.Output["allowed"]) + } +} + +func TestAuthzCheckStep_Denied(t *testing.T) { + factory := NewAuthzCheckStepFactory() + eng := newMockPolicyEngine() + // Load a deny policy so the mock engine denies the request. + _ = eng.LoadPolicy("deny-all", "deny") + app := newTestPolicyApp("policy", eng) + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{"subject": "user-1"}, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true when policy denies") + } + if result.Output["response_status"] != http.StatusForbidden { + t.Errorf("expected response_status=403, got %v", result.Output["response_status"]) + } + errMsg, _ := result.Output["error"].(string) + if !strings.Contains(errMsg, "forbidden") { + t.Errorf("expected error to contain 'forbidden', got %q", errMsg) + } +} + +func TestAuthzCheckStep_WritesHTTPResponse(t *testing.T) { + factory := NewAuthzCheckStepFactory() + eng := newMockPolicyEngine() + _ = eng.LoadPolicy("deny-all", "deny") + app := newTestPolicyApp("policy", eng) + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + w := httptest.NewRecorder() + pc := NewPipelineContext(map[string]any{"subject": "user-1"}, map[string]any{ + "_http_response_writer": w, + }) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true") + } + if w.Code != http.StatusForbidden { + t.Errorf("expected HTTP 403, got %d", w.Code) + } + if ct := w.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %q", ct) + } + if !strings.Contains(w.Body.String(), "forbidden") { + t.Errorf("expected 'forbidden' in response body, got %q", w.Body.String()) + } + if pc.Metadata["_response_handled"] != true { + t.Error("expected _response_handled=true in metadata") + } +} + +func TestAuthzCheckStep_WritesHTTPResponse_NoResponseWriter(t *testing.T) { + factory := NewAuthzCheckStepFactory() + eng := newMockPolicyEngine() + _ = eng.LoadPolicy("deny-all", "deny") + app := newTestPolicyApp("policy", eng) + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + // No response writer in metadata — should still stop pipeline with output. + pc := NewPipelineContext(map[string]any{"subject": "user-1"}, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true even without response writer") + } + if result.Output["response_status"] != http.StatusForbidden { + t.Errorf("expected response_status=403, got %v", result.Output["response_status"]) + } + headers, _ := result.Output["response_headers"].(map[string]string) + if headers["Content-Type"] != "application/json" { + t.Errorf("expected response_headers Content-Type=application/json, got %v", headers) + } +} + +func TestAuthzCheckStep_InputFrom(t *testing.T) { + factory := NewAuthzCheckStepFactory() + eng := newMockPolicyEngine() + _ = eng.LoadPolicy("deny-all", "deny") + app := newTestPolicyApp("policy", eng) + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + "input_from": "authz_input", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + // The input_from field contains a sub-map; the deny policy still triggers. + pc := NewPipelineContext(map[string]any{ + "authz_input": map[string]any{"subject": "user-1", "action": "read"}, + }, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.Stop { + t.Error("expected Stop=true") + } +} + +func TestAuthzCheckStep_FactoryRequiresPolicyEngine(t *testing.T) { + factory := NewAuthzCheckStepFactory() + + _, err := factory("authz", map[string]any{}, nil) + if err == nil { + t.Fatal("expected error for missing policy_engine") + } + if !strings.Contains(err.Error(), "'policy_engine' is required") { + t.Errorf("expected policy_engine error, got: %v", err) + } +} + +func TestAuthzCheckStep_NilApp(t *testing.T) { + factory := NewAuthzCheckStepFactory() + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error when app is nil") + } + if !strings.Contains(err.Error(), "no application context") { + t.Errorf("expected 'no application context' error, got: %v", err) + } +} + +func TestAuthzCheckStep_Name(t *testing.T) { + factory := NewAuthzCheckStepFactory() + app := newTestPolicyApp("policy", newMockPolicyEngine()) + + step, err := factory("my-authz-step", map[string]any{ + "policy_engine": "policy", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + if step.Name() != "my-authz-step" { + t.Errorf("expected name 'my-authz-step', got %q", step.Name()) + } +} + +func TestAuthzCheckStep_DefaultSubjectField(t *testing.T) { + factory := NewAuthzCheckStepFactory() + app := newTestPolicyApp("policy", newMockPolicyEngine()) + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + s := step.(*AuthzCheckStep) + if s.subjectField != "subject" { + t.Errorf("expected default subject_field='subject', got %q", s.subjectField) + } +} + +func TestAuthzCheckStep_CustomSubjectField(t *testing.T) { + factory := NewAuthzCheckStepFactory() + app := newTestPolicyApp("policy", newMockPolicyEngine()) + + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + "subject_field": "user_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + s := step.(*AuthzCheckStep) + if s.subjectField != "user_id" { + t.Errorf("expected subject_field='user_id', got %q", s.subjectField) + } +} + +// capturingPolicyEngine records the last input passed to Evaluate. +type capturingPolicyEngine struct { + lastInput map[string]any +} + +func (e *capturingPolicyEngine) Evaluate(_ context.Context, input map[string]any) (*PolicyDecision, error) { + e.lastInput = input + return &PolicyDecision{Allowed: true, Reasons: []string{"allow"}, Metadata: nil}, nil +} +func (e *capturingPolicyEngine) LoadPolicy(_, _ string) error { return nil } +func (e *capturingPolicyEngine) ListPolicies() []PolicyInfo { return nil } + +func TestAuthzCheckStep_SubjectFieldMappedToInput(t *testing.T) { + eng := &capturingPolicyEngine{} + app := newTestPolicyApp("policy", eng) + + factory := NewAuthzCheckStepFactory() + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + "subject_field": "auth_user_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{ + "auth_user_id": "user-99", + "other_field": "value", + }, nil) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Stop { + t.Error("expected Stop=false when policy allows") + } + // The input passed to the engine should have "subject" mapped from auth_user_id. + if eng.lastInput["subject"] != "user-99" { + t.Errorf("expected input[subject]=user-99, got %v", eng.lastInput["subject"]) + } + // Original field should still be present. + if eng.lastInput["auth_user_id"] != "user-99" { + t.Errorf("expected input[auth_user_id]=user-99, got %v", eng.lastInput["auth_user_id"]) + } +} + +func TestAuthzCheckStep_SubjectFieldMappingDoesNotMutatePipelineContext(t *testing.T) { + eng := &capturingPolicyEngine{} + app := newTestPolicyApp("policy", eng) + + factory := NewAuthzCheckStepFactory() + step, err := factory("authz", map[string]any{ + "policy_engine": "policy", + "subject_field": "auth_user_id", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{ + "auth_user_id": "user-99", + }, nil) + + _, err = step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // pc.Current should NOT have had "subject" injected. + if _, ok := pc.Current["subject"]; ok { + t.Error("expected pc.Current to not be mutated with 'subject' key") + } +} diff --git a/plugins/pipelinesteps/plugin.go b/plugins/pipelinesteps/plugin.go index 1eec412c..4e36c8a7 100644 --- a/plugins/pipelinesteps/plugin.go +++ b/plugins/pipelinesteps/plugin.go @@ -4,7 +4,7 @@ // raw_response, static_file, validate_path_param, validate_pagination, validate_request_body, // foreach, webhook_verify, base64_decode, ui_scaffold, ui_scaffold_analyze, // dlq_send, dlq_replay, retry_with_backoff, circuit_breaker (wrapping), -// s3_upload, auth_validate, token_revoke, sandbox_exec. +// s3_upload, auth_validate, authz_check, token_revoke, sandbox_exec. // It also provides the PipelineWorkflowHandler for composable pipelines. package pipelinesteps @@ -89,6 +89,7 @@ func New() *Plugin { "step.resilient_circuit_breaker", "step.s3_upload", "step.auth_validate", + "step.authz_check", "step.token_revoke", "step.field_reencrypt", "step.sandbox_exec", @@ -163,6 +164,7 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { })), "step.s3_upload": wrapStepFactory(module.NewS3UploadStepFactory()), "step.auth_validate": wrapStepFactory(module.NewAuthValidateStepFactory()), + "step.authz_check": wrapStepFactory(module.NewAuthzCheckStepFactory()), "step.token_revoke": wrapStepFactory(module.NewTokenRevokeStepFactory()), "step.field_reencrypt": wrapStepFactory(module.NewFieldReencryptStepFactory()), "step.sandbox_exec": wrapStepFactory(module.NewSandboxExecStepFactory()), diff --git a/plugins/pipelinesteps/plugin_test.go b/plugins/pipelinesteps/plugin_test.go index 30693340..eb0a64de 100644 --- a/plugins/pipelinesteps/plugin_test.go +++ b/plugins/pipelinesteps/plugin_test.go @@ -66,6 +66,7 @@ func TestStepFactories(t *testing.T) { "step.resilient_circuit_breaker", "step.s3_upload", "step.auth_validate", + "step.authz_check", "step.token_revoke", "step.base64_decode", "step.field_reencrypt",