Skip to content

Commit 6f5d7ee

Browse files
Copilotintel352
andauthored
fix(openapi): address review feedback — correctness, security, and performance improvements (#146)
* Initial plan * fix: apply all review feedback to OpenAPI module Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: intel352 <77607+intel352@users.noreply.github.com>
1 parent 037a746 commit 6f5d7ee

4 files changed

Lines changed: 280 additions & 83 deletions

File tree

example/openapi-petstore.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ modules:
2222
dependsOn:
2323
- petstore-router
2424
config:
25+
# NOTE: spec_file is resolved relative to this config file's directory
26+
# via the _config_dir mechanism in config.ResolvePathInConfig.
2527
spec_file: specs/petstore.yaml
2628
base_path: /api/v1
2729
router: petstore-router

module/openapi.go

Lines changed: 137 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"html"
78
"io"
89
"log/slog"
10+
"math"
911
"net/http"
1012
"os"
1113
"regexp"
@@ -109,7 +111,8 @@ type OpenAPIModule struct {
109111
name string
110112
cfg OpenAPIConfig
111113
spec *openAPISpec
112-
specBytes []byte // raw spec bytes for serving
114+
specBytes []byte // raw spec bytes for serving (original file content)
115+
specJSON []byte // cached JSON-serialised spec for /openapi.json endpoint
113116
routerName string
114117
logger *slog.Logger
115118
}
@@ -219,11 +222,26 @@ func (m *OpenAPIModule) RegisterRoutes(router HTTPRouter) {
219222

220223
// Serve raw spec at /openapi.json and /openapi.yaml
221224
if len(m.specBytes) > 0 {
225+
// Cache the JSON representation once per registration.
226+
if m.specJSON == nil {
227+
specJSON, err := json.Marshal(m.spec)
228+
if err != nil {
229+
specJSON = m.specBytes // fallback to raw bytes
230+
}
231+
m.specJSON = specJSON
232+
}
233+
222234
specPathJSON := basePath + "/openapi.json"
223235
specPathYAML := basePath + "/openapi.yaml"
224-
specHandler := m.buildSpecHandler()
225-
router.AddRoute(http.MethodGet, specPathJSON, specHandler)
226-
router.AddRoute(http.MethodGet, specPathYAML, specHandler)
236+
237+
// JSON endpoint: serve re-serialised spec as JSON.
238+
router.AddRoute(http.MethodGet, specPathJSON, &openAPISpecHandler{specJSON: m.specJSON})
239+
240+
// YAML endpoint: serve the original spec bytes with a YAML content-type.
241+
// This preserves the source format; if the original file was YAML it is
242+
// served as YAML, and if it was JSON it is served as-is (JSON is valid YAML).
243+
router.AddRoute(http.MethodGet, specPathYAML, &openAPIRawSpecHandler{specBytes: m.specBytes, contentType: "application/yaml"})
244+
227245
m.logger.Debug("OpenAPI spec endpoint registered", "module", m.name, "paths", []string{specPathJSON, specPathYAML})
228246
}
229247

@@ -257,17 +275,6 @@ func (m *OpenAPIModule) buildRouteHandler(specPath, method string, op *openAPIOp
257275
}
258276
}
259277

260-
// buildSpecHandler serves the raw spec bytes as JSON (re-serialised from the
261-
// parsed spec) so consumers always get valid JSON regardless of whether the
262-
// original file was YAML.
263-
func (m *OpenAPIModule) buildSpecHandler() HTTPHandler {
264-
specJSON, err := json.Marshal(m.spec)
265-
if err != nil {
266-
specJSON = m.specBytes // fallback to raw bytes
267-
}
268-
return &openAPISpecHandler{specJSON: specJSON}
269-
}
270-
271278
// buildSwaggerUIHandler returns an inline Swagger UI page that loads the spec
272279
// from specURL. This avoids any asset bundling — a CDN-hosted swagger-ui is used.
273280
func (m *OpenAPIModule) buildSwaggerUIHandler(specURL string) HTTPHandler {
@@ -340,23 +347,36 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
340347
if mt, ok := h.op.RequestBody.Content[ct]; ok {
341348
mediaType = &mt
342349
} else if mt, ok := h.op.RequestBody.Content["application/json"]; ok && ct == "" {
343-
// Default to application/json when no Content-Type is sent
350+
// NOTE: Intentionally treat a missing Content-Type as application/json for request body
351+
// validation. Per HTTP semantics, an absent Content-Type would normally imply
352+
// application/octet-stream, but this engine is primarily used for JSON APIs and this
353+
// default simplifies client usage.
344354
mediaType = &mt
345355
}
346356

347-
if h.op.RequestBody.Required && r.ContentLength == 0 && r.Body == http.NoBody {
348-
errs = append(errs, "request body is required but missing")
349-
} else if mediaType != nil && mediaType.Schema != nil {
350-
bodyBytes, err := io.ReadAll(r.Body)
351-
if err == nil && len(bodyBytes) > 0 {
357+
// Read the body once so we can both check for presence (when required)
358+
// and validate against the schema. Restore it afterwards for downstream handlers.
359+
bodyBytes, readErr := io.ReadAll(r.Body)
360+
if readErr != nil {
361+
h.module.logger.Error("failed to read request body for validation",
362+
"module", h.module.name,
363+
"path", h.specPath,
364+
"error", readErr,
365+
)
366+
errs = append(errs, "failed to read request body")
367+
} else {
368+
// Always restore body for downstream handlers.
369+
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
370+
371+
if h.op.RequestBody.Required && len(bodyBytes) == 0 {
372+
errs = append(errs, "request body is required but missing")
373+
} else if mediaType != nil && mediaType.Schema != nil && len(bodyBytes) > 0 {
352374
var bodyData any
353375
if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr == nil {
354376
if bodyErrs := validateJSONBody(bodyData, mediaType.Schema); len(bodyErrs) > 0 {
355377
errs = append(errs, bodyErrs...)
356378
}
357379
}
358-
// Restore body for downstream handlers
359-
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
360380
}
361381
}
362382
}
@@ -376,6 +396,21 @@ func (h *openAPISpecHandler) Handle(w http.ResponseWriter, _ *http.Request) {
376396
_, _ = w.Write(h.specJSON) //nolint:gosec // G705: spec JSON is loaded from a trusted config file, not user input
377397
}
378398

399+
// ---- openAPIRawSpecHandler ----
400+
401+
// openAPIRawSpecHandler serves the raw spec bytes with the given content-type.
402+
// Used for the /openapi.yaml endpoint to preserve the original source format.
403+
type openAPIRawSpecHandler struct {
404+
specBytes []byte
405+
contentType string
406+
}
407+
408+
func (h *openAPIRawSpecHandler) Handle(w http.ResponseWriter, _ *http.Request) {
409+
w.Header().Set("Content-Type", h.contentType)
410+
w.WriteHeader(http.StatusOK)
411+
_, _ = w.Write(h.specBytes) //nolint:gosec // G705: spec bytes are loaded from a trusted config file, not user input
412+
}
413+
379414
// ---- openAPISwaggerUIHandler ----
380415

381416
type openAPISwaggerUIHandler struct {
@@ -392,15 +427,18 @@ func (h *openAPISwaggerUIHandler) Handle(w http.ResponseWriter, _ *http.Request)
392427

393428
// parseOpenAPISpec parses a YAML or JSON byte slice into an openAPISpec.
394429
func parseOpenAPISpec(data []byte) (*openAPISpec, error) {
430+
if len(data) == 0 {
431+
return nil, fmt.Errorf("openapi spec data is empty")
432+
}
395433
var spec openAPISpec
396434
// Try YAML first (which also handles JSON since JSON is valid YAML)
397435
if err := yaml.Unmarshal(data, &spec); err != nil {
398436
return nil, fmt.Errorf("yaml parse: %w", err)
399437
}
400438
if spec.OpenAPI == "" {
401-
// May be JSON that yaml couldn't decode properly; try JSON directly
439+
// YAML parse succeeded but produced an empty OpenAPI field; try JSON directly.
402440
if err := json.Unmarshal(data, &spec); err != nil {
403-
return nil, fmt.Errorf("neither yaml nor json parse succeeded: %w", err)
441+
return nil, fmt.Errorf("yaml parse produced empty OpenAPI field; json parse also failed: %w", err)
404442
}
405443
}
406444
return &spec, nil
@@ -445,6 +483,28 @@ func extractParam(r *http.Request, p openAPIParameter) string {
445483
return ""
446484
}
447485

486+
// validateStringConstraints validates the string constraints (minLength, maxLength,
487+
// pattern) for a string value. The kind parameter ("parameter" or "field") is used
488+
// in error messages.
489+
func validateStringConstraints(s, name, kind string, schema *openAPISchema) []string {
490+
var errs []string
491+
if schema.MinLength != nil && len(s) < *schema.MinLength {
492+
errs = append(errs, fmt.Sprintf("%s %q must have minLength %d", kind, name, *schema.MinLength))
493+
}
494+
if schema.MaxLength != nil && len(s) > *schema.MaxLength {
495+
errs = append(errs, fmt.Sprintf("%s %q must have maxLength %d", kind, name, *schema.MaxLength))
496+
}
497+
if schema.Pattern != "" {
498+
re, err := regexp.Compile(schema.Pattern)
499+
if err != nil {
500+
errs = append(errs, fmt.Sprintf("%s %q has an invalid pattern %q: %v", kind, name, schema.Pattern, err))
501+
} else if !re.MatchString(s) {
502+
errs = append(errs, fmt.Sprintf("%s %q does not match pattern %q", kind, name, schema.Pattern))
503+
}
504+
}
505+
return errs
506+
}
507+
448508
// validateScalarValue validates a string value against a schema (type/format/enum checks).
449509
func validateScalarValue(val, name string, schema *openAPISchema) []string {
450510
var errs []string
@@ -478,22 +538,16 @@ func validateScalarValue(val, name string, schema *openAPISchema) []string {
478538
errs = append(errs, fmt.Sprintf("parameter %q must be 'true' or 'false', got %q", name, val))
479539
}
480540
case "string":
481-
if schema.MinLength != nil && len(val) < *schema.MinLength {
482-
errs = append(errs, fmt.Sprintf("parameter %q must have minLength %d", name, *schema.MinLength))
483-
}
484-
if schema.MaxLength != nil && len(val) > *schema.MaxLength {
485-
errs = append(errs, fmt.Sprintf("parameter %q must have maxLength %d", name, *schema.MaxLength))
486-
}
487-
if schema.Pattern != "" {
488-
if ok, _ := regexp.MatchString(schema.Pattern, val); !ok {
489-
errs = append(errs, fmt.Sprintf("parameter %q does not match pattern %q", name, schema.Pattern))
490-
}
491-
}
541+
errs = append(errs, validateStringConstraints(val, name, "parameter", schema)...)
492542
}
493-
// Enum validation
543+
// Enum validation: query/path parameters are always strings, so compare the
544+
// string form of each enum value against the string parameter value.
494545
if len(schema.Enum) > 0 {
495546
found := false
496547
for _, e := range schema.Enum {
548+
if e == nil {
549+
continue
550+
}
497551
if fmt.Sprintf("%v", e) == val {
498552
found = true
499553
break
@@ -542,23 +596,16 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
542596
case "string":
543597
s, ok := val.(string)
544598
if !ok {
545-
return []string{fmt.Sprintf("field %q must be a string", name)}
546-
}
547-
if schema.MinLength != nil && len(s) < *schema.MinLength {
548-
errs = append(errs, fmt.Sprintf("field %q must have minLength %d", name, *schema.MinLength))
549-
}
550-
if schema.MaxLength != nil && len(s) > *schema.MaxLength {
551-
errs = append(errs, fmt.Sprintf("field %q must have maxLength %d", name, *schema.MaxLength))
552-
}
553-
if schema.Pattern != "" {
554-
if ok2, _ := regexp.MatchString(schema.Pattern, s); !ok2 {
555-
errs = append(errs, fmt.Sprintf("field %q does not match pattern %q", name, schema.Pattern))
556-
}
599+
return []string{fmt.Sprintf("field %q must be a string, got %T", name, val)}
557600
}
601+
errs = append(errs, validateStringConstraints(s, name, "field", schema)...)
558602
case "integer":
559603
f, ok := val.(float64)
560604
if !ok {
561-
return []string{fmt.Sprintf("field %q must be an integer", name)}
605+
return []string{fmt.Sprintf("field %q must be an integer, got %T", name, val)}
606+
}
607+
if f != math.Trunc(f) {
608+
return []string{fmt.Sprintf("field %q must be an integer, got %v", name, f)}
562609
}
563610
if schema.Minimum != nil && f < *schema.Minimum {
564611
errs = append(errs, fmt.Sprintf("field %q must be >= %v", name, *schema.Minimum))
@@ -569,7 +616,7 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
569616
case "number":
570617
f, ok := val.(float64)
571618
if !ok {
572-
return []string{fmt.Sprintf("field %q must be a number", name)}
619+
return []string{fmt.Sprintf("field %q must be a number, got %T", name, val)}
573620
}
574621
if schema.Minimum != nil && f < *schema.Minimum {
575622
errs = append(errs, fmt.Sprintf("field %q must be >= %v", name, *schema.Minimum))
@@ -579,21 +626,41 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
579626
}
580627
case "boolean":
581628
if _, ok := val.(bool); !ok {
582-
errs = append(errs, fmt.Sprintf("field %q must be a boolean", name))
629+
errs = append(errs, fmt.Sprintf("field %q must be a boolean, got %T", name, val))
583630
}
584631
case "object":
585632
if subErrs := validateJSONBody(val, schema); len(subErrs) > 0 {
586633
errs = append(errs, subErrs...)
587634
}
588635
}
589-
// Enum validation
636+
// Enum validation: use type-aware comparison to prevent e.g. int 1 matching string "1".
590637
if len(schema.Enum) > 0 {
591638
found := false
592639
for _, e := range schema.Enum {
593-
if fmt.Sprintf("%v", e) == fmt.Sprintf("%v", val) {
640+
if e == nil {
641+
continue
642+
}
643+
// Direct equality covers string==string, bool==bool, float64==float64.
644+
if e == val {
594645
found = true
595646
break
596647
}
648+
// Handle numeric type mismatch: YAML decodes integers as int, but JSON
649+
// decodes all numbers as float64, so int(1) != float64(1) even though
650+
// they represent the same value.
651+
switch ev := e.(type) {
652+
case int:
653+
if fv, ok := val.(float64); ok && float64(ev) == fv {
654+
found = true
655+
}
656+
case int64:
657+
if fv, ok := val.(float64); ok && float64(ev) == fv {
658+
found = true
659+
}
660+
}
661+
if found {
662+
break
663+
}
597664
}
598665
if !found {
599666
errs = append(errs, fmt.Sprintf("field %q must be one of %v", name, schema.Enum))
@@ -604,21 +671,33 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
604671

605672
// swaggerUIHTML returns a minimal, self-contained Swagger UI HTML page that
606673
// loads the spec from specURL using the official Swagger UI CDN bundle.
674+
//
675+
// The base URL for the Swagger UI assets can be overridden via the
676+
// SWAGGER_UI_ASSETS_BASE_URL environment variable. If unset, it defaults to
677+
// "https://unpkg.com/swagger-ui-dist@5". This is useful for air-gapped
678+
// environments or when a local mirror is preferred.
607679
func swaggerUIHTML(title, specURL string) string {
608680
if title == "" {
609681
title = "API Documentation"
610682
}
683+
baseURL := os.Getenv("SWAGGER_UI_ASSETS_BASE_URL")
684+
if baseURL == "" {
685+
baseURL = "https://unpkg.com/swagger-ui-dist@5"
686+
}
687+
baseURL = strings.TrimRight(baseURL, "/")
688+
cssURL := baseURL + "/swagger-ui.css"
689+
jsURL := baseURL + "/swagger-ui-bundle.js"
611690
return `<!DOCTYPE html>
612691
<html lang="en">
613692
<head>
614693
<meta charset="UTF-8">
615694
<meta name="viewport" content="width=device-width, initial-scale=1.0">
616695
<title>` + htmlEscape(title) + `</title>
617-
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
696+
<link rel="stylesheet" href="` + htmlEscape(cssURL) + `">
618697
</head>
619698
<body>
620699
<div id="swagger-ui"></div>
621-
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
700+
<script src="` + htmlEscape(jsURL) + `"></script>
622701
<script>
623702
SwaggerUIBundle({
624703
url: "` + htmlEscape(specURL) + `",
@@ -632,11 +711,7 @@ func swaggerUIHTML(title, specURL string) string {
632711
}
633712

634713
// htmlEscape escapes a string for safe embedding in HTML attributes/text.
714+
// It delegates to the standard library html.EscapeString for robust escaping.
635715
func htmlEscape(s string) string {
636-
s = strings.ReplaceAll(s, "&", "&amp;")
637-
s = strings.ReplaceAll(s, "<", "&lt;")
638-
s = strings.ReplaceAll(s, ">", "&gt;")
639-
s = strings.ReplaceAll(s, `"`, "&#34;")
640-
s = strings.ReplaceAll(s, "'", "&#39;")
641-
return s
716+
return html.EscapeString(s)
642717
}

0 commit comments

Comments
 (0)