Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions module/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"log/slog"
"math"
"net/http"
"net/url"
"os"
"regexp"
"sort"
Expand Down Expand Up @@ -413,7 +414,7 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
continue
}
if val != "" && p.Schema != nil {
if schemaErrs := validateScalarValue(val, p.Name, p.Schema); len(schemaErrs) > 0 {
if schemaErrs := validateScalarValue(val, p.Name, "parameter", p.Schema); len(schemaErrs) > 0 {
errs = append(errs, schemaErrs...)
}
}
Expand Down Expand Up @@ -473,11 +474,20 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
if h.op.RequestBody.Required && len(bodyBytes) == 0 {
errs = append(errs, "request body is required but missing")
} else if mediaType != nil && mediaType.Schema != nil && len(bodyBytes) > 0 {
var bodyData any
if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr != nil {
errs = append(errs, fmt.Sprintf("request body contains invalid JSON: %v", jsonErr))
} else if bodyErrs := validateJSONValue(bodyData, "body", mediaType.Schema); len(bodyErrs) > 0 {
errs = append(errs, bodyErrs...)
if ct == "application/x-www-form-urlencoded" {
formValues, parseErr := url.ParseQuery(string(bodyBytes))
if parseErr != nil {
errs = append(errs, fmt.Sprintf("request body contains invalid form data: %v", parseErr))
} else if formErrs := validateFormBody(formValues, mediaType.Schema); len(formErrs) > 0 {
errs = append(errs, formErrs...)
}
} else {
var bodyData any
if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr != nil {
errs = append(errs, fmt.Sprintf("request body contains invalid JSON: %v", jsonErr))
} else if bodyErrs := validateJSONValue(bodyData, "body", mediaType.Schema); len(bodyErrs) > 0 {
errs = append(errs, bodyErrs...)
}
}
}
}
Expand Down Expand Up @@ -608,42 +618,43 @@ func validateStringConstraints(s, name, kind string, schema *openAPISchema) []st
}

// validateScalarValue validates a string value against a schema (type/format/enum checks).
func validateScalarValue(val, name string, schema *openAPISchema) []string {
// The kind parameter ("parameter" or "field") is used in error messages.
func validateScalarValue(val, name, kind string, schema *openAPISchema) []string {
var errs []string
switch schema.Type {
case "integer":
n, err := strconv.ParseInt(val, 10, 64)
if err != nil {
errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", name, val))
errs = append(errs, fmt.Sprintf("%s %q must be an integer, got %q", kind, name, val))
return errs
}
if schema.Minimum != nil && float64(n) < *schema.Minimum {
errs = append(errs, fmt.Sprintf("parameter %q must be >= %v", name, *schema.Minimum))
errs = append(errs, fmt.Sprintf("%s %q must be >= %v", kind, name, *schema.Minimum))
}
if schema.Maximum != nil && float64(n) > *schema.Maximum {
errs = append(errs, fmt.Sprintf("parameter %q must be <= %v", name, *schema.Maximum))
errs = append(errs, fmt.Sprintf("%s %q must be <= %v", kind, name, *schema.Maximum))
}
case "number":
f, err := strconv.ParseFloat(val, 64)
if err != nil {
errs = append(errs, fmt.Sprintf("parameter %q must be a number, got %q", name, val))
errs = append(errs, fmt.Sprintf("%s %q must be a number, got %q", kind, name, val))
return errs
}
if schema.Minimum != nil && f < *schema.Minimum {
errs = append(errs, fmt.Sprintf("parameter %q must be >= %v", name, *schema.Minimum))
errs = append(errs, fmt.Sprintf("%s %q must be >= %v", kind, name, *schema.Minimum))
}
if schema.Maximum != nil && f > *schema.Maximum {
errs = append(errs, fmt.Sprintf("parameter %q must be <= %v", name, *schema.Maximum))
errs = append(errs, fmt.Sprintf("%s %q must be <= %v", kind, name, *schema.Maximum))
}
case "boolean":
if val != "true" && val != "false" {
errs = append(errs, fmt.Sprintf("parameter %q must be 'true' or 'false', got %q", name, val))
errs = append(errs, fmt.Sprintf("%s %q must be 'true' or 'false', got %q", kind, name, val))
}
case "string":
errs = append(errs, validateStringConstraints(val, name, "parameter", schema)...)
errs = append(errs, validateStringConstraints(val, name, kind, schema)...)
}
// Enum validation: query/path parameters are always strings, so compare the
// string form of each enum value against the string parameter value.
// Enum validation: scalar values are always strings, so compare the
// string form of each enum value against the string value.
if len(schema.Enum) > 0 {
found := false
for _, e := range schema.Enum {
Expand All @@ -656,7 +667,7 @@ func validateScalarValue(val, name string, schema *openAPISchema) []string {
}
}
if !found {
errs = append(errs, fmt.Sprintf("parameter %q must be one of %v", name, schema.Enum))
errs = append(errs, fmt.Sprintf("%s %q must be one of %v", kind, name, schema.Enum))
}
}
return errs
Expand Down Expand Up @@ -691,6 +702,34 @@ func validateJSONBody(body any, schema *openAPISchema) []string {
return errs
}

// validateFormBody validates url.Values (from application/x-www-form-urlencoded) against an object schema.
// Form values are always strings, so each field is validated using validateScalarValue.
func validateFormBody(values url.Values, schema *openAPISchema) []string {
var errs []string
// Check required fields
for _, req := range schema.Required {
if _, present := values[req]; !present {
errs = append(errs, fmt.Sprintf("request body: required field %q is missing", req))
}
}
// Validate individual properties: check presence (not empty-string) so that
// present-but-empty fields are still validated against constraints like minLength/pattern/enum.
for field, propSchema := range schema.Properties {
vals, present := values[field]
if !present {
continue
}
var val string
if len(vals) > 0 {
val = vals[0]
}
if fieldErrs := validateScalarValue(val, field, "field", propSchema); len(fieldErrs) > 0 {
errs = append(errs, fieldErrs...)
}
}
return errs
}

// validateJSONValue validates a decoded JSON value against a schema.
func validateJSONValue(val any, name string, schema *openAPISchema) []string {
var errs []string
Expand Down
105 changes: 101 additions & 4 deletions module/openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,103 @@ func TestOpenAPIModule_RequestValidation_Body(t *testing.T) {
})
}

const webhookFormYAML = `
openapi: "3.0.0"
info:
title: Webhook API
version: "1.0.0"
paths:
/webhook:
post:
operationId: receiveWebhook
requestBody:
required: true
content:
application/x-www-form-urlencoded:
schema:
type: object
required:
- Body
properties:
Body:
type: string
minLength: 1
From:
type: string
responses:
"200":
description: OK
`

func TestOpenAPIModule_RequestValidation_FormEncoded(t *testing.T) {
specPath := writeTempSpec(t, ".yaml", webhookFormYAML)

mod := NewOpenAPIModule("webhook", OpenAPIConfig{
SpecFile: specPath,
Validation: OpenAPIValidationConfig{Request: true},
})
if err := mod.Init(nil); err != nil {
t.Fatalf("Init: %v", err)
}

router := &testRouter{}
mod.RegisterRoutes(router)

h := router.findHandler("POST", "/webhook")
if h == nil {
t.Fatal("POST /webhook handler not found")
}

t.Run("valid form body", func(t *testing.T) {
body := "Body=Hello+World&From=%2B15551234567"
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h.Handle(w, r)
if w.Code != http.StatusNotImplemented {
t.Errorf("expected 501 stub (validation OK), got %d: %s", w.Code, w.Body.String())
}
})

t.Run("missing required field", func(t *testing.T) {
body := "From=%2B15551234567" // missing required 'Body'
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h.Handle(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 validation error for missing required field, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "Body") {
t.Errorf("expected error mentioning 'Body', got: %s", w.Body.String())
}
})

t.Run("empty body when required", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(""))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h.Handle(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for empty required body, got %d: %s", w.Code, w.Body.String())
}
})

t.Run("present-but-empty field violates minLength", func(t *testing.T) {
body := "Body=" // Body key present but empty value, violates minLength:1
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h.Handle(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for empty field with minLength, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "minLength") {
t.Errorf("expected minLength error, got: %s", w.Body.String())
}
})
}

func TestOpenAPIModule_MaxBodySize(t *testing.T) {
specPath := writeTempSpec(t, ".yaml", petstoreYAML)

Expand Down Expand Up @@ -707,7 +804,7 @@ func TestValidateScalarValue(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errs := validateScalarValue(tt.val, "param", tt.schema)
errs := validateScalarValue(tt.val, "param", "parameter", tt.schema)
if tt.wantErr && len(errs) == 0 {
t.Error("expected validation error, got none")
}
Expand Down Expand Up @@ -741,23 +838,23 @@ func TestHTMLEscape(t *testing.T) {
func TestValidateScalarValue_Pattern(t *testing.T) {
t.Run("valid pattern match", func(t *testing.T) {
schema := &openAPISchema{Type: "string", Pattern: "^foo[0-9]+$"}
errs := validateScalarValue("foo123", "param", schema)
errs := validateScalarValue("foo123", "param", "parameter", schema)
if len(errs) > 0 {
t.Errorf("expected no errors, got %v", errs)
}
})

t.Run("pattern mismatch", func(t *testing.T) {
schema := &openAPISchema{Type: "string", Pattern: "^foo[0-9]+$"}
errs := validateScalarValue("bar", "param", schema)
errs := validateScalarValue("bar", "param", "parameter", schema)
if len(errs) == 0 {
t.Error("expected validation error for non-matching pattern, got none")
}
})

t.Run("invalid regex pattern returns error", func(t *testing.T) {
schema := &openAPISchema{Type: "string", Pattern: "["}
errs := validateScalarValue("anything", "param", schema)
errs := validateScalarValue("anything", "param", "parameter", schema)
if len(errs) == 0 {
t.Error("expected validation error for invalid regex pattern, got none")
}
Expand Down
Loading