Skip to content

Commit ef2128d

Browse files
Copilotintel352
andcommitted
fix: address PR review comments for OpenAPI response validation
- responseCapturingWriter uses own header map to prevent header leakage when validation fails and a 500 is returned instead of the pipeline response - validateJSONBody accepts bodyLabel parameter so error messages correctly say 'request body' or 'response body' depending on context (was always hardcoded to 'request body') - Implement additionalProperties validation in validateJSONBody — keys not in Properties are now validated against AdditionalProperties schema when defined - Rename TestOpenAPIModule_ResponseValidation_DefaultFallback_Valid to TestOpenAPIModule_ResponseValidation_DefaultFallback_InvalidFallback (test asserts 500 for schema mismatch, not a valid response) - Add nil guards for router.findHandler() in all 4 array constraint and warn-action subtests Co-authored-by: intel352 <77607+intel352@users.noreply.github.com>
1 parent 9882960 commit ef2128d

2 files changed

Lines changed: 47 additions & 11 deletions

File tree

module/openapi.go

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
569569
var bodyData any
570570
if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr != nil {
571571
errs = append(errs, fmt.Sprintf("request body contains invalid JSON: %v", jsonErr))
572-
} else if bodyErrs := validateJSONValue(bodyData, "body", mediaType.Schema); len(bodyErrs) > 0 {
572+
} else if bodyErrs := validateJSONValue(bodyData, "request body", mediaType.Schema); len(bodyErrs) > 0 {
573573
errs = append(errs, bodyErrs...)
574574
}
575575
}
@@ -584,8 +584,11 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
584584

585585
// responseCapturingWriter buffers the response body, status code, and headers
586586
// so we can validate them against the OpenAPI spec before sending to the client.
587+
// It uses its own header map to prevent leaked headers reaching the client when
588+
// validation fails and a different (500) response needs to be sent.
587589
type responseCapturingWriter struct {
588590
underlying http.ResponseWriter
591+
headers http.Header // own header map; copied to underlying only on flush
589592
body bytes.Buffer
590593
statusCode int
591594
headerSent bool
@@ -594,14 +597,15 @@ type responseCapturingWriter struct {
594597
func newResponseCapturingWriter(w http.ResponseWriter) *responseCapturingWriter {
595598
return &responseCapturingWriter{
596599
underlying: w,
600+
headers: make(http.Header),
597601
statusCode: http.StatusOK,
598602
}
599603
}
600604

601-
// Header returns the underlying response writer's header map so that callers
602-
// can set headers which will be flushed later.
605+
// Header returns this writer's own header map so that callers can set headers
606+
// which are only forwarded to the underlying writer when flush() is called.
603607
func (c *responseCapturingWriter) Header() http.Header {
604-
return c.underlying.Header()
608+
return c.headers
605609
}
606610

607611
// Write captures the response body into an internal buffer.
@@ -614,12 +618,18 @@ func (c *responseCapturingWriter) WriteHeader(code int) {
614618
c.statusCode = code
615619
}
616620

617-
// flush sends the buffered status code, headers, and body to the underlying writer.
621+
// flush copies captured headers and sends the buffered status code and body to the underlying writer.
618622
func (c *responseCapturingWriter) flush() {
619623
if c.headerSent {
620624
return
621625
}
622626
c.headerSent = true
627+
// Copy captured headers to the underlying writer before sending the status code.
628+
for k, vals := range c.headers {
629+
for _, v := range vals {
630+
c.underlying.Header().Add(k, v)
631+
}
632+
}
623633
c.underlying.WriteHeader(c.statusCode)
624634
_, _ = c.underlying.Write(c.body.Bytes()) //nolint:gosec // G705: body is pipeline output, written back to same response
625635
}
@@ -694,7 +704,7 @@ func (h *openAPIRouteHandler) validateResponse(statusCode int, headers http.Head
694704
return errs
695705
}
696706

697-
if bodyErrs := validateJSONValue(bodyData, "response", mediaType.Schema); len(bodyErrs) > 0 {
707+
if bodyErrs := validateJSONValue(bodyData, "response body", mediaType.Schema); len(bodyErrs) > 0 {
698708
errs = append(errs, bodyErrs...)
699709
}
700710

@@ -942,19 +952,21 @@ func validateScalarValue(val, name, kind string, schema *openAPISchema) []string
942952
}
943953

944954
// validateJSONBody validates a decoded JSON body against an object schema.
945-
func validateJSONBody(body any, schema *openAPISchema) []string {
955+
// The bodyLabel parameter (e.g. "request body" or "response body") is used in
956+
// error messages to distinguish validation context.
957+
func validateJSONBody(body any, schema *openAPISchema, bodyLabel string) []string {
946958
var errs []string
947959
obj, ok := body.(map[string]any)
948960
if !ok {
949961
if schema.Type == "object" {
950-
return []string{"request body must be a JSON object"}
962+
return []string{bodyLabel + " must be a JSON object"}
951963
}
952964
return nil
953965
}
954966
// Check required fields
955967
for _, req := range schema.Required {
956968
if _, present := obj[req]; !present {
957-
errs = append(errs, fmt.Sprintf("request body: required field %q is missing", req))
969+
errs = append(errs, fmt.Sprintf("%s: required field %q is missing", bodyLabel, req))
958970
}
959971
}
960972
// Validate individual properties
@@ -967,6 +979,18 @@ func validateJSONBody(body any, schema *openAPISchema) []string {
967979
errs = append(errs, fieldErrs...)
968980
}
969981
}
982+
// Validate additionalProperties: keys not declared in Properties are checked
983+
// against the additionalProperties schema when it is specified.
984+
if schema.AdditionalProperties != nil {
985+
for key, val := range obj {
986+
if _, defined := schema.Properties[key]; defined {
987+
continue
988+
}
989+
if fieldErrs := validateJSONValue(val, key, schema.AdditionalProperties); len(fieldErrs) > 0 {
990+
errs = append(errs, fieldErrs...)
991+
}
992+
}
993+
}
970994
return errs
971995
}
972996

@@ -1038,7 +1062,7 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
10381062
errs = append(errs, fmt.Sprintf("field %q must be a boolean, got %T", name, val))
10391063
}
10401064
case "object":
1041-
if subErrs := validateJSONBody(val, schema); len(subErrs) > 0 {
1065+
if subErrs := validateJSONBody(val, schema, name); len(subErrs) > 0 {
10421066
errs = append(errs, subErrs...)
10431067
}
10441068
case "array":

module/openapi_test.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1766,7 +1766,7 @@ func TestOpenAPIModule_ResponseValidation_InvalidResponse_WarnAction(t *testing.
17661766
}
17671767
}
17681768

1769-
func TestOpenAPIModule_ResponseValidation_DefaultFallback_Valid(t *testing.T) {
1769+
func TestOpenAPIModule_ResponseValidation_DefaultFallback_InvalidFallback(t *testing.T) {
17701770
specPath := writeTempSpec(t, ".yaml", responseValidationYAML)
17711771

17721772
mod := NewOpenAPIModule("resp-api", OpenAPIConfig{
@@ -2246,6 +2246,9 @@ paths:
22462246
router := &testRouter{}
22472247
mod.RegisterRoutes(router)
22482248
h := router.findHandler("GET", "/api/items")
2249+
if h == nil {
2250+
t.Fatal("route GET /api/items not registered")
2251+
}
22492252
w := httptest.NewRecorder()
22502253
r := httptest.NewRequest(http.MethodGet, "/api/items", nil)
22512254
h.Handle(w, r)
@@ -2286,6 +2289,9 @@ paths:
22862289
router := &testRouter{}
22872290
mod.RegisterRoutes(router)
22882291
h := router.findHandler("GET", "/api/items")
2292+
if h == nil {
2293+
t.Fatal("route GET /api/items not registered")
2294+
}
22892295
w := httptest.NewRecorder()
22902296
r := httptest.NewRequest(http.MethodGet, "/api/items", nil)
22912297
h.Handle(w, r)
@@ -2326,6 +2332,9 @@ paths:
23262332
router := &testRouter{}
23272333
mod.RegisterRoutes(router)
23282334
h := router.findHandler("GET", "/api/items")
2335+
if h == nil {
2336+
t.Fatal("route GET /api/items not registered")
2337+
}
23292338
w := httptest.NewRecorder()
23302339
r := httptest.NewRequest(http.MethodGet, "/api/items", nil)
23312340
h.Handle(w, r)
@@ -2373,6 +2382,9 @@ func TestOpenAPIModule_ResponseValidation_DefaultAction_IsWarn(t *testing.T) {
23732382
mod.RegisterRoutes(router)
23742383

23752384
h := router.findHandler("GET", "/api/pets")
2385+
if h == nil {
2386+
t.Fatal("handler for GET /api/pets not found")
2387+
}
23762388
w := httptest.NewRecorder()
23772389
r := httptest.NewRequest(http.MethodGet, "/api/pets", nil)
23782390
h.Handle(w, r)

0 commit comments

Comments
 (0)