Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions module/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"html"
"io"
Expand All @@ -12,6 +13,7 @@ import (
"net/http"
"os"
"regexp"
"sort"
"strconv"
"strings"

Expand All @@ -33,13 +35,18 @@ type OpenAPISwaggerUIConfig struct {

// OpenAPIConfig holds the full configuration for an OpenAPI module.
type OpenAPIConfig struct {
SpecFile string `yaml:"spec_file" json:"spec_file"`
BasePath string `yaml:"base_path" json:"base_path"`
Validation OpenAPIValidationConfig `yaml:"validation" json:"validation"`
SwaggerUI OpenAPISwaggerUIConfig `yaml:"swagger_ui" json:"swagger_ui"`
RouterName string `yaml:"router" json:"router"` // optional: explicit router to attach to
SpecFile string `yaml:"spec_file" json:"spec_file"`
BasePath string `yaml:"base_path" json:"base_path"`
Validation OpenAPIValidationConfig `yaml:"validation" json:"validation"`
SwaggerUI OpenAPISwaggerUIConfig `yaml:"swagger_ui" json:"swagger_ui"`
RouterName string `yaml:"router" json:"router"` // optional: explicit router to attach to
MaxBodyBytes int64 `yaml:"max_body_bytes" json:"max_body_bytes"` // max request body size (bytes); 0 = use default
}

// defaultMaxBodyBytes is the default request body size limit (1 MiB) applied
// when Validation.Request is enabled and MaxBodyBytes is not explicitly set.
const defaultMaxBodyBytes int64 = 1 << 20 // 1 MiB

// ---- Minimal OpenAPI v3 structs (parsed from YAML/JSON) ----

// openAPISpec is a minimal representation of an OpenAPI 3.x specification.
Expand Down Expand Up @@ -221,7 +228,7 @@ func (m *OpenAPIModule) RegisterRoutes(router HTTPRouter) {
}
}

// Serve raw spec at /openapi.json and /openapi.yaml
// Serve raw spec at /openapi.json and (when source is YAML) /openapi.yaml
if len(m.specBytes) > 0 {
// Cache the JSON representation once per registration.
if m.specJSON == nil {
Expand All @@ -233,21 +240,21 @@ func (m *OpenAPIModule) RegisterRoutes(router HTTPRouter) {
}

specPathJSON := basePath + "/openapi.json"
specPathYAML := basePath + "/openapi.yaml"

// JSON endpoint: serve re-serialised spec as JSON.
router.AddRoute(http.MethodGet, specPathJSON, &openAPISpecHandler{specJSON: m.specJSON})

// YAML endpoint: serve the original spec bytes with a content-type that
// matches the source format. JSON source files are served as application/json;
// YAML source files are served as application/yaml.
rawContentType := "application/yaml"
if trimmed := strings.TrimSpace(string(m.specBytes)); len(trimmed) > 0 && (trimmed[0] == '{' || trimmed[0] == '[') {
rawContentType = "application/json"
// YAML endpoint: only register /openapi.yaml when the source spec is YAML.
// When the source is JSON, clients can use /openapi.json instead.
trimmed := strings.TrimSpace(string(m.specBytes))
isJSONSource := len(trimmed) > 0 && (trimmed[0] == '{' || trimmed[0] == '[')
if !isJSONSource {
specPathYAML := basePath + "/openapi.yaml"
router.AddRoute(http.MethodGet, specPathYAML, &openAPIRawSpecHandler{specBytes: m.specBytes, contentType: "application/yaml"})
m.logger.Debug("OpenAPI spec endpoint registered", "module", m.name, "paths", []string{specPathJSON, specPathYAML})
} else {
m.logger.Debug("OpenAPI spec endpoint registered", "module", m.name, "paths", []string{specPathJSON})
}
router.AddRoute(http.MethodGet, specPathYAML, &openAPIRawSpecHandler{specBytes: m.specBytes, contentType: rawContentType})

m.logger.Debug("OpenAPI spec endpoint registered", "module", m.name, "paths", []string{specPathJSON, specPathYAML})
}

// Serve Swagger UI
Expand Down Expand Up @@ -365,16 +372,29 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
ct, supportedContentTypes(h.op.RequestBody.Content)))
}

// Enforce a max body size to prevent DoS via arbitrarily large payloads.
// The limit is configurable via OpenAPIConfig.MaxBodyBytes; default is 1 MiB.
maxBytes := h.module.cfg.MaxBodyBytes
if maxBytes <= 0 {
maxBytes = defaultMaxBodyBytes
}
r.Body = http.MaxBytesReader(nil, r.Body, maxBytes)

// Read the body once so we can both check for presence (when required)
// and validate against the schema. Restore it afterwards for downstream handlers.
bodyBytes, readErr := io.ReadAll(r.Body)
if readErr != nil {
h.module.logger.Error("failed to read request body for validation",
"module", h.module.name,
"path", h.specPath,
"error", readErr,
)
errs = append(errs, "failed to read request body")
var maxBytesErr *http.MaxBytesError
if errors.As(readErr, &maxBytesErr) {
errs = append(errs, fmt.Sprintf("request body exceeds maximum allowed size of %d bytes", maxBytes))
} else {
h.module.logger.Error("failed to read request body for validation",
"module", h.module.name,
"path", h.specPath,
"error", readErr,
)
errs = append(errs, "failed to read request body")
}
} else {
// Always restore body for downstream handlers using the original byte slice
// to avoid a bytes→string→bytes conversion that could corrupt non-UTF-8 payloads.
Expand All @@ -386,7 +406,7 @@ func (h *openAPIRouteHandler) validate(r *http.Request) []string {
var bodyData any
if jsonErr := json.Unmarshal(bodyBytes, &bodyData); jsonErr != nil {
errs = append(errs, fmt.Sprintf("request body contains invalid JSON: %v", jsonErr))
} else if bodyErrs := validateJSONBody(bodyData, mediaType.Schema); len(bodyErrs) > 0 {
} else if bodyErrs := validateJSONValue(bodyData, "body", mediaType.Schema); len(bodyErrs) > 0 {
errs = append(errs, bodyErrs...)
}
}
Expand Down Expand Up @@ -728,12 +748,13 @@ func htmlEscape(s string) string {
return html.EscapeString(s)
}

// supportedContentTypes returns a comma-joined list of content types defined
// in the requestBody.content map, used in validation error messages.
// supportedContentTypes returns a sorted, comma-joined list of content types
// defined in the requestBody.content map, used in validation error messages.
func supportedContentTypes(content map[string]openAPIMediaType) string {
types := make([]string, 0, len(content))
for ct := range content {
types = append(types, ct)
}
sort.Strings(types)
return strings.Join(types, ", ")
}
64 changes: 64 additions & 0 deletions module/openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,35 @@ func TestOpenAPIModule_ParseJSON(t *testing.T) {
}
}

func TestOpenAPIModule_JSONSourceNoYAMLEndpoint(t *testing.T) {
specPath := writeTempSpec(t, ".json", petstoreJSON)

mod := NewOpenAPIModule("json-api", OpenAPIConfig{
SpecFile: specPath,
BasePath: "/api",
})
if err := mod.Init(nil); err != nil {
t.Fatalf("Init: %v", err)
}

router := &testRouter{}
mod.RegisterRoutes(router)

paths := make(map[string]bool)
for _, rt := range router.routes {
paths[rt.method+":"+rt.path] = true
}

// JSON source spec: /openapi.json should be registered
if !paths["GET:/api/openapi.json"] {
t.Error("expected GET:/api/openapi.json to be registered for JSON source")
}
// /openapi.yaml should NOT be registered for a JSON source spec
if paths["GET:/api/openapi.yaml"] {
t.Error("expected GET:/api/openapi.yaml NOT to be registered for JSON source")
}
}

func TestOpenAPIModule_MissingSpecFile(t *testing.T) {
mod := NewOpenAPIModule("bad", OpenAPIConfig{})
if err := mod.Init(nil); err == nil {
Expand Down Expand Up @@ -429,6 +458,41 @@ func TestOpenAPIModule_RequestValidation_Body(t *testing.T) {
})
}

func TestOpenAPIModule_MaxBodySize(t *testing.T) {
specPath := writeTempSpec(t, ".yaml", petstoreYAML)

mod := NewOpenAPIModule("petstore", OpenAPIConfig{
SpecFile: specPath,
BasePath: "/api/v1",
Validation: OpenAPIValidationConfig{Request: true},
MaxBodyBytes: 10, // very small limit to trigger the check
})
if err := mod.Init(nil); err != nil {
t.Fatalf("Init: %v", err)
}

router := &testRouter{}
mod.RegisterRoutes(router)

h := router.findHandler("POST", "/api/v1/pets")
if h == nil {
t.Fatal("POST /api/v1/pets handler not found")
}

body := `{"name": "Fluffy", "tag": "cat"}` // 33 bytes, exceeds limit of 10
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/pets", bytes.NewBufferString(body))
r.Header.Set("Content-Type", "application/json")
h.Handle(w, r)

if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for oversized body, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "exceeds maximum") {
t.Errorf("expected error message about size limit, got: %s", w.Body.String())
}
}

func TestOpenAPIModule_NoValidation(t *testing.T) {
specPath := writeTempSpec(t, ".yaml", petstoreYAML)

Expand Down
44 changes: 37 additions & 7 deletions plugins/openapi/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory {
if v, ok := cfg["router"].(string); ok {
oacfg.RouterName = v
}
if v, ok := cfg["max_body_bytes"].(int); ok && v > 0 {
oacfg.MaxBodyBytes = int64(v)
} else if v, ok := cfg["max_body_bytes"].(int64); ok && v > 0 {
oacfg.MaxBodyBytes = v
} else if v, ok := cfg["max_body_bytes"].(float64); ok && v > 0 {
oacfg.MaxBodyBytes = int64(v)
}

if valCfg, ok := cfg["validation"].(map[string]any); ok {
if v, ok2 := valCfg["request"].(bool); ok2 {
Expand Down Expand Up @@ -151,9 +158,16 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema {
DefaultValue: map[string]any{"enabled": false, "path": "/docs"},
Group: "swagger_ui",
},
{
Key: "max_body_bytes",
Label: "Max Body Size",
Type: schema.FieldTypeNumber,
Description: "Maximum allowed request body size in bytes when validation is enabled (default: 1048576 = 1 MiB)",
Placeholder: "1048576",
Group: "validation",
},
},
DefaultConfig: map[string]any{
"base_path": "/api/v1",
"validation": map[string]any{"request": true, "response": false},
"swagger_ui": map[string]any{"enabled": false, "path": "/docs"},
},
Expand All @@ -178,12 +192,16 @@ func (p *Plugin) WiringHooks() []plugin.WiringHook {
func wireOpenAPIRoutes(app modular.Application, cfg *config.WorkflowConfig) error {
// Build name→router lookup from config dependsOn
routerNames := make(map[string]bool)
openAPIDeps := make(map[string][]string) // openapi module name → dependsOn
serverToRouter := make(map[string]string) // http.server name → router name
openAPIDeps := make(map[string][]string) // openapi module name → dependsOn
for _, modCfg := range cfg.Modules {
if modCfg.Type == "http.router" {
switch modCfg.Type {
case "http.router":
routerNames[modCfg.Name] = true
}
if modCfg.Type == "openapi" {
for _, dep := range modCfg.DependsOn {
serverToRouter[dep] = modCfg.Name
}
case "openapi":
openAPIDeps[modCfg.Name] = modCfg.DependsOn
}
}
Expand Down Expand Up @@ -214,7 +232,7 @@ func wireOpenAPIRoutes(app modular.Application, cfg *config.WorkflowConfig) erro
targetRouter = routers[rName]
}

// 2) dependsOn router reference
// 2) dependsOn: direct router reference
if targetRouter == nil {
for _, dep := range openAPIDeps[oaMod.Name()] {
if routerNames[dep] {
Expand All @@ -226,7 +244,19 @@ func wireOpenAPIRoutes(app modular.Application, cfg *config.WorkflowConfig) erro
}
}

// 3) Fall back to first available router
// 3) dependsOn: server reference → follow server→router mapping
if targetRouter == nil {
for _, dep := range openAPIDeps[oaMod.Name()] {
if rName, ok := serverToRouter[dep]; ok {
if router, found := routers[rName]; found {
targetRouter = router
break
}
}
}
}

// 4) Fall back to first available router
if targetRouter == nil {
targetRouter = firstRouter
}
Expand Down