diff --git a/README.md b/README.md index 9662841..e806437 100644 --- a/README.md +++ b/README.md @@ -116,8 +116,20 @@ Intake provides helper functions for creating endpoints with specific HTTP metho // These all create endpoint objects getEndpoint := intake.GET("/users", listUsersHandler) postEndpoint := intake.POST("/users", createUserHandler) -putEndpoint := intake.PUT("/users/:id", updateUserHandler) -deleteEndpoint := intake.DELETE("/users/:id", deleteUserHandler) +putEndpoint := intake.PUT("/users/{id}", updateUserHandler) +deleteEndpoint := intake.DELETE("/users/{id}", deleteUserHandler) +``` + +### Path Parameters (Go 1.22+) + +Intake registers Go 1.22+ ServeMux patterns, so you can use `{name}` segments +and retrieve them with `r.PathValue`. + +```go +app.AddEndpoint(http.MethodGet, "/users/{id}", func(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + fmt.Println("user id:", id) +}) ``` ### Managing Groups of Endpoints diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..a1c21de --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,53 @@ +package intake + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func BenchmarkServeHTTP(b *testing.B) { + app := New() + app.AddGlobalMiddleware(func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + next(w, r) + } + }) + app.AddEndpoint(http.MethodGet, "/bench", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/bench", nil) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := httptest.NewRecorder() + app.Mux.ServeHTTP(rr, req) + } +} + +func BenchmarkCORSPreflight(b *testing.B) { + app := New() + app.AddGlobalMiddleware(CORS(CORSConfig{ + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{http.MethodGet, http.MethodOptions}, + AllowedHeaders: []string{"X-Token", "Content-Type"}, + MaxAge: 600, + })) + app.AddEndpoint(http.MethodGet, "/bench", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodOptions, "/bench", nil) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", http.MethodGet) + req.Header.Set("Access-Control-Request-Headers", "X-Token") + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := httptest.NewRecorder() + app.Mux.ServeHTTP(rr, req) + } +} diff --git a/cors.go b/cors.go index 4c7fabc..fb5e8ac 100644 --- a/cors.go +++ b/cors.go @@ -4,10 +4,34 @@ package intake import ( "net/http" + "net/url" + "slices" "strconv" "strings" ) +type originPattern struct { + scheme string + suffix string +} + +type corsPolicy struct { + allowedMethods []string + allowedMethodsSet map[string]struct{} + allowedMethodsHeader string + allowedHeaders []string + allowedHeadersSet map[string]struct{} + allowedHeadersHeader string + allowedOrigins map[string]struct{} + allowedPatterns []originPattern + allowAnyOrigin bool + allowAnyHeader bool + exposeHeaders []string + exposeHeadersHeader string + allowCredentials bool + maxAge int +} + // CORSConfig defines the configuration options for the CORS middleware. // This struct allows for fine-grained control over CORS policy implementation. // Each field corresponds to a specific CORS header or behavior as defined in @@ -102,19 +126,7 @@ func CORS(config CORSConfig) MiddleWare { config.AllowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead} } - // Check for invalid configuration: wildcard origin with credentials - // According to spec, this is an invalid combination for security reasons - // If detected, we remove the wildcard to maintain security - if config.AllowCredentials && containsWildcard(config.AllowedOrigins) { - // Remove wildcard from allowed origins - newAllowedOrigins := make([]string, 0, len(config.AllowedOrigins)) - for _, origin := range config.AllowedOrigins { - if origin != "*" { - newAllowedOrigins = append(newAllowedOrigins, origin) - } - } - config.AllowedOrigins = newAllowedOrigins - } + policy := buildPolicy(config) return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -126,7 +138,7 @@ func CORS(config CORSConfig) MiddleWare { } // Check if the origin is allowed by the configured policy - originAllowed := isOriginAllowed(origin, config.AllowedOrigins) + originAllowed := policy.isOriginAllowed(origin) if !originAllowed { // Origin not allowed, pass through without CORS headers // This maintains security by not acknowledging invalid cross-origin requests @@ -139,18 +151,18 @@ func CORS(config CORSConfig) MiddleWare { // if the CORS request is allowed by the server if r.Method == http.MethodOptions { // Set standard CORS headers for all responses - corsHeaders(w, config, origin) + corsHeaders(w, policy, origin) // Set cache duration for preflight response // This helps reduce the number of preflight requests - if config.MaxAge > 0 { - w.Header().Set("Access-Control-Max-Age", strconv.Itoa(config.MaxAge)) + if policy.maxAge > 0 { + w.Header().Set("Access-Control-Max-Age", strconv.Itoa(policy.maxAge)) } // Check if the requested HTTP method is allowed requestMethod := r.Header.Get("Access-Control-Request-Method") if requestMethod != "" { - methodAllowed := contains(config.AllowedMethods, requestMethod) + _, methodAllowed := policy.allowedMethodsSet[requestMethod] if !methodAllowed { // Method not allowed - respond with 403 Forbidden w.WriteHeader(http.StatusForbidden) @@ -159,27 +171,34 @@ func CORS(config CORSConfig) MiddleWare { } // Set the list of allowed HTTP methods - if len(config.AllowedMethods) > 0 { - w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", ")) + if policy.allowedMethodsHeader != "" { + w.Header().Set("Access-Control-Allow-Methods", policy.allowedMethodsHeader) } // Handle the requested headers check requestHeaders := r.Header.Get("Access-Control-Request-Headers") - if len(config.AllowedHeaders) > 0 { - if containsWildcard(config.AllowedHeaders) { + if len(policy.allowedHeaders) > 0 || policy.allowAnyHeader { + if policy.allowAnyHeader { // If wildcard is configured for headers, mirror the requested headers // This allows the browser to send any headers it needs if requestHeaders != "" { w.Header().Set("Access-Control-Allow-Headers", requestHeaders) } } else { - // Otherwise, only allow the specifically configured headers - w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) + // Otherwise, only allow the specifically configured headers, + // and reject preflights that ask for disallowed headers. + if requestHeaders != "" && !policy.areHeadersAllowed(requestHeaders) { + w.WriteHeader(http.StatusForbidden) + return + } + if policy.allowedHeadersHeader != "" { + w.Header().Set("Access-Control-Allow-Headers", policy.allowedHeadersHeader) + } } } else if requestHeaders != "" { - // No allowed headers explicitly configured, but client requested headers - // Mirror the headers since no restrictions were explicitly set - w.Header().Set("Access-Control-Allow-Headers", requestHeaders) + // No allowed headers configured: reject explicit header requests. + w.WriteHeader(http.StatusForbidden) + return } // Preflight requests only need headers, not content @@ -190,7 +209,7 @@ func CORS(config CORSConfig) MiddleWare { // Handle actual CORS request (not a preflight) // Apply the CORS headers and continue with request processing - corsHeaders(w, config, origin) + corsHeaders(w, policy, origin) next(w, r) } } @@ -202,14 +221,14 @@ func CORS(config CORSConfig) MiddleWare { // // Parameters: // - w: The HTTP response writer to set headers on -// - config: The CORS configuration to apply +// - config: The CORS policy to apply // - origin: The requesting Origin header value -func corsHeaders(w http.ResponseWriter, config CORSConfig, origin string) { +func corsHeaders(w http.ResponseWriter, config corsPolicy, origin string) { // Set Access-Control-Allow-Origin header // There are two strategies based on configuration: // 1. Use "*" when wildcard origins are allowed and credentials aren't required // 2. Mirror the specific origin otherwise (required when using credentials) - if containsWildcard(config.AllowedOrigins) && !config.AllowCredentials { + if config.allowAnyOrigin && !config.allowCredentials { w.Header().Set("Access-Control-Allow-Origin", "*") } else { // Echo back the specific origin @@ -221,81 +240,131 @@ func corsHeaders(w http.ResponseWriter, config CORSConfig, origin string) { // Set Access-Control-Allow-Credentials header if credentials are allowed // This enables sending cookies, authorization headers, and TLS client certs - if config.AllowCredentials { + if config.allowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") } // Set Access-Control-Expose-Headers header if specific headers should be // accessible to JavaScript in the browser - if len(config.ExposeHeaders) > 0 { - w.Header().Set("Access-Control-Expose-Headers", strings.Join(config.ExposeHeaders, ", ")) + if config.exposeHeadersHeader != "" { + w.Header().Set("Access-Control-Expose-Headers", config.exposeHeadersHeader) } } -// containsWildcard checks if the slice contains the wildcard "*" value. -// This is a helper function used to determine if wildcard patterns exist -// in configuration settings like AllowedOrigins or AllowedHeaders. -// -// Parameters: -// - s: The string slice to check for wildcards -// -// Returns: -// - true if the slice contains "*", false otherwise -func containsWildcard(s []string) bool { - return contains(s, "*") -} +func buildPolicy(config CORSConfig) corsPolicy { + policy := corsPolicy{ + allowedMethods: config.AllowedMethods, + allowedMethodsSet: make(map[string]struct{}, len(config.AllowedMethods)), + allowedMethodsHeader: strings.Join(config.AllowedMethods, ", "), + allowedHeaders: config.AllowedHeaders, + allowedHeadersSet: make(map[string]struct{}, len(config.AllowedHeaders)), + allowedHeadersHeader: strings.Join(config.AllowedHeaders, ", "), + allowedOrigins: make(map[string]struct{}, len(config.AllowedOrigins)), + allowAnyOrigin: false, + allowAnyHeader: containsWildcard(config.AllowedHeaders), + exposeHeaders: config.ExposeHeaders, + exposeHeadersHeader: strings.Join(config.ExposeHeaders, ", "), + allowCredentials: config.AllowCredentials, + maxAge: config.MaxAge, + } -// isOriginAllowed checks if the origin is allowed based on the allowed origins list. -// This function supports multiple matching strategies: -// - Exact match with a specific origin -// - Wildcard match allowing all origins -// - Domain pattern matching (e.g., "https://*.example.com") -// -// Parameters: -// - origin: The origin from the request to check -// - allowedOrigins: The list of origins allowed by configuration -// -// Returns: -// - true if the origin is allowed, false otherwise -func isOriginAllowed(origin string, allowedOrigins []string) bool { - if len(allowedOrigins) == 0 { - return false + for _, method := range config.AllowedMethods { + policy.allowedMethodsSet[method] = struct{}{} } - for _, allowedOrigin := range allowedOrigins { - // Check for wildcard allowing all origins - if allowedOrigin == "*" { - return true + for _, header := range config.AllowedHeaders { + policy.allowedHeadersSet[strings.ToLower(header)] = struct{}{} + } + + for _, origin := range config.AllowedOrigins { + if origin == "*" { + policy.allowAnyOrigin = true + continue } - // Check for exact match - if allowedOrigin == origin { - return true + if strings.HasPrefix(origin, "https://*.") { + policy.allowedPatterns = append(policy.allowedPatterns, originPattern{ + scheme: "https", + suffix: origin[len("https://*."):], + }) + continue } + if strings.HasPrefix(origin, "http://*.") { + policy.allowedPatterns = append(policy.allowedPatterns, originPattern{ + scheme: "http", + suffix: origin[len("http://*."):], + }) + continue + } + + policy.allowedOrigins[origin] = struct{}{} + } + + // Invalid configuration: wildcard origin with credentials. + // Remove wildcard to maintain security. + if policy.allowCredentials && policy.allowAnyOrigin { + policy.allowAnyOrigin = false + } + + return policy +} + +func (p corsPolicy) isOriginAllowed(origin string) bool { + if p.allowAnyOrigin { + return true + } + if _, ok := p.allowedOrigins[origin]; ok { + return true + } + if len(p.allowedPatterns) == 0 { + return false + } + + u, err := url.Parse(origin) + if err != nil { + return false + } + host := u.Hostname() + if host == "" { + return false + } - // Support for origin patterns like "https://*.example.com" - // This allows any subdomain of example.com to be matched by a single rule - if strings.HasPrefix(allowedOrigin, "https://*.") && strings.HasSuffix(origin, allowedOrigin[10:]) { + for _, pattern := range p.allowedPatterns { + if u.Scheme != pattern.scheme { + continue + } + if host != pattern.suffix && strings.HasSuffix(host, "."+pattern.suffix) { return true } } return false } -// contains checks if a string exists in a slice. -// This is a general utility function for string slice membership testing. +func (p corsPolicy) areHeadersAllowed(requestHeaders string) bool { + if requestHeaders == "" { + return true + } + for _, h := range strings.Split(requestHeaders, ",") { + header := strings.ToLower(strings.TrimSpace(h)) + if header == "" { + continue + } + if _, ok := p.allowedHeadersSet[header]; !ok { + return false + } + } + return true +} + +// containsWildcard checks if the slice contains the wildcard "*" value. +// This is a helper function used to determine if wildcard patterns exist +// in configuration settings like AllowedOrigins or AllowedHeaders. // // Parameters: -// - s: The string slice to search in -// - str: The string to search for +// - s: The string slice to check for wildcards // // Returns: -// - true if the string is found in the slice, false otherwise -func contains(s []string, str string) bool { - for _, v := range s { - if v == str { - return true - } - } - return false +// - true if the slice contains "*", false otherwise +func containsWildcard(s []string) bool { + return slices.Contains(s, "*") } diff --git a/cors_test.go b/cors_test.go new file mode 100644 index 0000000..cdab76a --- /dev/null +++ b/cors_test.go @@ -0,0 +1,118 @@ +package intake + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestCORSPreflightHeaderValidation(t *testing.T) { + app := New() + app.AddGlobalMiddleware(CORS(CORSConfig{ + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{http.MethodGet}, + AllowedHeaders: []string{"X-Token", "Content-Type"}, + })) + app.AddEndpoint(http.MethodGet, "/data", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + app.AddOptionsEndpoints() + + t.Run("rejects disallowed headers", func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/data", nil) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", http.MethodGet) + req.Header.Set("Access-Control-Request-Headers", "X-Token, X-Other") + rr := httptest.NewRecorder() + + app.Mux.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("expected status %d, got %d", http.StatusForbidden, rr.Code) + } + }) + + t.Run("allows configured headers", func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/data", nil) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", http.MethodGet) + req.Header.Set("Access-Control-Request-Headers", "x-token") + rr := httptest.NewRecorder() + + app.Mux.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("expected status %d, got %d", http.StatusNoContent, rr.Code) + } + if got := rr.Header().Get("Access-Control-Allow-Headers"); got == "" { + t.Fatalf("expected Access-Control-Allow-Headers, got empty") + } + }) +} + +func TestCORSWildcardHeadersEcho(t *testing.T) { + app := New() + app.AddGlobalMiddleware(CORS(CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{http.MethodGet}, + AllowedHeaders: []string{"*"}, + })) + app.AddEndpoint(http.MethodGet, "/data", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + app.AddOptionsEndpoints() + + req := httptest.NewRequest(http.MethodOptions, "/data", nil) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", http.MethodGet) + req.Header.Set("Access-Control-Request-Headers", "X-Foo, X-Bar") + rr := httptest.NewRecorder() + + app.Mux.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("expected status %d, got %d", http.StatusNoContent, rr.Code) + } + if got := rr.Header().Get("Access-Control-Allow-Headers"); got != "X-Foo, X-Bar" { + t.Fatalf("expected Access-Control-Allow-Headers to echo request, got %q", got) + } + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Fatalf("expected Access-Control-Allow-Origin '*', got %q", got) + } +} + +func TestCORSWildcardOriginSchemeMatch(t *testing.T) { + app := New() + app.AddGlobalMiddleware(CORS(CORSConfig{ + AllowedOrigins: []string{"http://*.example.com"}, + AllowedMethods: []string{http.MethodGet}, + })) + app.AddEndpoint(http.MethodGet, "/data", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + app.AddOptionsEndpoints() + + t.Run("allows http subdomain", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/data", nil) + req.Header.Set("Origin", "http://api.example.com") + rr := httptest.NewRecorder() + + app.Mux.ServeHTTP(rr, req) + + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://api.example.com" { + t.Fatalf("expected Access-Control-Allow-Origin to echo origin, got %q", got) + } + }) + + t.Run("rejects https subdomain", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/data", nil) + req.Header.Set("Origin", "https://api.example.com") + rr := httptest.NewRecorder() + + app.Mux.ServeHTTP(rr, req) + + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Fatalf("expected no Access-Control-Allow-Origin, got %q", got) + } + }) +} diff --git a/delme.md b/delme.md new file mode 100644 index 0000000..e69de29 diff --git a/intake.go b/intake.go index 7b37c33..c3002d1 100644 --- a/intake.go +++ b/intake.go @@ -92,7 +92,8 @@ func (a *Intake) AddEndpoints(e ...Endpoints) { // // Parameters: // - verb: The HTTP method (GET, POST, PUT, DELETE, etc.) -// - path: The URL path to register the handler for +// - path: The URL path to register the handler for. Go 1.22 path parameters +// like /users/{id} are supported. // - finalHandler: The handler function that will process the request // - middleware: Optional route-specific middleware functions func (a *Intake) AddEndpoint(verb string, path string, finalHandler http.HandlerFunc, middleware ...MiddleWare) { @@ -105,34 +106,35 @@ func (a *Intake) AddEndpoint(verb string, path string, finalHandler http.Handler handlerKey := fmt.Sprintf("%s %s", verb, path) - // Apply global middleware first - var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { - // Apply panic recovery here to capture panics in both global and route middleware - if a.PanicHandler != nil { - defer func() { - if err := recover(); err != nil { - a.PanicHandler(w, r, err) - } - }() - } - - // Apply route middleware and the final handler - routeHandler := finalHandler - for i := len(middleware) - 1; i >= 0; i-- { - if middleware[i] != nil { - routeHandler = middleware[i](routeHandler) - } + // Build route-specific chain first. + routeHandler := finalHandler + for i := len(middleware) - 1; i >= 0; i-- { + if middleware[i] != nil { + routeHandler = middleware[i](routeHandler) } - routeHandler(w, r) } // Apply global middleware in reverse order + handler := routeHandler for i := len(a.GlobalMiddleware) - 1; i >= 0; i-- { if a.GlobalMiddleware[i] != nil { handler = a.GlobalMiddleware[i](handler) } } + // Apply panic recovery last so it wraps global and route middleware. + if a.PanicHandler != nil { + inner := handler + handler = func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + a.PanicHandler(w, r, err) + } + }() + inner(w, r) + } + } + a.Mux.HandleFunc(handlerKey, handler) } diff --git a/intake_test.go b/intake_test.go index 372a6e4..90daaee 100644 --- a/intake_test.go +++ b/intake_test.go @@ -1,7 +1,6 @@ package intake import ( - "context" "encoding/json" "io/ioutil" "net/http" @@ -46,6 +45,28 @@ func TestIntake(t *testing.T) { } }) + t.Run("test path params", func(t *testing.T) { + paramApp := New() + var got string + handler := func(w http.ResponseWriter, r *http.Request) { + got = r.PathValue("hello") + w.WriteHeader(http.StatusOK) + } + + paramApp.AddEndpoint(http.MethodGet, "/api/{hello}/world", handler) + + r := httptest.NewRequest(http.MethodGet, "/api/hi/world", nil) + w := httptest.NewRecorder() + paramApp.Mux.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + if got != "hi" { + t.Errorf("Expected path value %q, got %q", "hi", got) + } + }) + t.Run("test middleware execution", func(t *testing.T) { middlewareCalled := false middleware := func(next http.HandlerFunc) http.HandlerFunc { @@ -249,18 +270,12 @@ func TestIntake(t *testing.T) { panicApp := New() middlewareCalled := false panicHandlerCalled := false - requestID := "test-request-123" errorMessage := "Middleware panic" - // Set up a panic handler that checks for request context + // Set up a panic handler panicApp.SetPanicHandler(func(w http.ResponseWriter, r *http.Request, err any) { panicHandlerCalled = true - // Verify middleware was executed by checking for request ID - if id := r.Context().Value("requestID"); id != requestID { - t.Errorf("Expected requestID %q in context, got %v", requestID, id) - } - w.WriteHeader(http.StatusInternalServerError) errMsg, ok := err.(string) if !ok { @@ -269,15 +284,11 @@ func TestIntake(t *testing.T) { w.Write([]byte(errMsg)) }) - // Create middleware that adds a request ID to context + // Create middleware that marks execution requestIDMiddleware := func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { middlewareCalled = true - // Create a new request with context containing request ID - ctx := context.WithValue(r.Context(), "requestID", requestID) - r = r.WithContext(ctx) - next(w, r) } } @@ -329,3 +340,25 @@ func TestIntake(t *testing.T) { } }) } + +func TestIsOriginAllowedWildcardSubdomain(t *testing.T) { + policy := buildPolicy(CORSConfig{ + AllowedOrigins: []string{"https://*.example.com"}, + }) + + cases := []struct { + origin string + want bool + }{ + {origin: "https://api.example.com", want: true}, + {origin: "https://example.com", want: false}, + {origin: "https://badexample.com", want: false}, + {origin: "http://api.example.com", want: false}, + } + + for _, tc := range cases { + if got := policy.isOriginAllowed(tc.origin); got != tc.want { + t.Fatalf("origin %q allowed=%v, want %v", tc.origin, got, tc.want) + } + } +}