-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware.go
More file actions
148 lines (124 loc) · 3.97 KB
/
middleware.go
File metadata and controls
148 lines (124 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package capacitor
import (
"fmt"
"log/slog"
"net"
"net/http"
"strconv"
)
// KeyFunc extracts the rate-limit key from an incoming request.
type KeyFunc func(r *http.Request) string
// ClassifyFunc determines the rate-limit profile name for a request.
// An empty return value uses the default limiter.
type ClassifyFunc func(r *http.Request) string
// ProfileConfig maps profile names to Capacitor instances.
// Use it with [WithProfiles] and [WithClassifier] to route
// requests to different rate-limiting policies.
type ProfileConfig map[string]Capacitor
// MiddlewareOption configures the HTTP middleware.
type MiddlewareOption func(*middleware)
type middleware struct {
limiter Capacitor
keyFunc KeyFunc
denyHandler http.Handler
profiles ProfileConfig
classifier ClassifyFunc
logger *slog.Logger
}
// WithKeyFunc sets the function used to derive the rate-limit key.
// Defaults to KeyFromRemoteIP.
func WithKeyFunc(fn KeyFunc) MiddlewareOption {
return func(m *middleware) { m.keyFunc = fn }
}
// WithDenyHandler replaces the default 429 response handler.
func WithDenyHandler(h http.Handler) MiddlewareOption {
return func(m *middleware) { m.denyHandler = h }
}
// WithProfiles configures per-profile limiters. Combine with
// [WithClassifier] to route requests to named profiles. Unknown
// or empty profile names fall back to the default limiter.
func WithProfiles(profiles ProfileConfig) MiddlewareOption {
return func(m *middleware) { m.profiles = profiles }
}
// WithClassifier sets the function used to route a request to a
// named rate-limit profile. See [WithProfiles].
func WithClassifier(fn ClassifyFunc) MiddlewareOption {
return func(m *middleware) { m.classifier = fn }
}
func (m *middleware) resolve(r *http.Request) Capacitor {
if m.classifier == nil {
return m.limiter
}
name := m.classifier(r)
if name == "" {
return m.limiter
}
if lim, ok := m.profiles[name]; ok {
return lim
}
return m.limiter
}
// KeyFromRemoteIP extracts the IP from RemoteAddr, stripping the port.
func KeyFromRemoteIP(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
// KeyFromHeader returns a KeyFunc that reads the given header.
func KeyFromHeader(name string) KeyFunc {
return func(r *http.Request) string {
return r.Header.Get(name)
}
}
// NewMiddleware returns standard net/http middleware that rate-limits
// requests using the provided Capacitor.
func NewMiddleware(limiter Capacitor, opts ...MiddlewareOption) func(http.Handler) http.Handler {
m := &middleware{
limiter: limiter,
keyFunc: KeyFromRemoteIP,
denyHandler: http.HandlerFunc(defaultDeny),
logger: slog.Default(),
}
for _, opt := range opts {
opt(m)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := m.keyFunc(r)
if key == "" {
next.ServeHTTP(w, r)
return
}
lim := m.resolve(r)
result, err := lim.Attempt(r.Context(), key)
if err != nil {
m.logger.Warn("rate limiter degraded, using fallback",
"error", err, "key", key)
}
result.writeHeaders(w)
if !result.Allowed {
m.denyHandler.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
}
// writeHeaders sets IETF RateLimit-* headers on the response.
func (r Result) writeHeaders(w http.ResponseWriter) {
h := w.Header()
h.Set("RateLimit-Limit", strconv.FormatInt(r.Limit, 10))
h.Set("RateLimit-Remaining", strconv.FormatInt(r.Remaining, 10))
if r.RetryAfter > 0 {
secs := strconv.FormatInt(int64(r.RetryAfter.Seconds()), 10)
h.Set("RateLimit-Reset", secs)
h.Set("Retry-After", secs)
}
}
func defaultDeny(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = fmt.Fprintln(w, http.StatusText(http.StatusTooManyRequests)) //nolint:errcheck // best-effort write to http.ResponseWriter; error unactionable after WriteHeader
}