@@ -4,8 +4,10 @@ import (
44 "context"
55 "encoding/json"
66 "fmt"
7+ "html"
78 "io"
89 "log/slog"
10+ "math"
911 "net/http"
1012 "os"
1113 "regexp"
@@ -109,7 +111,8 @@ type OpenAPIModule struct {
109111 name string
110112 cfg OpenAPIConfig
111113 spec * openAPISpec
112- specBytes []byte // raw spec bytes for serving
114+ specBytes []byte // raw spec bytes for serving (original file content)
115+ specJSON []byte // cached JSON-serialised spec for /openapi.json endpoint
113116 routerName string
114117 logger * slog.Logger
115118}
@@ -219,11 +222,26 @@ func (m *OpenAPIModule) RegisterRoutes(router HTTPRouter) {
219222
220223 // Serve raw spec at /openapi.json and /openapi.yaml
221224 if len (m .specBytes ) > 0 {
225+ // Cache the JSON representation once per registration.
226+ if m .specJSON == nil {
227+ specJSON , err := json .Marshal (m .spec )
228+ if err != nil {
229+ specJSON = m .specBytes // fallback to raw bytes
230+ }
231+ m .specJSON = specJSON
232+ }
233+
222234 specPathJSON := basePath + "/openapi.json"
223235 specPathYAML := basePath + "/openapi.yaml"
224- specHandler := m .buildSpecHandler ()
225- router .AddRoute (http .MethodGet , specPathJSON , specHandler )
226- router .AddRoute (http .MethodGet , specPathYAML , specHandler )
236+
237+ // JSON endpoint: serve re-serialised spec as JSON.
238+ router .AddRoute (http .MethodGet , specPathJSON , & openAPISpecHandler {specJSON : m .specJSON })
239+
240+ // YAML endpoint: serve the original spec bytes with a YAML content-type.
241+ // This preserves the source format; if the original file was YAML it is
242+ // served as YAML, and if it was JSON it is served as-is (JSON is valid YAML).
243+ router .AddRoute (http .MethodGet , specPathYAML , & openAPIRawSpecHandler {specBytes : m .specBytes , contentType : "application/yaml" })
244+
227245 m .logger .Debug ("OpenAPI spec endpoint registered" , "module" , m .name , "paths" , []string {specPathJSON , specPathYAML })
228246 }
229247
@@ -257,17 +275,6 @@ func (m *OpenAPIModule) buildRouteHandler(specPath, method string, op *openAPIOp
257275 }
258276}
259277
260- // buildSpecHandler serves the raw spec bytes as JSON (re-serialised from the
261- // parsed spec) so consumers always get valid JSON regardless of whether the
262- // original file was YAML.
263- func (m * OpenAPIModule ) buildSpecHandler () HTTPHandler {
264- specJSON , err := json .Marshal (m .spec )
265- if err != nil {
266- specJSON = m .specBytes // fallback to raw bytes
267- }
268- return & openAPISpecHandler {specJSON : specJSON }
269- }
270-
271278// buildSwaggerUIHandler returns an inline Swagger UI page that loads the spec
272279// from specURL. This avoids any asset bundling — a CDN-hosted swagger-ui is used.
273280func (m * OpenAPIModule ) buildSwaggerUIHandler (specURL string ) HTTPHandler {
@@ -340,23 +347,36 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
340347 if mt , ok := h .op .RequestBody .Content [ct ]; ok {
341348 mediaType = & mt
342349 } else if mt , ok := h .op .RequestBody .Content ["application/json" ]; ok && ct == "" {
343- // Default to application/json when no Content-Type is sent
350+ // NOTE: Intentionally treat a missing Content-Type as application/json for request body
351+ // validation. Per HTTP semantics, an absent Content-Type would normally imply
352+ // application/octet-stream, but this engine is primarily used for JSON APIs and this
353+ // default simplifies client usage.
344354 mediaType = & mt
345355 }
346356
347- if h .op .RequestBody .Required && r .ContentLength == 0 && r .Body == http .NoBody {
348- errs = append (errs , "request body is required but missing" )
349- } else if mediaType != nil && mediaType .Schema != nil {
350- bodyBytes , err := io .ReadAll (r .Body )
351- if err == nil && len (bodyBytes ) > 0 {
357+ // Read the body once so we can both check for presence (when required)
358+ // and validate against the schema. Restore it afterwards for downstream handlers.
359+ bodyBytes , readErr := io .ReadAll (r .Body )
360+ if readErr != nil {
361+ h .module .logger .Error ("failed to read request body for validation" ,
362+ "module" , h .module .name ,
363+ "path" , h .specPath ,
364+ "error" , readErr ,
365+ )
366+ errs = append (errs , "failed to read request body" )
367+ } else {
368+ // Always restore body for downstream handlers.
369+ r .Body = io .NopCloser (strings .NewReader (string (bodyBytes )))
370+
371+ if h .op .RequestBody .Required && len (bodyBytes ) == 0 {
372+ errs = append (errs , "request body is required but missing" )
373+ } else if mediaType != nil && mediaType .Schema != nil && len (bodyBytes ) > 0 {
352374 var bodyData any
353375 if jsonErr := json .Unmarshal (bodyBytes , & bodyData ); jsonErr == nil {
354376 if bodyErrs := validateJSONBody (bodyData , mediaType .Schema ); len (bodyErrs ) > 0 {
355377 errs = append (errs , bodyErrs ... )
356378 }
357379 }
358- // Restore body for downstream handlers
359- r .Body = io .NopCloser (strings .NewReader (string (bodyBytes )))
360380 }
361381 }
362382 }
@@ -376,6 +396,21 @@ func (h *openAPISpecHandler) Handle(w http.ResponseWriter, _ *http.Request) {
376396 _ , _ = w .Write (h .specJSON ) //nolint:gosec // G705: spec JSON is loaded from a trusted config file, not user input
377397}
378398
399+ // ---- openAPIRawSpecHandler ----
400+
401+ // openAPIRawSpecHandler serves the raw spec bytes with the given content-type.
402+ // Used for the /openapi.yaml endpoint to preserve the original source format.
403+ type openAPIRawSpecHandler struct {
404+ specBytes []byte
405+ contentType string
406+ }
407+
408+ func (h * openAPIRawSpecHandler ) Handle (w http.ResponseWriter , _ * http.Request ) {
409+ w .Header ().Set ("Content-Type" , h .contentType )
410+ w .WriteHeader (http .StatusOK )
411+ _ , _ = w .Write (h .specBytes ) //nolint:gosec // G705: spec bytes are loaded from a trusted config file, not user input
412+ }
413+
379414// ---- openAPISwaggerUIHandler ----
380415
381416type openAPISwaggerUIHandler struct {
@@ -392,15 +427,18 @@ func (h *openAPISwaggerUIHandler) Handle(w http.ResponseWriter, _ *http.Request)
392427
393428// parseOpenAPISpec parses a YAML or JSON byte slice into an openAPISpec.
394429func parseOpenAPISpec (data []byte ) (* openAPISpec , error ) {
430+ if len (data ) == 0 {
431+ return nil , fmt .Errorf ("openapi spec data is empty" )
432+ }
395433 var spec openAPISpec
396434 // Try YAML first (which also handles JSON since JSON is valid YAML)
397435 if err := yaml .Unmarshal (data , & spec ); err != nil {
398436 return nil , fmt .Errorf ("yaml parse: %w" , err )
399437 }
400438 if spec .OpenAPI == "" {
401- // May be JSON that yaml couldn't decode properly ; try JSON directly
439+ // YAML parse succeeded but produced an empty OpenAPI field ; try JSON directly.
402440 if err := json .Unmarshal (data , & spec ); err != nil {
403- return nil , fmt .Errorf ("neither yaml nor json parse succeeded : %w" , err )
441+ return nil , fmt .Errorf ("yaml parse produced empty OpenAPI field; json parse also failed : %w" , err )
404442 }
405443 }
406444 return & spec , nil
@@ -445,6 +483,28 @@ func extractParam(r *http.Request, p openAPIParameter) string {
445483 return ""
446484}
447485
486+ // validateStringConstraints validates the string constraints (minLength, maxLength,
487+ // pattern) for a string value. The kind parameter ("parameter" or "field") is used
488+ // in error messages.
489+ func validateStringConstraints (s , name , kind string , schema * openAPISchema ) []string {
490+ var errs []string
491+ if schema .MinLength != nil && len (s ) < * schema .MinLength {
492+ errs = append (errs , fmt .Sprintf ("%s %q must have minLength %d" , kind , name , * schema .MinLength ))
493+ }
494+ if schema .MaxLength != nil && len (s ) > * schema .MaxLength {
495+ errs = append (errs , fmt .Sprintf ("%s %q must have maxLength %d" , kind , name , * schema .MaxLength ))
496+ }
497+ if schema .Pattern != "" {
498+ re , err := regexp .Compile (schema .Pattern )
499+ if err != nil {
500+ errs = append (errs , fmt .Sprintf ("%s %q has an invalid pattern %q: %v" , kind , name , schema .Pattern , err ))
501+ } else if ! re .MatchString (s ) {
502+ errs = append (errs , fmt .Sprintf ("%s %q does not match pattern %q" , kind , name , schema .Pattern ))
503+ }
504+ }
505+ return errs
506+ }
507+
448508// validateScalarValue validates a string value against a schema (type/format/enum checks).
449509func validateScalarValue (val , name string , schema * openAPISchema ) []string {
450510 var errs []string
@@ -478,22 +538,16 @@ func validateScalarValue(val, name string, schema *openAPISchema) []string {
478538 errs = append (errs , fmt .Sprintf ("parameter %q must be 'true' or 'false', got %q" , name , val ))
479539 }
480540 case "string" :
481- if schema .MinLength != nil && len (val ) < * schema .MinLength {
482- errs = append (errs , fmt .Sprintf ("parameter %q must have minLength %d" , name , * schema .MinLength ))
483- }
484- if schema .MaxLength != nil && len (val ) > * schema .MaxLength {
485- errs = append (errs , fmt .Sprintf ("parameter %q must have maxLength %d" , name , * schema .MaxLength ))
486- }
487- if schema .Pattern != "" {
488- if ok , _ := regexp .MatchString (schema .Pattern , val ); ! ok {
489- errs = append (errs , fmt .Sprintf ("parameter %q does not match pattern %q" , name , schema .Pattern ))
490- }
491- }
541+ errs = append (errs , validateStringConstraints (val , name , "parameter" , schema )... )
492542 }
493- // Enum validation
543+ // Enum validation: query/path parameters are always strings, so compare the
544+ // string form of each enum value against the string parameter value.
494545 if len (schema .Enum ) > 0 {
495546 found := false
496547 for _ , e := range schema .Enum {
548+ if e == nil {
549+ continue
550+ }
497551 if fmt .Sprintf ("%v" , e ) == val {
498552 found = true
499553 break
@@ -542,23 +596,16 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
542596 case "string" :
543597 s , ok := val .(string )
544598 if ! ok {
545- return []string {fmt .Sprintf ("field %q must be a string" , name )}
546- }
547- if schema .MinLength != nil && len (s ) < * schema .MinLength {
548- errs = append (errs , fmt .Sprintf ("field %q must have minLength %d" , name , * schema .MinLength ))
549- }
550- if schema .MaxLength != nil && len (s ) > * schema .MaxLength {
551- errs = append (errs , fmt .Sprintf ("field %q must have maxLength %d" , name , * schema .MaxLength ))
552- }
553- if schema .Pattern != "" {
554- if ok2 , _ := regexp .MatchString (schema .Pattern , s ); ! ok2 {
555- errs = append (errs , fmt .Sprintf ("field %q does not match pattern %q" , name , schema .Pattern ))
556- }
599+ return []string {fmt .Sprintf ("field %q must be a string, got %T" , name , val )}
557600 }
601+ errs = append (errs , validateStringConstraints (s , name , "field" , schema )... )
558602 case "integer" :
559603 f , ok := val .(float64 )
560604 if ! ok {
561- return []string {fmt .Sprintf ("field %q must be an integer" , name )}
605+ return []string {fmt .Sprintf ("field %q must be an integer, got %T" , name , val )}
606+ }
607+ if f != math .Trunc (f ) {
608+ return []string {fmt .Sprintf ("field %q must be an integer, got %v" , name , f )}
562609 }
563610 if schema .Minimum != nil && f < * schema .Minimum {
564611 errs = append (errs , fmt .Sprintf ("field %q must be >= %v" , name , * schema .Minimum ))
@@ -569,7 +616,7 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
569616 case "number" :
570617 f , ok := val .(float64 )
571618 if ! ok {
572- return []string {fmt .Sprintf ("field %q must be a number" , name )}
619+ return []string {fmt .Sprintf ("field %q must be a number, got %T " , name , val )}
573620 }
574621 if schema .Minimum != nil && f < * schema .Minimum {
575622 errs = append (errs , fmt .Sprintf ("field %q must be >= %v" , name , * schema .Minimum ))
@@ -579,21 +626,41 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
579626 }
580627 case "boolean" :
581628 if _ , ok := val .(bool ); ! ok {
582- errs = append (errs , fmt .Sprintf ("field %q must be a boolean" , name ))
629+ errs = append (errs , fmt .Sprintf ("field %q must be a boolean, got %T " , name , val ))
583630 }
584631 case "object" :
585632 if subErrs := validateJSONBody (val , schema ); len (subErrs ) > 0 {
586633 errs = append (errs , subErrs ... )
587634 }
588635 }
589- // Enum validation
636+ // Enum validation: use type-aware comparison to prevent e.g. int 1 matching string "1".
590637 if len (schema .Enum ) > 0 {
591638 found := false
592639 for _ , e := range schema .Enum {
593- if fmt .Sprintf ("%v" , e ) == fmt .Sprintf ("%v" , val ) {
640+ if e == nil {
641+ continue
642+ }
643+ // Direct equality covers string==string, bool==bool, float64==float64.
644+ if e == val {
594645 found = true
595646 break
596647 }
648+ // Handle numeric type mismatch: YAML decodes integers as int, but JSON
649+ // decodes all numbers as float64, so int(1) != float64(1) even though
650+ // they represent the same value.
651+ switch ev := e .(type ) {
652+ case int :
653+ if fv , ok := val .(float64 ); ok && float64 (ev ) == fv {
654+ found = true
655+ }
656+ case int64 :
657+ if fv , ok := val .(float64 ); ok && float64 (ev ) == fv {
658+ found = true
659+ }
660+ }
661+ if found {
662+ break
663+ }
597664 }
598665 if ! found {
599666 errs = append (errs , fmt .Sprintf ("field %q must be one of %v" , name , schema .Enum ))
@@ -604,21 +671,33 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
604671
605672// swaggerUIHTML returns a minimal, self-contained Swagger UI HTML page that
606673// loads the spec from specURL using the official Swagger UI CDN bundle.
674+ //
675+ // The base URL for the Swagger UI assets can be overridden via the
676+ // SWAGGER_UI_ASSETS_BASE_URL environment variable. If unset, it defaults to
677+ // "https://unpkg.com/swagger-ui-dist@5". This is useful for air-gapped
678+ // environments or when a local mirror is preferred.
607679func swaggerUIHTML (title , specURL string ) string {
608680 if title == "" {
609681 title = "API Documentation"
610682 }
683+ baseURL := os .Getenv ("SWAGGER_UI_ASSETS_BASE_URL" )
684+ if baseURL == "" {
685+ baseURL = "https://unpkg.com/swagger-ui-dist@5"
686+ }
687+ baseURL = strings .TrimRight (baseURL , "/" )
688+ cssURL := baseURL + "/swagger-ui.css"
689+ jsURL := baseURL + "/swagger-ui-bundle.js"
611690 return `<!DOCTYPE html>
612691<html lang="en">
613692<head>
614693 <meta charset="UTF-8">
615694 <meta name="viewport" content="width=device-width, initial-scale=1.0">
616695 <title>` + htmlEscape (title ) + `</title>
617- <link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css ">
696+ <link rel="stylesheet" href="` + htmlEscape ( cssURL ) + ` ">
618697</head>
619698<body>
620699 <div id="swagger-ui"></div>
621- <script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js "></script>
700+ <script src="` + htmlEscape ( jsURL ) + ` "></script>
622701 <script>
623702 SwaggerUIBundle({
624703 url: "` + htmlEscape (specURL ) + `",
@@ -632,11 +711,7 @@ func swaggerUIHTML(title, specURL string) string {
632711}
633712
634713// htmlEscape escapes a string for safe embedding in HTML attributes/text.
714+ // It delegates to the standard library html.EscapeString for robust escaping.
635715func htmlEscape (s string ) string {
636- s = strings .ReplaceAll (s , "&" , "&" )
637- s = strings .ReplaceAll (s , "<" , "<" )
638- s = strings .ReplaceAll (s , ">" , ">" )
639- s = strings .ReplaceAll (s , `"` , """ )
640- s = strings .ReplaceAll (s , "'" , "'" )
641- return s
716+ return html .EscapeString (s )
642717}
0 commit comments