diff --git a/module/openapi.go b/module/openapi.go index 73c8fb81..aa941eef 100644 --- a/module/openapi.go +++ b/module/openapi.go @@ -11,6 +11,7 @@ import ( "log/slog" "math" "net/http" + "net/url" "os" "regexp" "sort" @@ -413,7 +414,7 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string { continue } if val != "" && p.Schema != nil { - if schemaErrs := validateScalarValue(val, p.Name, p.Schema); len(schemaErrs) > 0 { + if schemaErrs := validateScalarValue(val, p.Name, "parameter", p.Schema); len(schemaErrs) > 0 { errs = append(errs, schemaErrs...) } } @@ -473,11 +474,20 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string { if h.op.RequestBody.Required && len(bodyBytes) == 0 { errs = append(errs, "request body is required but missing") } else if mediaType != nil && mediaType.Schema != nil && len(bodyBytes) > 0 { - var bodyData any - if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr != nil { - errs = append(errs, fmt.Sprintf("request body contains invalid JSON: %v", jsonErr)) - } else if bodyErrs := validateJSONValue(bodyData, "body", mediaType.Schema); len(bodyErrs) > 0 { - errs = append(errs, bodyErrs...) + if ct == "application/x-www-form-urlencoded" { + formValues, parseErr := url.ParseQuery(string(bodyBytes)) + if parseErr != nil { + errs = append(errs, fmt.Sprintf("request body contains invalid form data: %v", parseErr)) + } else if formErrs := validateFormBody(formValues, mediaType.Schema); len(formErrs) > 0 { + errs = append(errs, formErrs...) + } + } else { + var bodyData any + if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr != nil { + errs = append(errs, fmt.Sprintf("request body contains invalid JSON: %v", jsonErr)) + } else if bodyErrs := validateJSONValue(bodyData, "body", mediaType.Schema); len(bodyErrs) > 0 { + errs = append(errs, bodyErrs...) + } } } } @@ -608,42 +618,43 @@ func validateStringConstraints(s, name, kind string, schema *openAPISchema) []st } // validateScalarValue validates a string value against a schema (type/format/enum checks). -func validateScalarValue(val, name string, schema *openAPISchema) []string { +// The kind parameter ("parameter" or "field") is used in error messages. +func validateScalarValue(val, name, kind string, schema *openAPISchema) []string { var errs []string switch schema.Type { case "integer": n, err := strconv.ParseInt(val, 10, 64) if err != nil { - errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", name, val)) + errs = append(errs, fmt.Sprintf("%s %q must be an integer, got %q", kind, name, val)) return errs } if schema.Minimum != nil && float64(n) < *schema.Minimum { - errs = append(errs, fmt.Sprintf("parameter %q must be >= %v", name, *schema.Minimum)) + errs = append(errs, fmt.Sprintf("%s %q must be >= %v", kind, name, *schema.Minimum)) } if schema.Maximum != nil && float64(n) > *schema.Maximum { - errs = append(errs, fmt.Sprintf("parameter %q must be <= %v", name, *schema.Maximum)) + errs = append(errs, fmt.Sprintf("%s %q must be <= %v", kind, name, *schema.Maximum)) } case "number": f, err := strconv.ParseFloat(val, 64) if err != nil { - errs = append(errs, fmt.Sprintf("parameter %q must be a number, got %q", name, val)) + errs = append(errs, fmt.Sprintf("%s %q must be a number, got %q", kind, name, val)) return errs } if schema.Minimum != nil && f < *schema.Minimum { - errs = append(errs, fmt.Sprintf("parameter %q must be >= %v", name, *schema.Minimum)) + errs = append(errs, fmt.Sprintf("%s %q must be >= %v", kind, name, *schema.Minimum)) } if schema.Maximum != nil && f > *schema.Maximum { - errs = append(errs, fmt.Sprintf("parameter %q must be <= %v", name, *schema.Maximum)) + errs = append(errs, fmt.Sprintf("%s %q must be <= %v", kind, name, *schema.Maximum)) } case "boolean": if val != "true" && val != "false" { - errs = append(errs, fmt.Sprintf("parameter %q must be 'true' or 'false', got %q", name, val)) + errs = append(errs, fmt.Sprintf("%s %q must be 'true' or 'false', got %q", kind, name, val)) } case "string": - errs = append(errs, validateStringConstraints(val, name, "parameter", schema)...) + errs = append(errs, validateStringConstraints(val, name, kind, schema)...) } - // Enum validation: query/path parameters are always strings, so compare the - // string form of each enum value against the string parameter value. + // Enum validation: scalar values are always strings, so compare the + // string form of each enum value against the string value. if len(schema.Enum) > 0 { found := false for _, e := range schema.Enum { @@ -656,7 +667,7 @@ func validateScalarValue(val, name string, schema *openAPISchema) []string { } } if !found { - errs = append(errs, fmt.Sprintf("parameter %q must be one of %v", name, schema.Enum)) + errs = append(errs, fmt.Sprintf("%s %q must be one of %v", kind, name, schema.Enum)) } } return errs @@ -691,6 +702,34 @@ func validateJSONBody(body any, schema *openAPISchema) []string { return errs } +// validateFormBody validates url.Values (from application/x-www-form-urlencoded) against an object schema. +// Form values are always strings, so each field is validated using validateScalarValue. +func validateFormBody(values url.Values, schema *openAPISchema) []string { + var errs []string + // Check required fields + for _, req := range schema.Required { + if _, present := values[req]; !present { + errs = append(errs, fmt.Sprintf("request body: required field %q is missing", req)) + } + } + // Validate individual properties: check presence (not empty-string) so that + // present-but-empty fields are still validated against constraints like minLength/pattern/enum. + for field, propSchema := range schema.Properties { + vals, present := values[field] + if !present { + continue + } + var val string + if len(vals) > 0 { + val = vals[0] + } + if fieldErrs := validateScalarValue(val, field, "field", propSchema); len(fieldErrs) > 0 { + errs = append(errs, fieldErrs...) + } + } + return errs +} + // validateJSONValue validates a decoded JSON value against a schema. func validateJSONValue(val any, name string, schema *openAPISchema) []string { var errs []string diff --git a/module/openapi_test.go b/module/openapi_test.go index ec9d9eb3..481a5ccc 100644 --- a/module/openapi_test.go +++ b/module/openapi_test.go @@ -564,6 +564,103 @@ func TestOpenAPIModule_RequestValidation_Body(t *testing.T) { }) } +const webhookFormYAML = ` +openapi: "3.0.0" +info: + title: Webhook API + version: "1.0.0" +paths: + /webhook: + post: + operationId: receiveWebhook + requestBody: + required: true + content: + application/x-www-form-urlencoded: + schema: + type: object + required: + - Body + properties: + Body: + type: string + minLength: 1 + From: + type: string + responses: + "200": + description: OK +` + +func TestOpenAPIModule_RequestValidation_FormEncoded(t *testing.T) { + specPath := writeTempSpec(t, ".yaml", webhookFormYAML) + + mod := NewOpenAPIModule("webhook", OpenAPIConfig{ + SpecFile: specPath, + Validation: OpenAPIValidationConfig{Request: true}, + }) + if err := mod.Init(nil); err != nil { + t.Fatalf("Init: %v", err) + } + + router := &testRouter{} + mod.RegisterRoutes(router) + + h := router.findHandler("POST", "/webhook") + if h == nil { + t.Fatal("POST /webhook handler not found") + } + + t.Run("valid form body", func(t *testing.T) { + body := "Body=Hello+World&From=%2B15551234567" + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h.Handle(w, r) + if w.Code != http.StatusNotImplemented { + t.Errorf("expected 501 stub (validation OK), got %d: %s", w.Code, w.Body.String()) + } + }) + + t.Run("missing required field", func(t *testing.T) { + body := "From=%2B15551234567" // missing required 'Body' + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h.Handle(w, r) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 validation error for missing required field, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "Body") { + t.Errorf("expected error mentioning 'Body', got: %s", w.Body.String()) + } + }) + + t.Run("empty body when required", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader("")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h.Handle(w, r) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for empty required body, got %d: %s", w.Code, w.Body.String()) + } + }) + + t.Run("present-but-empty field violates minLength", func(t *testing.T) { + body := "Body=" // Body key present but empty value, violates minLength:1 + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h.Handle(w, r) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for empty field with minLength, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "minLength") { + t.Errorf("expected minLength error, got: %s", w.Body.String()) + } + }) +} + func TestOpenAPIModule_MaxBodySize(t *testing.T) { specPath := writeTempSpec(t, ".yaml", petstoreYAML) @@ -707,7 +804,7 @@ func TestValidateScalarValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - errs := validateScalarValue(tt.val, "param", tt.schema) + errs := validateScalarValue(tt.val, "param", "parameter", tt.schema) if tt.wantErr && len(errs) == 0 { t.Error("expected validation error, got none") } @@ -741,7 +838,7 @@ func TestHTMLEscape(t *testing.T) { func TestValidateScalarValue_Pattern(t *testing.T) { t.Run("valid pattern match", func(t *testing.T) { schema := &openAPISchema{Type: "string", Pattern: "^foo[0-9]+$"} - errs := validateScalarValue("foo123", "param", schema) + errs := validateScalarValue("foo123", "param", "parameter", schema) if len(errs) > 0 { t.Errorf("expected no errors, got %v", errs) } @@ -749,7 +846,7 @@ func TestValidateScalarValue_Pattern(t *testing.T) { t.Run("pattern mismatch", func(t *testing.T) { schema := &openAPISchema{Type: "string", Pattern: "^foo[0-9]+$"} - errs := validateScalarValue("bar", "param", schema) + errs := validateScalarValue("bar", "param", "parameter", schema) if len(errs) == 0 { t.Error("expected validation error for non-matching pattern, got none") } @@ -757,7 +854,7 @@ func TestValidateScalarValue_Pattern(t *testing.T) { t.Run("invalid regex pattern returns error", func(t *testing.T) { schema := &openAPISchema{Type: "string", Pattern: "["} - errs := validateScalarValue("anything", "param", schema) + errs := validateScalarValue("anything", "param", "parameter", schema) if len(errs) == 0 { t.Error("expected validation error for invalid regex pattern, got none") }