From 1f3b746847f047c628b839e3e81c37b7db53d164 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Mar 2026 22:52:53 +0000 Subject: [PATCH 1/3] Initial plan From 237351f9bed77d18d2df7b9b606cf5192d7d6b61 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Mar 2026 23:05:47 +0000 Subject: [PATCH 2/3] feat: support configurable CORS headers, credentials, maxAge, and wildcard subdomain origins Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- e2e_middleware_test.go | 138 +++++++++++++++++++++++ module/http_middleware.go | 88 ++++++++++++--- module/http_middleware_test.go | 166 +++++++++++++++++++++++++++- module/platform_do_database.go | 6 +- module/platform_do_database_test.go | 12 +- module/scan_provider_test.go | 10 +- plugins/http/modules.go | 30 +++-- 7 files changed, 413 insertions(+), 37 deletions(-) diff --git a/e2e_middleware_test.go b/e2e_middleware_test.go index c7d5c713..e0ff7693 100644 --- a/e2e_middleware_test.go +++ b/e2e_middleware_test.go @@ -416,6 +416,144 @@ func TestE2E_Middleware_CORS(t *testing.T) { t.Log("E2E Middleware CORS: Allowed, disallowed, headers, and preflight scenarios verified") } +// TestE2E_Middleware_CORS_FullConfig verifies that the CORS middleware factory correctly +// applies allowedHeaders, allowCredentials, maxAge, and wildcard subdomain origin matching. +func TestE2E_Middleware_CORS_FullConfig(t *testing.T) { + port := getFreePort(t) + addr := fmt.Sprintf(":%d", port) + baseURL := fmt.Sprintf("http://127.0.0.1:%d", port) + + cfg := &config.WorkflowConfig{ + Modules: []config.ModuleConfig{ + {Name: "fc-server", Type: "http.server", Config: map[string]any{"address": addr}}, + {Name: "fc-router", Type: "http.router", DependsOn: []string{"fc-server"}}, + {Name: "fc-handler", Type: "http.handler", DependsOn: []string{"fc-router"}, Config: map[string]any{"contentType": "application/json"}}, + {Name: "fc-cors", Type: "http.middleware.cors", DependsOn: []string{"fc-router"}, Config: map[string]any{ + "allowedOrigins": []any{"*.example.com", "https://trusted.io"}, + "allowedMethods": []any{"GET", "POST", "OPTIONS"}, + "allowedHeaders": []any{"Authorization", "Content-Type", "X-CSRF-Token", "X-Request-Id"}, + "allowCredentials": true, + "maxAge": 3600, + }}, + }, + Workflows: map[string]any{ + "http": map[string]any{ + "server": "fc-server", + "router": "fc-router", + "routes": []any{ + map[string]any{ + "method": "GET", + "path": "/api/fc-test", + "handler": "fc-handler", + "middlewares": []any{"fc-cors"}, + }, + }, + }, + }, + Triggers: map[string]any{}, + } + + logger := &mockLogger{} + app := modular.NewStdApplication(modular.NewStdConfigProvider(nil), logger) + engine := NewStdEngine(app, logger) + loadAllPlugins(t, engine) + engine.RegisterWorkflowHandler(handlers.NewHTTPWorkflowHandler()) + + if err := engine.BuildFromConfig(cfg); err != nil { + t.Fatalf("BuildFromConfig failed: %v", err) + } + + ctx := t.Context() + if err := engine.Start(ctx); err != nil { + t.Fatalf("Engine start failed: %v", err) + } + defer engine.Stop(context.Background()) + + waitForServer(t, baseURL, 5*time.Second) + client := &http.Client{Timeout: 5 * time.Second} + + // Subtest 1: Configurable allowedHeaders are reflected + t.Run("configurable_headers", func(t *testing.T) { + req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil) + req.Header.Set("Origin", "http://app.example.com") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + acah := resp.Header.Get("Access-Control-Allow-Headers") + want := "Authorization, Content-Type, X-CSRF-Token, X-Request-Id" + if acah != want { + t.Errorf("Expected Access-Control-Allow-Headers %q, got %q", want, acah) + } + }) + + // Subtest 2: allowCredentials sets the Credentials header + t.Run("allow_credentials", func(t *testing.T) { + req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil) + req.Header.Set("Origin", "https://trusted.io") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.Header.Get("Access-Control-Allow-Credentials") != "true" { + t.Errorf("Expected Access-Control-Allow-Credentials: true, got %q", resp.Header.Get("Access-Control-Allow-Credentials")) + } + }) + + // Subtest 3: maxAge is set on responses + t.Run("max_age", func(t *testing.T) { + req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil) + req.Header.Set("Origin", "https://trusted.io") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.Header.Get("Access-Control-Max-Age") != "3600" { + t.Errorf("Expected Access-Control-Max-Age: 3600, got %q", resp.Header.Get("Access-Control-Max-Age")) + } + }) + + // Subtest 4: Wildcard subdomain matching + t.Run("wildcard_subdomain", func(t *testing.T) { + req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil) + req.Header.Set("Origin", "http://admin.example.com") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + acao := resp.Header.Get("Access-Control-Allow-Origin") + if acao != "http://admin.example.com" { + t.Errorf("Expected Access-Control-Allow-Origin 'http://admin.example.com', got %q", acao) + } + }) + + // Subtest 5: Wildcard subdomain does not match unrelated domains + t.Run("wildcard_subdomain_no_match", func(t *testing.T) { + req, _ := http.NewRequest("GET", baseURL+"/api/fc-test", nil) + req.Header.Set("Origin", "http://evil.com") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + acao := resp.Header.Get("Access-Control-Allow-Origin") + if acao != "" { + t.Errorf("Expected no Access-Control-Allow-Origin for disallowed domain, got %q", acao) + } + }) + + t.Log("E2E Middleware CORS FullConfig: all new features verified") +} + // TestE2E_Middleware_RequestID verifies the RequestID middleware adds an // X-Request-ID header to every response, and preserves a client-supplied one. func TestE2E_Middleware_RequestID(t *testing.T) { diff --git a/module/http_middleware.go b/module/http_middleware.go index 7eaa2ff5..0272ff7a 100644 --- a/module/http_middleware.go +++ b/module/http_middleware.go @@ -318,19 +318,59 @@ func (m *LoggingMiddleware) RequiresServices() []modular.ServiceDependency { return nil } +// CORSMiddlewareConfig holds configuration for the CORS middleware. +type CORSMiddlewareConfig struct { + // AllowedOrigins is the list of origins allowed to make cross-origin requests. + // Use "*" to allow all origins. Supports wildcard subdomain patterns like "*.example.com". + AllowedOrigins []string + // AllowedMethods is the list of HTTP methods allowed in CORS requests. + AllowedMethods []string + // AllowedHeaders is the list of HTTP headers allowed in CORS requests. + // Defaults to ["Content-Type", "Authorization"] when empty. + AllowedHeaders []string + // AllowCredentials indicates whether the request can include user credentials. + // When true, the actual request Origin is reflected (never "*"). + AllowCredentials bool + // MaxAge specifies how long (in seconds) the preflight response may be cached. + // Zero means no caching directive is sent. + MaxAge int +} + // CORSMiddleware provides CORS support type CORSMiddleware struct { - name string - allowedOrigins []string - allowedMethods []string + name string + allowedOrigins []string + allowedMethods []string + allowedHeaders []string + allowCredentials bool + maxAge int } -// NewCORSMiddleware creates a new CORS middleware +// defaultCORSHeaders is the default set of allowed headers for backward compatibility. +var defaultCORSHeaders = []string{"Content-Type", "Authorization"} + +// NewCORSMiddleware creates a new CORS middleware with default allowed headers. func NewCORSMiddleware(name string, allowedOrigins, allowedMethods []string) *CORSMiddleware { + return NewCORSMiddlewareWithConfig(name, CORSMiddlewareConfig{ + AllowedOrigins: allowedOrigins, + AllowedMethods: allowedMethods, + }) +} + +// NewCORSMiddlewareWithConfig creates a new CORS middleware with full configuration. +// If AllowedHeaders is empty, it defaults to ["Content-Type", "Authorization"]. +func NewCORSMiddlewareWithConfig(name string, cfg CORSMiddlewareConfig) *CORSMiddleware { + headers := cfg.AllowedHeaders + if len(headers) == 0 { + headers = defaultCORSHeaders + } return &CORSMiddleware{ - name: name, - allowedOrigins: allowedOrigins, - allowedMethods: allowedMethods, + name: name, + allowedOrigins: cfg.AllowedOrigins, + allowedMethods: cfg.AllowedMethods, + allowedHeaders: headers, + allowCredentials: cfg.AllowCredentials, + maxAge: cfg.MaxAge, } } @@ -344,24 +384,42 @@ func (m *CORSMiddleware) Init(app modular.Application) error { return nil } +// corsOriginAllowed checks if the given origin is in the allowed list. +// It supports exact matching, "*" wildcard, and subdomain wildcards like "*.example.com". +func corsOriginAllowed(origin string, allowedOrigins []string) bool { + for _, allowed := range allowedOrigins { + if allowed == "*" || allowed == origin { + return true + } + // Wildcard subdomain matching: "*.example.com" matches "sub.example.com" + if strings.HasPrefix(allowed, "*.") { + suffix := allowed[1:] // ".example.com" + if strings.HasSuffix(origin, suffix) { + return true + } + } + } + return false +} + // Process implements middleware processing func (m *CORSMiddleware) Process(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Check if origin is allowed - allowed := false - for _, allowedOrigin := range m.allowedOrigins { - if allowedOrigin == "*" || allowedOrigin == origin { - allowed = true - break - } - } + allowed := corsOriginAllowed(origin, m.allowedOrigins) if allowed { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowedMethods, ", ")) - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowedHeaders, ", ")) + if m.allowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + if m.maxAge > 0 { + w.Header().Set("Access-Control-Max-Age", strconv.Itoa(m.maxAge)) + } } // Handle preflight requests diff --git a/module/http_middleware_test.go b/module/http_middleware_test.go index 60e3f72d..d9da7e32 100644 --- a/module/http_middleware_test.go +++ b/module/http_middleware_test.go @@ -272,7 +272,171 @@ func TestCORSMiddleware_Process_Preflight(t *testing.T) { } } -func TestCORSMiddleware_ProvidesServices(t *testing.T) { +func TestCORSMiddlewareWithConfig_AllowedHeaders(t *testing.T) { + m := NewCORSMiddlewareWithConfig("cors", CORSMiddlewareConfig{ + AllowedOrigins: []string{"http://localhost:3000"}, + AllowedMethods: []string{"GET", "POST"}, + AllowedHeaders: []string{"Content-Type", "Authorization", "X-CSRF-Token", "X-Request-Id"}, + }) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://localhost:3000") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + got := rec.Header().Get("Access-Control-Allow-Headers") + want := "Content-Type, Authorization, X-CSRF-Token, X-Request-Id" + if got != want { + t.Errorf("expected Access-Control-Allow-Headers %q, got %q", want, got) + } +} + +func TestCORSMiddlewareWithConfig_DefaultHeaders(t *testing.T) { + // When AllowedHeaders is omitted, defaults to Content-Type and Authorization. + m := NewCORSMiddlewareWithConfig("cors", CORSMiddlewareConfig{ + AllowedOrigins: []string{"http://localhost:3000"}, + AllowedMethods: []string{"GET"}, + }) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://localhost:3000") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + got := rec.Header().Get("Access-Control-Allow-Headers") + want := "Content-Type, Authorization" + if got != want { + t.Errorf("expected default Access-Control-Allow-Headers %q, got %q", want, got) + } +} + +func TestCORSMiddlewareWithConfig_AllowCredentials(t *testing.T) { + m := NewCORSMiddlewareWithConfig("cors", CORSMiddlewareConfig{ + AllowedOrigins: []string{"http://app.example.com"}, + AllowedMethods: []string{"GET", "POST"}, + AllowCredentials: true, + }) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://app.example.com") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("Access-Control-Allow-Origin") != "http://app.example.com" { + t.Errorf("expected origin reflected, got %q", rec.Header().Get("Access-Control-Allow-Origin")) + } + if rec.Header().Get("Access-Control-Allow-Credentials") != "true" { + t.Errorf("expected Access-Control-Allow-Credentials: true, got %q", rec.Header().Get("Access-Control-Allow-Credentials")) + } +} + +func TestCORSMiddlewareWithConfig_NoCredentialsFlagNotSet(t *testing.T) { + m := NewCORSMiddlewareWithConfig("cors", CORSMiddlewareConfig{ + AllowedOrigins: []string{"http://app.example.com"}, + AllowedMethods: []string{"GET"}, + }) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://app.example.com") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("expected no Access-Control-Allow-Credentials header, got %q", rec.Header().Get("Access-Control-Allow-Credentials")) + } +} + +func TestCORSMiddlewareWithConfig_MaxAge(t *testing.T) { + m := NewCORSMiddlewareWithConfig("cors", CORSMiddlewareConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST"}, + MaxAge: 3600, + }) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://anything.com") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("Access-Control-Max-Age") != "3600" { + t.Errorf("expected Access-Control-Max-Age: 3600, got %q", rec.Header().Get("Access-Control-Max-Age")) + } +} + +func TestCORSMiddlewareWithConfig_NoMaxAge(t *testing.T) { + m := NewCORSMiddlewareWithConfig("cors", CORSMiddlewareConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET"}, + }) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://anything.com") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("Access-Control-Max-Age") != "" { + t.Errorf("expected no Access-Control-Max-Age header, got %q", rec.Header().Get("Access-Control-Max-Age")) + } +} + +func TestCORSMiddlewareWithConfig_WildcardSubdomain(t *testing.T) { + m := NewCORSMiddlewareWithConfig("cors", CORSMiddlewareConfig{ + AllowedOrigins: []string{"*.example.com"}, + AllowedMethods: []string{"GET"}, + }) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + origin string + allowed bool + }{ + {"http://app.example.com", true}, + {"http://admin.example.com", true}, + {"http://evil.com", false}, + {"http://notexample.com", false}, + } + + for _, tt := range tests { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", tt.origin) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + hasHeader := rec.Header().Get("Access-Control-Allow-Origin") != "" + if hasHeader != tt.allowed { + t.Errorf("origin %q: expected allowed=%v, got header=%q", tt.origin, tt.allowed, rec.Header().Get("Access-Control-Allow-Origin")) + } + } +} + +func TestCORSMiddlewareWithConfig_ProvidesServices(t *testing.T) { m := NewCORSMiddleware("cors-mw", nil, nil) svcs := m.ProvidesServices() if len(svcs) != 1 { diff --git a/module/platform_do_database.go b/module/platform_do_database.go index 8b550f33..ea44c7a5 100644 --- a/module/platform_do_database.go +++ b/module/platform_do_database.go @@ -13,12 +13,12 @@ import ( type DODatabaseState struct { ID string `json:"id"` Name string `json:"name"` - Engine string `json:"engine"` // pg, mysql, redis, mongodb, kafka + Engine string `json:"engine"` // pg, mysql, redis, mongodb, kafka Version string `json:"version"` - Size string `json:"size"` // e.g. db-s-1vcpu-1gb + Size string `json:"size"` // e.g. db-s-1vcpu-1gb Region string `json:"region"` NumNodes int `json:"numNodes"` - Status string `json:"status"` // pending, online, resizing, migrating, error + Status string `json:"status"` // pending, online, resizing, migrating, error Host string `json:"host"` Port int `json:"port"` DatabaseName string `json:"databaseName"` diff --git a/module/platform_do_database_test.go b/module/platform_do_database_test.go index 52e36f98..b1a28dcd 100644 --- a/module/platform_do_database_test.go +++ b/module/platform_do_database_test.go @@ -6,13 +6,13 @@ func TestPlatformDODatabase_MockBackend(t *testing.T) { m := &PlatformDODatabase{ name: "test-db", config: map[string]any{ - "provider": "mock", - "engine": "pg", - "version": "16", - "size": "db-s-1vcpu-1gb", - "region": "nyc1", + "provider": "mock", + "engine": "pg", + "version": "16", + "size": "db-s-1vcpu-1gb", + "region": "nyc1", "num_nodes": 1, - "name": "test-db", + "name": "test-db", }, state: &DODatabaseState{ Name: "test-db", diff --git a/module/scan_provider_test.go b/module/scan_provider_test.go index d6b83439..187e6d86 100644 --- a/module/scan_provider_test.go +++ b/module/scan_provider_test.go @@ -58,7 +58,7 @@ func (a *scanMockApp) GetService(name string, target any) error { return nil } -func (a *scanMockApp) RegisterService(name string, svc any) error { a.services[name] = svc; return nil } +func (a *scanMockApp) RegisterService(name string, svc any) error { a.services[name] = svc; return nil } func (a *scanMockApp) RegisterConfigSection(string, modular.ConfigProvider) {} func (a *scanMockApp) GetConfigSection(string) (modular.ConfigProvider, error) { return nil, nil @@ -67,7 +67,7 @@ func (a *scanMockApp) ConfigSections() map[string]modular.ConfigProvider { retur func (a *scanMockApp) Logger() modular.Logger { return nil } func (a *scanMockApp) SetLogger(modular.Logger) {} func (a *scanMockApp) ConfigProvider() modular.ConfigProvider { return nil } -func (a *scanMockApp) SvcRegistry() modular.ServiceRegistry { return a.services } +func (a *scanMockApp) SvcRegistry() modular.ServiceRegistry { return a.services } func (a *scanMockApp) RegisterModule(modular.Module) {} func (a *scanMockApp) Init() error { return nil } func (a *scanMockApp) Start() error { return nil } @@ -83,9 +83,9 @@ func (a *scanMockApp) GetServiceEntry(string) (*modular.ServiceRegistryEntry, bo func (a *scanMockApp) GetServicesByInterface(_ reflect.Type) []*modular.ServiceRegistryEntry { return nil } -func (a *scanMockApp) GetModule(string) modular.Module { return nil } -func (a *scanMockApp) GetAllModules() map[string]modular.Module { return nil } -func (a *scanMockApp) StartTime() time.Time { return time.Time{} } +func (a *scanMockApp) GetModule(string) modular.Module { return nil } +func (a *scanMockApp) GetAllModules() map[string]modular.Module { return nil } +func (a *scanMockApp) StartTime() time.Time { return time.Time{} } func (a *scanMockApp) OnConfigLoaded(func(modular.Application) error) {} func newScanApp(provider SecurityScannerProvider) *scanMockApp { diff --git a/plugins/http/modules.go b/plugins/http/modules.go index 84be5730..8df90ba7 100644 --- a/plugins/http/modules.go +++ b/plugins/http/modules.go @@ -173,25 +173,41 @@ func rateLimitMiddlewareFactory(name string, cfg map[string]any) modular.Module } func corsMiddlewareFactory(name string, cfg map[string]any) modular.Module { - allowedOrigins := []string{"*"} - allowedMethods := []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"} + corsCfg := module.CORSMiddlewareConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + } if origins, ok := cfg["allowedOrigins"].([]any); ok { - allowedOrigins = make([]string, len(origins)) + corsCfg.AllowedOrigins = make([]string, len(origins)) for i, origin := range origins { if str, ok := origin.(string); ok { - allowedOrigins[i] = str + corsCfg.AllowedOrigins[i] = str } } } if methods, ok := cfg["allowedMethods"].([]any); ok { - allowedMethods = make([]string, len(methods)) + corsCfg.AllowedMethods = make([]string, len(methods)) for i, method := range methods { if str, ok := method.(string); ok { - allowedMethods[i] = str + corsCfg.AllowedMethods[i] = str + } + } + } + if headers, ok := cfg["allowedHeaders"].([]any); ok { + corsCfg.AllowedHeaders = make([]string, len(headers)) + for i, header := range headers { + if str, ok := header.(string); ok { + corsCfg.AllowedHeaders[i] = str } } } - return module.NewCORSMiddleware(name, allowedOrigins, allowedMethods) + if allowCreds, ok := cfg["allowCredentials"].(bool); ok { + corsCfg.AllowCredentials = allowCreds + } + if maxAge, ok := cfg["maxAge"].(int); ok { + corsCfg.MaxAge = maxAge + } + return module.NewCORSMiddlewareWithConfig(name, corsCfg) } func requestIDMiddlewareFactory(name string, _ map[string]any) modular.Module { From 1eb012bcd6b092a77d58ff2a5c9e9aa1aebc4e2e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:25:04 +0000 Subject: [PATCH 3/3] fix: address CORS middleware review comments - Vary header, empty origin guard, port-aware wildcard, float64 maxAge, schema update Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- example/go.mod | 14 +++++------ example/go.sum | 28 +++++++++++----------- module/http_middleware.go | 20 +++++++++++----- module/http_middleware_test.go | 43 +++++++++++++++++++++++++++++++++- plugins/http/modules.go | 2 ++ plugins/http/schemas.go | 12 +++++++--- 6 files changed, 88 insertions(+), 31 deletions(-) diff --git a/example/go.mod b/example/go.mod index b00541e5..ae6989fb 100644 --- a/example/go.mod +++ b/example/go.mod @@ -5,7 +5,7 @@ go 1.26.0 replace github.com/GoCodeAlone/workflow => ../ require ( - github.com/GoCodeAlone/modular v1.12.0 + github.com/GoCodeAlone/modular v1.12.3 github.com/GoCodeAlone/workflow v0.0.0-00010101000000-000000000000 ) @@ -20,12 +20,12 @@ require ( cloud.google.com/go/storage v1.60.0 // indirect github.com/BurntSushi/toml v1.6.0 // indirect github.com/DataDog/datadog-go/v5 v5.4.0 // indirect - github.com/GoCodeAlone/modular/modules/auth v1.12.0 // indirect - github.com/GoCodeAlone/modular/modules/cache v1.12.0 // indirect - github.com/GoCodeAlone/modular/modules/eventbus/v2 v2.5.0 // indirect - github.com/GoCodeAlone/modular/modules/jsonschema v1.12.0 // indirect - github.com/GoCodeAlone/modular/modules/reverseproxy/v2 v2.5.0 // indirect - github.com/GoCodeAlone/modular/modules/scheduler v1.12.0 // indirect + github.com/GoCodeAlone/modular/modules/auth v1.14.0 // indirect + github.com/GoCodeAlone/modular/modules/cache v1.14.0 // indirect + github.com/GoCodeAlone/modular/modules/eventbus/v2 v2.7.0 // indirect + github.com/GoCodeAlone/modular/modules/jsonschema v1.14.0 // indirect + github.com/GoCodeAlone/modular/modules/reverseproxy/v2 v2.7.0 // indirect + github.com/GoCodeAlone/modular/modules/scheduler v1.14.0 // indirect github.com/GoCodeAlone/yaegi v0.17.1 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect diff --git a/example/go.sum b/example/go.sum index f1c81b9c..b872c630 100644 --- a/example/go.sum +++ b/example/go.sum @@ -30,20 +30,20 @@ github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2 github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/DataDog/datadog-go/v5 v5.4.0 h1:Ea3eXUVwrVV28F/fo3Dr3aa+TL/Z7Xi6SUPKW8L99aI= github.com/DataDog/datadog-go/v5 v5.4.0/go.mod h1:K9kcYBlxkcPP8tvvjZZKs/m1edNAUFzBbdpTUKfCsuw= -github.com/GoCodeAlone/modular v1.12.0 h1:C4tLfJe65rrUQsbtndiVfldtT8IRKZcHczNRNbBK4wo= -github.com/GoCodeAlone/modular v1.12.0/go.mod h1:ET7mlekRjkRq9mwJdWmaC2KDUWvjla2IqKVFrYO2JnY= -github.com/GoCodeAlone/modular/modules/auth v1.12.0 h1:eO4iq8tkz8W5sLKRSG5dC+ACITMtxZrtSJ+ReE3fKdA= -github.com/GoCodeAlone/modular/modules/auth v1.12.0/go.mod h1:D+yfkgN3MTkyl1xe8h2UL7uqB9Vj1lO3wUrscfnJ/NU= -github.com/GoCodeAlone/modular/modules/cache v1.12.0 h1:Ue6aXytFq1I+OnC3PcV2KlUg4lHiuGWH0Qq+v/lqyp0= -github.com/GoCodeAlone/modular/modules/cache v1.12.0/go.mod h1:kSaT8wNy/3YGmtIpDqPbW6MRqKOp2yc8a5MHdAag2CE= -github.com/GoCodeAlone/modular/modules/eventbus/v2 v2.5.0 h1:K6X+X1sOq+lpI1Oa+XUzH+GlSRYJQfDTTcvMjZfkbFU= -github.com/GoCodeAlone/modular/modules/eventbus/v2 v2.5.0/go.mod h1:Q0TpCFTtd0q20okDyi63ALS+1xmkYU4wNUOqwczyih0= -github.com/GoCodeAlone/modular/modules/jsonschema v1.12.0 h1:urGK8Xtwku4tn8nBeVZn9UqvldnCptZ3rLCXO21vSz4= -github.com/GoCodeAlone/modular/modules/jsonschema v1.12.0/go.mod h1:+/0p1alfSbhhshcNRId1HRRIupeu0DPC7BH8AYiBQ1I= -github.com/GoCodeAlone/modular/modules/reverseproxy/v2 v2.5.0 h1:zcF46oZ7MJFfZCmzqc1n9ZTw6wrTJSFr04yaz6EYKeo= -github.com/GoCodeAlone/modular/modules/reverseproxy/v2 v2.5.0/go.mod h1:ycmJYst0dgaeLYBDOFGYz3ZiVK0fVcbl59omBySpKis= -github.com/GoCodeAlone/modular/modules/scheduler v1.12.0 h1:kxeLUpFFZ2HWV5B7Ra1WaOr1DDee5G6kAZ6F1BUXX/Y= -github.com/GoCodeAlone/modular/modules/scheduler v1.12.0/go.mod h1:VpDSAU0Guj8geVz19YCSknyCJp0j3TMBaxLEYXedkZc= +github.com/GoCodeAlone/modular v1.12.3 h1:WcNqc1ZG+Lv/xzF8wTDavGIOeAvlV4wEd5HO2mVTUwE= +github.com/GoCodeAlone/modular v1.12.3/go.mod h1:nDdyW/eJu4gDFNueb6vWwLvti3bPHSZJHkWGiwEmi2I= +github.com/GoCodeAlone/modular/modules/auth v1.14.0 h1:Y+p4/HIcxkajlcNhcPlqpwAt1SCHjB4AaDMEys50E3I= +github.com/GoCodeAlone/modular/modules/auth v1.14.0/go.mod h1:fkwPn2svDsCHBI19gtUHxo064SL+EudjB+o7VjL9ug8= +github.com/GoCodeAlone/modular/modules/cache v1.14.0 h1:ykQRwXJGXaRtAsnW9Tgs0LvXExonkKr8P7XIHxPaYdY= +github.com/GoCodeAlone/modular/modules/cache v1.14.0/go.mod h1:tcIjHJHZ5fVU8sstILrXeVQgjpZcUkErnNjRaxkBSR8= +github.com/GoCodeAlone/modular/modules/eventbus/v2 v2.7.0 h1:clGAyaOfyDc9iY63ONfZiHReVccVhK/yH19QEb14SSI= +github.com/GoCodeAlone/modular/modules/eventbus/v2 v2.7.0/go.mod h1:0AnfWGVmrqyv91rduc6mrPqW6WQchDAa2WtM0Qmw/WA= +github.com/GoCodeAlone/modular/modules/jsonschema v1.14.0 h1:dCiPIO+NvJPizfCeUQqGXHD1WitOVYpKuL3fxMEjRlw= +github.com/GoCodeAlone/modular/modules/jsonschema v1.14.0/go.mod h1:5Hm+R9G41wwb0hKefx9+9PMqffjU1tA7roW3t3sTaLE= +github.com/GoCodeAlone/modular/modules/reverseproxy/v2 v2.7.0 h1:TtVD+tE8ABN98n50MFVyMAvMsBM4JE86KRgCRDzPDC4= +github.com/GoCodeAlone/modular/modules/reverseproxy/v2 v2.7.0/go.mod h1:N7d8aSV4eqr90qjlIOs/8EmW7avt9gwX06Uh+zKDr4s= +github.com/GoCodeAlone/modular/modules/scheduler v1.14.0 h1:JSrzo4FB7uGASExv+fCLRd6pXWULV1mJYvzmM9PzUeM= +github.com/GoCodeAlone/modular/modules/scheduler v1.14.0/go.mod h1:emkR2AnilabLJZv1rOTDO9eGpRBmZs487H00Lnp9jIc= github.com/GoCodeAlone/yaegi v0.17.1 h1:aPAwU29L9cGceRAff02c5pjQcT5KapDB4fWFZK9tElE= github.com/GoCodeAlone/yaegi v0.17.1/go.mod h1:z5Pr6Wse6QJcQvpgxTxzMAevFarH0N37TG88Y9dprx0= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= diff --git a/module/http_middleware.go b/module/http_middleware.go index 0272ff7a..e78128e0 100644 --- a/module/http_middleware.go +++ b/module/http_middleware.go @@ -6,6 +6,7 @@ import ( "math" "net" "net/http" + "net/url" "strconv" "strings" "sync" @@ -386,15 +387,22 @@ func (m *CORSMiddleware) Init(app modular.Application) error { // corsOriginAllowed checks if the given origin is in the allowed list. // It supports exact matching, "*" wildcard, and subdomain wildcards like "*.example.com". +// Wildcard patterns are matched against the parsed hostname only, so ports are handled correctly: +// "*.example.com" will match "http://sub.example.com:3000". func corsOriginAllowed(origin string, allowedOrigins []string) bool { + if origin == "" { + return false + } for _, allowed := range allowedOrigins { if allowed == "*" || allowed == origin { return true } - // Wildcard subdomain matching: "*.example.com" matches "sub.example.com" + // Wildcard subdomain matching: "*.example.com" matches "sub.example.com" (any port). + // Parse the request origin to extract just the hostname for comparison. if strings.HasPrefix(allowed, "*.") { suffix := allowed[1:] // ".example.com" - if strings.HasSuffix(origin, suffix) { + u, err := url.Parse(origin) + if err == nil && strings.HasSuffix(u.Hostname(), suffix) { return true } } @@ -407,10 +415,10 @@ func (m *CORSMiddleware) Process(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") - // Check if origin is allowed - allowed := corsOriginAllowed(origin, m.allowedOrigins) - - if allowed { + // Only apply CORS headers when the request includes an Origin header. + // Requests without Origin are not cross-origin requests and need no CORS response. + if origin != "" && corsOriginAllowed(origin, m.allowedOrigins) { + w.Header().Add("Vary", "Origin") w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.allowedMethods, ", ")) w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.allowedHeaders, ", ")) diff --git a/module/http_middleware_test.go b/module/http_middleware_test.go index d9da7e32..bac86e7b 100644 --- a/module/http_middleware_test.go +++ b/module/http_middleware_test.go @@ -419,13 +419,19 @@ func TestCORSMiddlewareWithConfig_WildcardSubdomain(t *testing.T) { }{ {"http://app.example.com", true}, {"http://admin.example.com", true}, + // Port should be handled correctly via hostname parsing + {"http://app.example.com:3000", true}, {"http://evil.com", false}, {"http://notexample.com", false}, + // Empty origin must not match wildcard + {"", false}, } for _, tt := range tests { req := httptest.NewRequest("GET", "/test", nil) - req.Header.Set("Origin", tt.origin) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) @@ -436,6 +442,41 @@ func TestCORSMiddlewareWithConfig_WildcardSubdomain(t *testing.T) { } } +func TestCORSMiddleware_VaryHeader(t *testing.T) { + m := NewCORSMiddleware("cors", []string{"http://localhost:3000"}, []string{"GET"}) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://localhost:3000") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("Vary") != "Origin" { + t.Errorf("expected Vary: Origin header, got %q", rec.Header().Get("Vary")) + } +} + +func TestCORSMiddleware_EmptyOriginSkipped(t *testing.T) { + // When no Origin header is sent, CORS headers must not be set. + m := NewCORSMiddleware("cors", []string{"*"}, []string{"GET"}) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + // No Origin header set + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("Access-Control-Allow-Origin") != "" { + t.Errorf("expected no CORS headers when Origin is absent, got %q", rec.Header().Get("Access-Control-Allow-Origin")) + } +} + func TestCORSMiddlewareWithConfig_ProvidesServices(t *testing.T) { m := NewCORSMiddleware("cors-mw", nil, nil) svcs := m.ProvidesServices() diff --git a/plugins/http/modules.go b/plugins/http/modules.go index 8df90ba7..239f62fc 100644 --- a/plugins/http/modules.go +++ b/plugins/http/modules.go @@ -206,6 +206,8 @@ func corsMiddlewareFactory(name string, cfg map[string]any) modular.Module { } if maxAge, ok := cfg["maxAge"].(int); ok { corsCfg.MaxAge = maxAge + } else if maxAgeFloat, ok := cfg["maxAge"].(float64); ok { + corsCfg.MaxAge = int(maxAgeFloat) } return module.NewCORSMiddlewareWithConfig(name, corsCfg) } diff --git a/plugins/http/schemas.go b/plugins/http/schemas.go index 73212cd3..000e6650 100644 --- a/plugins/http/schemas.go +++ b/plugins/http/schemas.go @@ -176,12 +176,18 @@ func corsMiddlewareSchema() *schema.ModuleSchema { Inputs: []schema.ServiceIODef{{Name: "request", Type: "http.Request", Description: "HTTP request needing CORS headers"}}, Outputs: []schema.ServiceIODef{{Name: "cors", Type: "http.Request", Description: "HTTP request with CORS headers applied"}}, ConfigFields: []schema.ConfigFieldDef{ - {Key: "allowedOrigins", Label: "Allowed Origins", Type: schema.FieldTypeArray, ArrayItemType: "string", DefaultValue: []string{"*"}, Description: "Allowed origins (e.g. https://example.com, http://localhost:3000)"}, + {Key: "allowedOrigins", Label: "Allowed Origins", Type: schema.FieldTypeArray, ArrayItemType: "string", DefaultValue: []string{"*"}, Description: "Allowed origins (e.g. https://example.com, http://localhost:3000, *.example.com)"}, {Key: "allowedMethods", Label: "Allowed Methods", Type: schema.FieldTypeArray, ArrayItemType: "string", DefaultValue: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, Description: "Allowed HTTP methods"}, + {Key: "allowedHeaders", Label: "Allowed Headers", Type: schema.FieldTypeArray, ArrayItemType: "string", DefaultValue: []string{"Content-Type", "Authorization"}, Description: "Allowed request headers (e.g. Authorization, X-CSRF-Token, X-Request-Id)"}, + {Key: "allowCredentials", Label: "Allow Credentials", Type: schema.FieldTypeBool, DefaultValue: false, Description: "Whether to allow requests with credentials (cookies, authorization headers). When true, the actual Origin is reflected instead of *"}, + {Key: "maxAge", Label: "Max Age (sec)", Type: schema.FieldTypeNumber, DefaultValue: 0, Description: "How long (in seconds) the preflight response may be cached. 0 means no caching directive is sent"}, }, DefaultConfig: map[string]any{ - "allowedOrigins": []string{"*"}, - "allowedMethods": []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + "allowedOrigins": []string{"*"}, + "allowedMethods": []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + "allowedHeaders": []string{"Content-Type", "Authorization"}, + "allowCredentials": false, + "maxAge": 0, }, } }