Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,27 @@ const (
requestLimitKey
)

// WithIncrement sets the increment value in the context.
func WithIncrement(ctx context.Context, value int) context.Context {
return context.WithValue(ctx, incrementKey, value)
}

// getIncrement gets the increment value from the context, which was set by
// [WithIncrement].
func getIncrement(ctx context.Context) int {
if value, ok := ctx.Value(incrementKey).(int); ok {
return value
}
return 1
}

// WithRequestLimit sets the request limit in the context.
func WithRequestLimit(ctx context.Context, value int) context.Context {
return context.WithValue(ctx, requestLimitKey, value)
}

// getRequestLimit gets the request limit from the context, which was set by
// [WithRequestLimit].
func getRequestLimit(ctx context.Context) int {
if value, ok := ctx.Value(requestLimitKey).(int); ok {
return value
Expand Down
9 changes: 5 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ go 1.23.0

toolchain go1.24.1

require github.com/zeebo/xxh3 v1.0.2

require golang.org/x/sys v0.30.0 // indirect
require (
github.com/zeebo/xxh3 v1.0.2
golang.org/x/sync v0.12.0
)

require (
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
golang.org/x/sync v0.12.0
golang.org/x/sys v0.30.0 // indirect
)
89 changes: 79 additions & 10 deletions httprate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,71 @@ import (
"time"
)

// Limit creates a new [net/http] middleware that limits requests by the given
// request limit and window length. The returned middleware will call the next
// handler if the request limit is not exceeded.
func Limit(requestLimit int, windowLength time.Duration, options ...Option) func(next http.Handler) http.Handler {
return NewRateLimiter(requestLimit, windowLength, options...).Handler
}

// KeyFunc is a function that derives a key for the given request.
type KeyFunc func(r *http.Request) (string, error)

// Option is a function that configures the rate limiter.
type Option func(rl *RateLimiter)

// Set custom response headers. If empty, the header is omitted.
// ResponseHeaders defines custom response headers. If empty, the header is omitted.
type ResponseHeaders struct {
Limit string // Default: X-RateLimit-Limit
Remaining string // Default: X-RateLimit-Remaining
Increment string // Default: X-RateLimit-Increment
Reset string // Default: X-RateLimit-Reset
RetryAfter string // Default: Retry-After
}

// Limit is the total number of requests that are permitted before the rate limit
// is exceeded. Default: "X-RateLimit-Limit".
Limit string
// Remaining is the number of requests remaining before the rate limit is
// exceeded. Default: "X-RateLimit-Remaining".
Remaining string
// Increment is the number of requests incremented by the rate limiter. Default:
// "X-RateLimit-Increment".
Increment string
// Reset is the time at which the rate limit will be reset. Default:
// "X-RateLimit-Reset".
Reset string
// RetryAfter is the time in seconds after which the rate limit will be reset.
// Default: "Retry-After".
RetryAfter string
}

// LimitAll is a shortcut for [Limit] which uses a shared default key, resulting in
// a single rate-limiter for all requests.
func LimitAll(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler {
return Limit(requestLimit, windowLength)
}

// LimitByIP is a shortcut for [Limit] with the key function set to [KeyByIP],
// returning a new [net/http] middleware that limits requests by IP address.
func LimitByIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler {
return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByIP))
}

// LimitByRealIP is a shortcut for [Limit] with the key function set to [KeyByRealIP],
// returning a new [net/http] middleware that limits requests by real IP address.
func LimitByRealIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler {
return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByRealIP))
}

// LimitByEndpoint is a shortcut for [Limit] with the key function set to [KeyByEndpoint],
// returning a new [net/http] middleware that limits requests by endpoint.
func LimitByEndpoint(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler {
return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByEndpoint))
}

// Key returns a key function that always returns the specified key.
func Key(key string) func(r *http.Request) (string, error) {
return func(r *http.Request) (string, error) {
return key, nil
}
}

// KeyByIP uses the canonicalized remote address, [net/http.Request.RemoteAddr],
// to get the IP address.
func KeyByIP(r *http.Request) (string, error) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
Expand All @@ -49,6 +80,9 @@ func KeyByIP(r *http.Request) (string, error) {
return canonicalizeIP(ip), nil
}

// KeyByRealIP uses the "True-Client-IP", "X-Real-IP", and "X-Forwarded-For"
// headers (in that order of precedence) to get the IP address, after canonicalizing.
// If none of the headers are present, the remote address is used.
func KeyByRealIP(r *http.Request) (string, error) {
var ip string

Expand All @@ -73,10 +107,12 @@ func KeyByRealIP(r *http.Request) (string, error) {
return canonicalizeIP(ip), nil
}

// KeyByEndpoint uses the URL path, [net/url.URL.Path] as the key.
func KeyByEndpoint(r *http.Request) (string, error) {
return r.URL.Path, nil
}

// WithKeyFuncs composes multiple key functions into a single key.
func WithKeyFuncs(keyFuncs ...KeyFunc) Option {
return func(rl *RateLimiter) {
if len(keyFuncs) > 0 {
Expand All @@ -85,42 +121,76 @@ func WithKeyFuncs(keyFuncs ...KeyFunc) Option {
}
}

// WithKeyByIP is an option which sets the key function to [KeyByIP].
func WithKeyByIP() Option {
return WithKeyFuncs(KeyByIP)
}

// WithKeyByRealIP is an option which sets the key function to [KeyByRealIP].
func WithKeyByRealIP() Option {
return WithKeyFuncs(KeyByRealIP)
}

// WithKeyByEndpoint is an option which sets the key function to [KeyByEndpoint].
func WithKeyByEndpoint() Option {
return WithKeyFuncs(KeyByEndpoint)
}

// WithLimitHandler is an option which sets the limit handler to the given
// [http.HandlerFunc]. If not set, the default limit handler is used.
func WithLimitHandler(h http.HandlerFunc) Option {
return func(rl *RateLimiter) {
rl.onRateLimited = h
}
}

// WithErrorHandler is an option which sets the error handler to the given
// function. If not set, the default error handler is used.
func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option {
return func(rl *RateLimiter) {
rl.onError = h
}
}

// WithLimitCounter is an option which sets the limit counter to the given
// [LimitCounter]. If not set, the default [LocalLimitCounter] is used.
func WithLimitCounter(c LimitCounter) Option {
return func(rl *RateLimiter) {
rl.limitCounter = c
}
}

// WithResponseHeaders is an option which sets the response headers to the given
// [ResponseHeaders]. If not set, the default response headers are used.
func WithResponseHeaders(headers ResponseHeaders) Option {
return func(rl *RateLimiter) {
rl.headers = headers
}
}

// WithNoop is an option which does nothing.
func WithNoop() Option {
return func(rl *RateLimiter) {}
}

// Skip is a middleware that allows the rate limiter headers to be applied onto a
// request, without actually including the request in the rate limit. Use this for
// endpoints that can be used for checking the rate limit, without affecting the
// rate limit. NOTE: This MUST be loaded in your middleware stack before the rate
// limiter.
//
// Example:
//
// rl := httprate.Limit(100, time.Minute)
// r.With(rl).Get(...) // Will be rate limited.
// r.With(httprate.Skip, rl).Get(...) // Will not be rate limited, but still sets appropriate headers.
func Skip(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(WithIncrement(r.Context(), 0)))
})
}

// composedKeyFunc composes multiple key functions into a single key.
func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc {
return func(r *http.Request) (string, error) {
var key strings.Builder
Expand Down Expand Up @@ -151,11 +221,10 @@ func canonicalizeIP(ip string) string {
case ':':
// IPv6
isIPv6 = true
break
}
}
if !isIPv6 {
// Not an IP address at all
// Not an IP address at all.
return ip
}

Expand Down
50 changes: 49 additions & 1 deletion httprate_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package httprate

import "testing"
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)

func Test_canonicalizeIP(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -57,3 +62,46 @@ func Test_canonicalizeIP(t *testing.T) {
})
}
}

func TestSkip(t *testing.T) {
window := time.Minute
limit := 3

inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

t.Run("without skip exhausts limit", func(t *testing.T) {
limited := LimitAll(limit, window)(inner)
want := []int{http.StatusOK, http.StatusOK, http.StatusOK, http.StatusTooManyRequests}
for i, wantCode := range want {
rec := httptest.NewRecorder()
limited.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if got := rec.Code; got != wantCode {
t.Fatalf("request %d: status = %d, want %d", i, got, wantCode)
}
}
})

t.Run("with skip does not count toward limit", func(t *testing.T) {
limited := LimitAll(limit, window)(inner)
skipped := Skip(limited)
n := limit + 10
for i := range n {
rec := httptest.NewRecorder()
skipped.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if got := rec.Code; got != http.StatusOK {
t.Fatalf("request %d: status = %d, want %d", i, got, http.StatusOK)
}
}
})

t.Run("with skip still sets rate limit headers", func(t *testing.T) {
limited := LimitAll(limit, window)(inner)
rec := httptest.NewRecorder()
Skip(limited).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if got := rec.Header().Get("X-RateLimit-Limit"); got != "3" {
t.Errorf("X-RateLimit-Limit = %q, want %q", got, "3")
}
})
}
5 changes: 3 additions & 2 deletions limit_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import (
"github.com/zeebo/xxh3"
)

// LimitCounterKey computes a hash key for the given key and window.
func LimitCounterKey(key string, window time.Time) uint64 {
h := xxh3.New()
h.WriteString(key)
h.WriteString(strconv.FormatInt(window.Unix(), 10))
_, _ = h.WriteString(key)
_, _ = h.WriteString(strconv.FormatInt(window.Unix(), 10))
return h.Sum64()
}
10 changes: 10 additions & 0 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@ import (
"time"
)

// LimitCounter is an interface that defines the methods for a rate limit counter.
// It is used to store and retrieve the rate limit counter for a given key and
// window. A default implementation is provided by [NewLocalLimitCounter].
type LimitCounter interface {
Config(requestLimit int, windowLength time.Duration)
Increment(key string, currentWindow time.Time) error
IncrementBy(key string, currentWindow time.Time, amount int) error
Get(key string, currentWindow, previousWindow time.Time) (int, int, error)
}

// NewRateLimiter creates a new [RateLimiter] with the given request limit and
// window length. The returned rate limiter will use the default [LocalLimitCounter]
// implementation, if not overridden.
func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *RateLimiter {
rl := &RateLimiter{
requestLimit: requestLimit,
Expand Down Expand Up @@ -125,10 +131,12 @@ func (l *RateLimiter) RespondOnLimit(w http.ResponseWriter, r *http.Request, key
return onLimit
}

// Counter returns the limit counter used by the rate limiter.
func (l *RateLimiter) Counter() LimitCounter {
return l.limitCounter
}

// Status returns the current status of the rate limiter for the given key.
func (l *RateLimiter) Status(key string) (bool, float64, error) {
return l.calculateRate(key, l.requestLimit)
}
Expand All @@ -149,6 +157,8 @@ func (l *RateLimiter) Handler(next http.Handler) http.Handler {
})
}

// calculateRate calculates the rate for the given key and request limit. It does
// not increment the counter.
func (l *RateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
now := time.Now().UTC()
currentWindow := now.Truncate(l.windowLength)
Expand Down
Loading
Loading