Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/openapi-petstore.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ modules:
dependsOn:
- petstore-router
config:
# NOTE: spec_file is resolved relative to this config file's directory
# via the _config_dir mechanism in config.ResolvePathInConfig.
spec_file: specs/petstore.yaml
base_path: /api/v1
router: petstore-router
Expand Down
199 changes: 137 additions & 62 deletions module/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"encoding/json"
"fmt"
"html"
"io"
"log/slog"
"math"
"net/http"
"os"
"regexp"
Expand Down Expand Up @@ -109,7 +111,8 @@ type OpenAPIModule struct {
name string
cfg OpenAPIConfig
spec *openAPISpec
specBytes []byte // raw spec bytes for serving
specBytes []byte // raw spec bytes for serving (original file content)
specJSON []byte // cached JSON-serialised spec for /openapi.json endpoint
routerName string
logger *slog.Logger
}
Expand Down Expand Up @@ -219,11 +222,26 @@ func (m *OpenAPIModule) RegisterRoutes(router HTTPRouter) {

// Serve raw spec at /openapi.json and /openapi.yaml
if len(m.specBytes) > 0 {
// Cache the JSON representation once per registration.
if m.specJSON == nil {
specJSON, err := json.Marshal(m.spec)
if err != nil {
specJSON = m.specBytes // fallback to raw bytes
}
m.specJSON = specJSON
}

specPathJSON := basePath + "/openapi.json"
specPathYAML := basePath + "/openapi.yaml"
specHandler := m.buildSpecHandler()
router.AddRoute(http.MethodGet, specPathJSON, specHandler)
router.AddRoute(http.MethodGet, specPathYAML, specHandler)

// JSON endpoint: serve re-serialised spec as JSON.
router.AddRoute(http.MethodGet, specPathJSON, &openAPISpecHandler{specJSON: m.specJSON})

// YAML endpoint: serve the original spec bytes with a YAML content-type.
// This preserves the source format; if the original file was YAML it is
// served as YAML, and if it was JSON it is served as-is (JSON is valid YAML).
router.AddRoute(http.MethodGet, specPathYAML, &openAPIRawSpecHandler{specBytes: m.specBytes, contentType: "application/yaml"})

m.logger.Debug("OpenAPI spec endpoint registered", "module", m.name, "paths", []string{specPathJSON, specPathYAML})
}

Expand Down Expand Up @@ -257,17 +275,6 @@ func (m *OpenAPIModule) buildRouteHandler(specPath, method string, op *openAPIOp
}
}

// buildSpecHandler serves the raw spec bytes as JSON (re-serialised from the
// parsed spec) so consumers always get valid JSON regardless of whether the
// original file was YAML.
func (m *OpenAPIModule) buildSpecHandler() HTTPHandler {
specJSON, err := json.Marshal(m.spec)
if err != nil {
specJSON = m.specBytes // fallback to raw bytes
}
return &openAPISpecHandler{specJSON: specJSON}
}

// buildSwaggerUIHandler returns an inline Swagger UI page that loads the spec
// from specURL. This avoids any asset bundling — a CDN-hosted swagger-ui is used.
func (m *OpenAPIModule) buildSwaggerUIHandler(specURL string) HTTPHandler {
Expand Down Expand Up @@ -340,23 +347,36 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
if mt, ok := h.op.RequestBody.Content[ct]; ok {
mediaType = &mt
} else if mt, ok := h.op.RequestBody.Content["application/json"]; ok && ct == "" {
// Default to application/json when no Content-Type is sent
// NOTE: Intentionally treat a missing Content-Type as application/json for request body
// validation. Per HTTP semantics, an absent Content-Type would normally imply
// application/octet-stream, but this engine is primarily used for JSON APIs and this
// default simplifies client usage.
mediaType = &mt
}

if h.op.RequestBody.Required && r.ContentLength == 0 && r.Body == http.NoBody {
errs = append(errs, "request body is required but missing")
} else if mediaType != nil && mediaType.Schema != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err == nil && len(bodyBytes) > 0 {
// Read the body once so we can both check for presence (when required)
// and validate against the schema. Restore it afterwards for downstream handlers.
bodyBytes, readErr := io.ReadAll(r.Body)
if readErr != nil {
h.module.logger.Error("failed to read request body for validation",
"module", h.module.name,
"path", h.specPath,
"error", readErr,
)
errs = append(errs, "failed to read request body")
} else {
// Always restore body for downstream handlers.
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))

if h.op.RequestBody.Required && len(bodyBytes) == 0 {
errs = append(errs, "request body is required but missing")
} else if mediaType != nil && mediaType.Schema != nil && len(bodyBytes) > 0 {
var bodyData any
if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr == nil {
if bodyErrs := validateJSONBody(bodyData, mediaType.Schema); len(bodyErrs) > 0 {
errs = append(errs, bodyErrs...)
}
}
// Restore body for downstream handlers
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
}
}
}
Expand All @@ -376,6 +396,21 @@ func (h *openAPISpecHandler) Handle(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write(h.specJSON) //nolint:gosec // G705: spec JSON is loaded from a trusted config file, not user input
}

// ---- openAPIRawSpecHandler ----

// openAPIRawSpecHandler serves the raw spec bytes with the given content-type.
// Used for the /openapi.yaml endpoint to preserve the original source format.
type openAPIRawSpecHandler struct {
specBytes []byte
contentType string
}

func (h *openAPIRawSpecHandler) Handle(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", h.contentType)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(h.specBytes) //nolint:gosec // G705: spec bytes are loaded from a trusted config file, not user input
}

// ---- openAPISwaggerUIHandler ----

type openAPISwaggerUIHandler struct {
Expand All @@ -392,15 +427,18 @@ func (h *openAPISwaggerUIHandler) Handle(w http.ResponseWriter, _ *http.Request)

// parseOpenAPISpec parses a YAML or JSON byte slice into an openAPISpec.
func parseOpenAPISpec(data []byte) (*openAPISpec, error) {
if len(data) == 0 {
return nil, fmt.Errorf("openapi spec data is empty")
}
var spec openAPISpec
// Try YAML first (which also handles JSON since JSON is valid YAML)
if err := yaml.Unmarshal(data, &spec); err != nil {
return nil, fmt.Errorf("yaml parse: %w", err)
}
if spec.OpenAPI == "" {
// May be JSON that yaml couldn't decode properly; try JSON directly
// YAML parse succeeded but produced an empty OpenAPI field; try JSON directly.
if err := json.Unmarshal(data, &spec); err != nil {
return nil, fmt.Errorf("neither yaml nor json parse succeeded: %w", err)
return nil, fmt.Errorf("yaml parse produced empty OpenAPI field; json parse also failed: %w", err)
}
}
return &spec, nil
Expand Down Expand Up @@ -445,6 +483,28 @@ func extractParam(r *http.Request, p openAPIParameter) string {
return ""
}

// validateStringConstraints validates the string constraints (minLength, maxLength,
// pattern) for a string value. The kind parameter ("parameter" or "field") is used
// in error messages.
func validateStringConstraints(s, name, kind string, schema *openAPISchema) []string {
var errs []string
if schema.MinLength != nil && len(s) < *schema.MinLength {
errs = append(errs, fmt.Sprintf("%s %q must have minLength %d", kind, name, *schema.MinLength))
}
if schema.MaxLength != nil && len(s) > *schema.MaxLength {
errs = append(errs, fmt.Sprintf("%s %q must have maxLength %d", kind, name, *schema.MaxLength))
}
if schema.Pattern != "" {
re, err := regexp.Compile(schema.Pattern)
if err != nil {
errs = append(errs, fmt.Sprintf("%s %q has an invalid pattern %q: %v", kind, name, schema.Pattern, err))
} else if !re.MatchString(s) {
errs = append(errs, fmt.Sprintf("%s %q does not match pattern %q", kind, name, schema.Pattern))
}
}
return errs
}

// validateScalarValue validates a string value against a schema (type/format/enum checks).
func validateScalarValue(val, name string, schema *openAPISchema) []string {
var errs []string
Expand Down Expand Up @@ -478,22 +538,16 @@ func validateScalarValue(val, name string, schema *openAPISchema) []string {
errs = append(errs, fmt.Sprintf("parameter %q must be 'true' or 'false', got %q", name, val))
}
case "string":
if schema.MinLength != nil && len(val) < *schema.MinLength {
errs = append(errs, fmt.Sprintf("parameter %q must have minLength %d", name, *schema.MinLength))
}
if schema.MaxLength != nil && len(val) > *schema.MaxLength {
errs = append(errs, fmt.Sprintf("parameter %q must have maxLength %d", name, *schema.MaxLength))
}
if schema.Pattern != "" {
if ok, _ := regexp.MatchString(schema.Pattern, val); !ok {
errs = append(errs, fmt.Sprintf("parameter %q does not match pattern %q", name, schema.Pattern))
}
}
errs = append(errs, validateStringConstraints(val, name, "parameter", schema)...)
}
// Enum validation
// Enum validation: query/path parameters are always strings, so compare the
// string form of each enum value against the string parameter value.
if len(schema.Enum) > 0 {
found := false
for _, e := range schema.Enum {
if e == nil {
continue
}
if fmt.Sprintf("%v", e) == val {
found = true
break
Expand Down Expand Up @@ -542,23 +596,16 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
case "string":
s, ok := val.(string)
if !ok {
return []string{fmt.Sprintf("field %q must be a string", name)}
}
if schema.MinLength != nil && len(s) < *schema.MinLength {
errs = append(errs, fmt.Sprintf("field %q must have minLength %d", name, *schema.MinLength))
}
if schema.MaxLength != nil && len(s) > *schema.MaxLength {
errs = append(errs, fmt.Sprintf("field %q must have maxLength %d", name, *schema.MaxLength))
}
if schema.Pattern != "" {
if ok2, _ := regexp.MatchString(schema.Pattern, s); !ok2 {
errs = append(errs, fmt.Sprintf("field %q does not match pattern %q", name, schema.Pattern))
}
return []string{fmt.Sprintf("field %q must be a string, got %T", name, val)}
}
errs = append(errs, validateStringConstraints(s, name, "field", schema)...)
case "integer":
f, ok := val.(float64)
if !ok {
return []string{fmt.Sprintf("field %q must be an integer", name)}
return []string{fmt.Sprintf("field %q must be an integer, got %T", name, val)}
}
if f != math.Trunc(f) {
return []string{fmt.Sprintf("field %q must be an integer, got %v", name, f)}
}
if schema.Minimum != nil && f < *schema.Minimum {
errs = append(errs, fmt.Sprintf("field %q must be >= %v", name, *schema.Minimum))
Expand All @@ -569,7 +616,7 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
case "number":
f, ok := val.(float64)
if !ok {
return []string{fmt.Sprintf("field %q must be a number", name)}
return []string{fmt.Sprintf("field %q must be a number, got %T", name, val)}
}
if schema.Minimum != nil && f < *schema.Minimum {
errs = append(errs, fmt.Sprintf("field %q must be >= %v", name, *schema.Minimum))
Expand All @@ -579,21 +626,41 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {
}
case "boolean":
if _, ok := val.(bool); !ok {
errs = append(errs, fmt.Sprintf("field %q must be a boolean", name))
errs = append(errs, fmt.Sprintf("field %q must be a boolean, got %T", name, val))
}
case "object":
if subErrs := validateJSONBody(val, schema); len(subErrs) > 0 {
errs = append(errs, subErrs...)
}
}
// Enum validation
// Enum validation: use type-aware comparison to prevent e.g. int 1 matching string "1".
if len(schema.Enum) > 0 {
found := false
for _, e := range schema.Enum {
if fmt.Sprintf("%v", e) == fmt.Sprintf("%v", val) {
if e == nil {
continue
}
// Direct equality covers string==string, bool==bool, float64==float64.
if e == val {
found = true
break
}
// Handle numeric type mismatch: YAML decodes integers as int, but JSON
// decodes all numbers as float64, so int(1) != float64(1) even though
// they represent the same value.
switch ev := e.(type) {
case int:
if fv, ok := val.(float64); ok && float64(ev) == fv {
found = true
}
case int64:
if fv, ok := val.(float64); ok && float64(ev) == fv {
found = true
}
}
if found {
break
}
}
if !found {
errs = append(errs, fmt.Sprintf("field %q must be one of %v", name, schema.Enum))
Expand All @@ -604,21 +671,33 @@ func validateJSONValue(val any, name string, schema *openAPISchema) []string {

// swaggerUIHTML returns a minimal, self-contained Swagger UI HTML page that
// loads the spec from specURL using the official Swagger UI CDN bundle.
//
// The base URL for the Swagger UI assets can be overridden via the
// SWAGGER_UI_ASSETS_BASE_URL environment variable. If unset, it defaults to
// "https://unpkg.com/swagger-ui-dist@5". This is useful for air-gapped
// environments or when a local mirror is preferred.
func swaggerUIHTML(title, specURL string) string {
if title == "" {
title = "API Documentation"
}
baseURL := os.Getenv("SWAGGER_UI_ASSETS_BASE_URL")
if baseURL == "" {
baseURL = "https://unpkg.com/swagger-ui-dist@5"
}
baseURL = strings.TrimRight(baseURL, "/")
cssURL := baseURL + "/swagger-ui.css"
jsURL := baseURL + "/swagger-ui-bundle.js"
return `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>` + htmlEscape(title) + `</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
<link rel="stylesheet" href="` + htmlEscape(cssURL) + `">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
<script src="` + htmlEscape(jsURL) + `"></script>
<script>
SwaggerUIBundle({
url: "` + htmlEscape(specURL) + `",
Expand All @@ -632,11 +711,7 @@ func swaggerUIHTML(title, specURL string) string {
}

// htmlEscape escapes a string for safe embedding in HTML attributes/text.
// It delegates to the standard library html.EscapeString for robust escaping.
func htmlEscape(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
s = strings.ReplaceAll(s, `"`, "&#34;")
s = strings.ReplaceAll(s, "'", "&#39;")
return s
return html.EscapeString(s)
}
Loading