Skip to content
566 changes: 566 additions & 0 deletions DOCUMENTATION.md

Large diffs are not rendered by default.

20 changes: 18 additions & 2 deletions admin/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ modules:
requestsPerMinute: 120
burstSize: 20

# Strict per-IP rate limiter for login: 10 attempts/minute, burst 10
- name: auth-login-ratelimit
type: http.middleware.ratelimit
config:
requestsPerMinute: 10
burstSize: 10

# Strict per-IP rate limiter for registration: 5 attempts/hour, burst 5
- name: auth-register-ratelimit
type: http.middleware.ratelimit
config:
requestsPerHour: 5
burstSize: 5

# --- Data Layer ---
- name: admin-db
type: storage.sqlite
Expand Down Expand Up @@ -143,14 +157,16 @@ workflows:
middlewares: [admin-cors, admin-ratelimit]

# === Auth (unauthenticated) ===
# login: strict per-IP rate limit (10/minute) to defend against brute-force
- method: POST
path: "/api/v1/auth/login"
handler: admin-auth
middlewares: [admin-cors, admin-ratelimit]
middlewares: [admin-cors, auth-login-ratelimit]
# register: strict per-IP rate limit (5/hour) to defend against account enumeration
- method: POST
path: "/api/v1/auth/register"
handler: admin-auth
middlewares: [admin-cors, admin-ratelimit]
middlewares: [admin-cors, auth-register-ratelimit]
- method: POST
path: "/api/v1/auth/refresh"
handler: admin-auth
Expand Down
63 changes: 50 additions & 13 deletions module/api_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,21 @@ type AuthConfig struct {

// APIGateway is a composable gateway module that combines routing, auth,
// rate limiting, and proxying into a single module.
//
// Each APIGateway instance maintains its own independent rate limiter state.
// Rate limiters are never shared across instances, so multiple APIGateway
// instances (e.g. in multi-tenant deployments) do not interfere with each other.
type APIGateway struct {
name string
routes []GatewayRoute
cors *CORSConfig
auth *AuthConfig

// internal state
sortedRoutes []GatewayRoute // sorted by prefix length (longest first)
proxies map[string]*httputil.ReverseProxy
rateLimiters map[string]*gatewayRateLimiter // keyed by path prefix
globalLimiter *gatewayRateLimiter
sortedRoutes []GatewayRoute // sorted by prefix length (longest first)
proxies map[string]*httputil.ReverseProxy
rateLimiters map[string]*gatewayRateLimiter // keyed by path prefix
instanceRateLimiter *gatewayRateLimiter // instance-scoped limiter applied before per-route limits
}

// gatewayRateLimiter is a simple per-client token bucket limiter for the gateway.
Expand Down Expand Up @@ -90,13 +94,36 @@ func (rl *gatewayRateLimiter) allow(clientIP string) bool {
return bucket.allow()
}

// NewAPIGateway creates a new APIGateway module.
func NewAPIGateway(name string) *APIGateway {
return &APIGateway{
// APIGatewayOption is a functional option for configuring an APIGateway at construction time.
type APIGatewayOption func(*APIGateway)

// WithRateLimit sets an instance-level rate limit applied to all requests before per-route
// limits are checked. The limiter is scoped to this APIGateway instance and does not affect
// any other instance.
func WithRateLimit(cfg *RateLimitConfig) APIGatewayOption {
return func(g *APIGateway) {
if cfg != nil && cfg.RequestsPerMinute > 0 {
burst := cfg.BurstSize
if burst <= 0 {
burst = cfg.RequestsPerMinute
}
g.instanceRateLimiter = newGatewayRateLimiter(cfg.RequestsPerMinute, burst)
}
}
}

// NewAPIGateway creates a new APIGateway module. Optional functional options can be
// provided to configure the instance at construction time (e.g. WithRateLimit).
func NewAPIGateway(name string, opts ...APIGatewayOption) *APIGateway {
g := &APIGateway{
name: name,
proxies: make(map[string]*httputil.ReverseProxy),
rateLimiters: make(map[string]*gatewayRateLimiter),
}
for _, opt := range opts {
opt(g)
}
return g
}

// SetRoutes configures the gateway routes.
Expand Down Expand Up @@ -146,17 +173,27 @@ func (g *APIGateway) SetRoutes(routes []GatewayRoute) error {
return nil
}

// SetGlobalRateLimit configures a global rate limit applied to all routes.
func (g *APIGateway) SetGlobalRateLimit(cfg *RateLimitConfig) {
// SetRateLimit configures an instance-level rate limit applied to all routes on this gateway.
// The limiter is scoped to this APIGateway instance and does not affect any other instance.
// Prefer injecting rate limit config via WithRateLimit at construction time when possible.
func (g *APIGateway) SetRateLimit(cfg *RateLimitConfig) {
if cfg != nil && cfg.RequestsPerMinute > 0 {
burst := cfg.BurstSize
if burst <= 0 {
burst = cfg.RequestsPerMinute
}
g.globalLimiter = newGatewayRateLimiter(cfg.RequestsPerMinute, burst)
g.instanceRateLimiter = newGatewayRateLimiter(cfg.RequestsPerMinute, burst)
}
}

// SetGlobalRateLimit is deprecated: use SetRateLimit instead.
// The rate limiter has always been instance-scoped; this method was misleadingly named.
//
// Deprecated: Use SetRateLimit.
func (g *APIGateway) SetGlobalRateLimit(cfg *RateLimitConfig) {
g.SetRateLimit(cfg)
}

// SetCORS configures CORS settings.
func (g *APIGateway) SetCORS(cfg *CORSConfig) {
g.cors = cfg
Expand Down Expand Up @@ -214,9 +251,9 @@ func (g *APIGateway) Handle(w http.ResponseWriter, r *http.Request) {

clientIP := extractClientIP(r)

// Global rate limiting
if g.globalLimiter != nil {
if !g.globalLimiter.allow(clientIP) {
// Instance-level rate limiting (applied before per-route limits)
if g.instanceRateLimiter != nil {
if !g.instanceRateLimiter.allow(clientIP) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]string{
Expand Down
73 changes: 71 additions & 2 deletions module/api_gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,14 @@ func TestAPIGateway_CORS(t *testing.T) {
}
}

func TestAPIGateway_GlobalRateLimit(t *testing.T) {
func TestAPIGateway_InstanceRateLimit_SetRateLimit(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()

gw := NewAPIGateway("gw")
gw.SetGlobalRateLimit(&RateLimitConfig{RequestsPerMinute: 60, BurstSize: 2})
gw.SetRateLimit(&RateLimitConfig{RequestsPerMinute: 60, BurstSize: 2})
_ = gw.SetRoutes([]GatewayRoute{
{PathPrefix: "/api", Backend: backend.URL},
})
Expand All @@ -252,6 +252,75 @@ func TestAPIGateway_GlobalRateLimit(t *testing.T) {
}
}

func TestAPIGateway_InstanceRateLimit_WithRateLimit(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()

gw := NewAPIGateway("gw", WithRateLimit(&RateLimitConfig{RequestsPerMinute: 60, BurstSize: 1}))
_ = gw.SetRoutes([]GatewayRoute{
{PathPrefix: "/api", Backend: backend.URL},
})

// First should succeed (burst=1)
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "10.0.0.2:1234"
w := httptest.NewRecorder()
gw.Handle(w, req)
if w.Code != http.StatusOK {
t.Errorf("first request expected 200, got %d", w.Code)
}

// Second should be rate limited
req = httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "10.0.0.2:1234"
w = httptest.NewRecorder()
gw.Handle(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", w.Code)
}
}

func TestAPIGateway_InstanceRateLimiters_AreIsolated(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()

cfg := &RateLimitConfig{RequestsPerMinute: 60, BurstSize: 1}
gw1 := NewAPIGateway("gw1", WithRateLimit(cfg))
gw2 := NewAPIGateway("gw2", WithRateLimit(cfg))
_ = gw1.SetRoutes([]GatewayRoute{{PathPrefix: "/api", Backend: backend.URL}})
_ = gw2.SetRoutes([]GatewayRoute{{PathPrefix: "/api", Backend: backend.URL}})

// Exhaust gw1's burst for this client
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "10.0.0.3:1234"
w := httptest.NewRecorder()
gw1.Handle(w, req)
if w.Code != http.StatusOK {
t.Errorf("gw1 first request expected 200, got %d", w.Code)
}

req = httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "10.0.0.3:1234"
w = httptest.NewRecorder()
gw1.Handle(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("gw1 second request expected 429, got %d", w.Code)
}

// gw2 should be unaffected — its burst is independent
req = httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "10.0.0.3:1234"
w = httptest.NewRecorder()
gw2.Handle(w, req)
if w.Code != http.StatusOK {
t.Errorf("gw2 should be isolated from gw1; expected 200, got %d", w.Code)
}
}

func TestAPIGateway_PerRouteRateLimit(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
Expand Down
39 changes: 32 additions & 7 deletions module/http_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
type RateLimitMiddleware struct {
name string
requestsPerMinute int
ratePerMinute float64 // fractional rate, used when requestsPerHour is set
burstSize int
strategy RateLimitStrategy
tokenHeader string // HTTP header to extract token from
Expand All @@ -44,7 +45,7 @@ type RateLimitMiddleware struct {

// client tracks the rate limiting state for a single client
type client struct {
tokens int
tokens float64
lastTimestamp time.Time
}

Expand All @@ -53,6 +54,7 @@ func NewRateLimitMiddleware(name string, requestsPerMinute, burstSize int) *Rate
return &RateLimitMiddleware{
name: name,
requestsPerMinute: requestsPerMinute,
ratePerMinute: float64(requestsPerMinute),
burstSize: burstSize,
strategy: RateLimitByIP,
tokenHeader: "Authorization",
Expand All @@ -62,6 +64,24 @@ func NewRateLimitMiddleware(name string, requestsPerMinute, burstSize int) *Rate
}
}

// NewRateLimitMiddlewareWithHourlyRate creates a rate limiting middleware using
// a per-hour rate. Useful for low-frequency endpoints like registration where
// fractional per-minute rates are needed.
func NewRateLimitMiddlewareWithHourlyRate(name string, requestsPerHour, burstSize int) *RateLimitMiddleware {
m := &RateLimitMiddleware{
name: name,
requestsPerMinute: 0, // not used when ratePerMinute is set
ratePerMinute: float64(requestsPerHour) / 60.0,
burstSize: burstSize,
strategy: RateLimitByIP,
tokenHeader: "Authorization",
clients: make(map[string]*client),
cleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
}
return m
}

// NewRateLimitMiddlewareWithStrategy creates a rate limiting middleware with
// a specific client identification strategy.
func NewRateLimitMiddlewareWithStrategy(name string, requestsPerMinute, burstSize int, strategy RateLimitStrategy) *RateLimitMiddleware {
Expand Down Expand Up @@ -132,20 +152,20 @@ func (m *RateLimitMiddleware) Process(next http.Handler) http.Handler {
m.mu.Lock()
c, exists := m.clients[key]
if !exists {
c = &client{tokens: m.burstSize, lastTimestamp: time.Now()}
c = &client{tokens: float64(m.burstSize), lastTimestamp: time.Now()}
m.clients[key] = c
} else {
// Refill tokens based on elapsed time
// Refill tokens based on elapsed time using fractional rate
elapsed := time.Since(c.lastTimestamp).Minutes()
tokensToAdd := int(elapsed * float64(m.requestsPerMinute))
tokensToAdd := elapsed * m.ratePerMinute
if tokensToAdd > 0 {
c.tokens = min(c.tokens+tokensToAdd, m.burstSize)
c.tokens = min(c.tokens+tokensToAdd, float64(m.burstSize))
c.lastTimestamp = time.Now()
}
}

// Check if request can proceed
if c.tokens <= 0 {
if c.tokens < 1 {
m.mu.Unlock()
w.Header().Set("Retry-After", "60")
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
Expand All @@ -163,7 +183,12 @@ func (m *RateLimitMiddleware) Process(next http.Handler) http.Handler {
// cleanupStaleClients removes client entries that haven't been seen in over
// twice the refill window. This prevents unbounded memory growth.
func (m *RateLimitMiddleware) cleanupStaleClients() {
staleThreshold := 2 * time.Minute * time.Duration(max(1, m.burstSize/max(1, m.requestsPerMinute)))
// Use fractional ratePerMinute to compute refill window correctly
refillWindow := 1.0
if m.ratePerMinute > 0 {
refillWindow = float64(m.burstSize) / m.ratePerMinute
}
staleThreshold := time.Duration(2*refillWindow) * time.Minute
if staleThreshold < 10*time.Minute {
staleThreshold = 10 * time.Minute
}
Expand Down
Loading
Loading