Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions e2e_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,144 @@ func TestE2E_Middleware_CORS(t *testing.T) {
t.Log("E2E Middleware CORS: Allowed, disallowed, headers, and preflight scenarios verified")
}

// TestE2E_Middleware_CORS_FullConfig verifies that the CORS middleware factory correctly
// applies allowedHeaders, allowCredentials, maxAge, and wildcard subdomain origin matching.
func TestE2E_Middleware_CORS_FullConfig(t *testing.T) {
port := getFreePort(t)
addr := fmt.Sprintf(":%d", port)
baseURL := fmt.Sprintf("http://127.0.0.1:%d", port)

cfg := &config.WorkflowConfig{
Modules: []config.ModuleConfig{
{Name: "fc-server", Type: "http.server", Config: map[string]any{"address": addr}},
{Name: "fc-router", Type: "http.router", DependsOn: []string{"fc-server"}},
{Name: "fc-handler", Type: "http.handler", DependsOn: []string{"fc-router"}, Config: map[string]any{"contentType": "application/json"}},
{Name: "fc-cors", Type: "http.middleware.cors", DependsOn: []string{"fc-router"}, Config: map[string]any{
"allowedOrigins": []any{"*.example.com", "https://trusted.io"},
"allowedMethods": []any{"GET", "POST", "OPTIONS"},
"allowedHeaders": []any{"Authorization", "Content-Type", "X-CSRF-Token", "X-Request-Id"},
"allowCredentials": true,
"maxAge": 3600,
}},
},
Workflows: map[string]any{
"http": map[string]any{
"server": "fc-server",
"router": "fc-router",
"routes": []any{
map[string]any{
"method": "GET",
"path": "/api/fc-test",
"handler": "fc-handler",
"middlewares": []any{"fc-cors"},
},
},
},
},
Triggers: map[string]any{},
}

logger := &mockLogger{}
app := modular.NewStdApplication(modular.NewStdConfigProvider(nil), logger)
engine := NewStdEngine(app, logger)
loadAllPlugins(t, engine)
engine.RegisterWorkflowHandler(handlers.NewHTTPWorkflowHandler())

if err := engine.BuildFromConfig(cfg); err != nil {
t.Fatalf("BuildFromConfig failed: %v", err)
}

ctx := t.Context()
if err := engine.Start(ctx); err != nil {
t.Fatalf("Engine start failed: %v", err)
}
defer engine.Stop(context.Background())

waitForServer(t, baseURL, 5*time.Second)
client := &http.Client{Timeout: 5 * time.Second}

// Subtest 1: Configurable allowedHeaders are reflected
t.Run("configurable_headers", func(t *testing.T) {
req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil)
req.Header.Set("Origin", "http://app.example.com")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

acah := resp.Header.Get("Access-Control-Allow-Headers")
want := "Authorization, Content-Type, X-CSRF-Token, X-Request-Id"
if acah != want {
t.Errorf("Expected Access-Control-Allow-Headers %q, got %q", want, acah)
}
})

// Subtest 2: allowCredentials sets the Credentials header
t.Run("allow_credentials", func(t *testing.T) {
req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil)
req.Header.Set("Origin", "https://trusted.io")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

if resp.Header.Get("Access-Control-Allow-Credentials") != "true" {
t.Errorf("Expected Access-Control-Allow-Credentials: true, got %q", resp.Header.Get("Access-Control-Allow-Credentials"))
}
})

// Subtest 3: maxAge is set on responses
t.Run("max_age", func(t *testing.T) {
req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil)
req.Header.Set("Origin", "https://trusted.io")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

if resp.Header.Get("Access-Control-Max-Age") != "3600" {
t.Errorf("Expected Access-Control-Max-Age: 3600, got %q", resp.Header.Get("Access-Control-Max-Age"))
}
})

// Subtest 4: Wildcard subdomain matching
t.Run("wildcard_subdomain", func(t *testing.T) {
req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil)
req.Header.Set("Origin", "http://admin.example.com")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

acao := resp.Header.Get("Access-Control-Allow-Origin")
if acao != "http://admin.example.com" {
t.Errorf("Expected Access-Control-Allow-Origin 'http://admin.example.com', got %q", acao)
}
})

// Subtest 5: Wildcard subdomain does not match unrelated domains
t.Run("wildcard_subdomain_no_match", func(t *testing.T) {
req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil)
req.Header.Set("Origin", "http://evil.com")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

acao := resp.Header.Get("Access-Control-Allow-Origin")
if acao != "" {
t.Errorf("Expected no Access-Control-Allow-Origin for disallowed domain, got %q", acao)
}
})

t.Log("E2E Middleware CORS FullConfig: all new features verified")
}

// TestE2E_Middleware_RequestID verifies the RequestID middleware adds an
// X-Request-ID header to every response, and preserves a client-supplied one.
func TestE2E_Middleware_RequestID(t *testing.T) {
Expand Down
102 changes: 84 additions & 18 deletions module/http_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -318,19 +319,59 @@ func (m *LoggingMiddleware) RequiresServices() []modular.ServiceDependency {
return nil
}

// CORSMiddlewareConfig holds configuration for the CORS middleware.
type CORSMiddlewareConfig struct {
// AllowedOrigins is the list of origins allowed to make cross-origin requests.
// Use "*" to allow all origins. Supports wildcard subdomain patterns like "*.example.com".
AllowedOrigins []string
// AllowedMethods is the list of HTTP methods allowed in CORS requests.
AllowedMethods []string
// AllowedHeaders is the list of HTTP headers allowed in CORS requests.
// Defaults to ["Content-Type", "Authorization"] when empty.
AllowedHeaders []string
// AllowCredentials indicates whether the request can include user credentials.
// When true, the actual request Origin is reflected (never "*").
AllowCredentials bool
// MaxAge specifies how long (in seconds) the preflight response may be cached.
// Zero means no caching directive is sent.
MaxAge int
}

// CORSMiddleware provides CORS support
type CORSMiddleware struct {
name string
allowedOrigins []string
allowedMethods []string
name string
allowedOrigins []string
allowedMethods []string
allowedHeaders []string
allowCredentials bool
maxAge int
}

// NewCORSMiddleware creates a new CORS middleware
// defaultCORSHeaders is the default set of allowed headers for backward compatibility.
var defaultCORSHeaders = []string{"Content-Type", "Authorization"}

// NewCORSMiddleware creates a new CORS middleware with default allowed headers.
func NewCORSMiddleware(name string, allowedOrigins, allowedMethods []string) *CORSMiddleware {
return NewCORSMiddlewareWithConfig(name, CORSMiddlewareConfig{
AllowedOrigins: allowedOrigins,
AllowedMethods: allowedMethods,
})
}

// NewCORSMiddlewareWithConfig creates a new CORS middleware with full configuration.
// If AllowedHeaders is empty, it defaults to ["Content-Type", "Authorization"].
func NewCORSMiddlewareWithConfig(name string, cfg CORSMiddlewareConfig) *CORSMiddleware {
headers := cfg.AllowedHeaders
if len(headers) == 0 {
headers = defaultCORSHeaders
}
return &CORSMiddleware{
name: name,
allowedOrigins: allowedOrigins,
allowedMethods: allowedMethods,
name: name,
allowedOrigins: cfg.AllowedOrigins,
allowedMethods: cfg.AllowedMethods,
allowedHeaders: headers,
allowCredentials: cfg.AllowCredentials,
maxAge: cfg.MaxAge,
}
}

Expand All @@ -344,24 +385,49 @@ func (m *CORSMiddleware) Init(app modular.Application) error {
return nil
}

// corsOriginAllowed checks if the given origin is in the allowed list.
// It supports exact matching, "*" wildcard, and subdomain wildcards like "*.example.com".
// Wildcard patterns are matched against the parsed hostname only, so ports are handled correctly:
// "*.example.com" will match "http://sub.example.com:3000".
func corsOriginAllowed(origin string, allowedOrigins []string) bool {
if origin == "" {
return false
}
for _, allowed := range allowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
// Wildcard subdomain matching: "*.example.com" matches "sub.example.com" (any port).
// Parse the request origin to extract just the hostname for comparison.
if strings.HasPrefix(allowed, "*.") {
suffix := allowed[1:] // ".example.com"
u, err := url.Parse(origin)
if err == nil && strings.HasSuffix(u.Hostname(), suffix) {
return true
}
}
}
return false
}

// Process implements middleware processing
func (m *CORSMiddleware) Process(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")

// Check if origin is allowed
allowed := false
for _, allowedOrigin := range m.allowedOrigins {
if allowedOrigin == "*" || allowedOrigin == origin {
allowed = true
break
}
}

if allowed {
// Only apply CORS headers when the request includes an Origin header.
// Requests without Origin are not cross-origin requests and need no CORS response.
if origin != "" && corsOriginAllowed(origin, m.allowedOrigins) {
w.Header().Add("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowedHeaders, ", "))
if m.allowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if m.maxAge > 0 {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(m.maxAge))
}
}

// Handle preflight requests
Expand Down
Loading
Loading