11package httprate
22
33import (
4+ "context"
45 "fmt"
56 "math"
67 "net/http"
@@ -15,6 +16,11 @@ type LimitCounter interface {
1516 Get (key string , previousWindow , currentWindow time.Time ) (int , int , error )
1617}
1718
19+ type ContextLimitCounter interface {
20+ Increment (ctx context.Context , key string , currentWindow time.Time ) error
21+ Get (ctx context.Context , key string , previousWindow , currentWindow time.Time ) (int , int , error )
22+ }
23+
1824func NewRateLimiter (requestLimit int , windowLength time.Duration , options ... Option ) * rateLimiter {
1925 return newRateLimiter (requestLimit , windowLength , options ... )
2026}
@@ -58,24 +64,42 @@ func LimitCounterKey(key string, window time.Time) uint64 {
5864 return h .Sum64 ()
5965}
6066
67+ // limitCounterWrap implements the LimitCounter interface without context.
68+ // Calls ContextLimitCounter with context.Background(), exists to maintain compatibility.
69+ type limitCounterWrap struct {
70+ ContextLimitCounter
71+ }
72+
73+ func (l * limitCounterWrap ) Increment (key string , currentWindow time.Time ) error {
74+ return l .ContextLimitCounter .Increment (context .Background (), key , currentWindow )
75+ }
76+
77+ func (l * limitCounterWrap ) Get (key string , previousWindow , currentWindow time.Time ) (int , int , error ) {
78+ return l .ContextLimitCounter .Get (context .Background (), key , previousWindow , currentWindow )
79+ }
80+
6181type rateLimiter struct {
6282 requestLimit int
6383 windowLength time.Duration
6484 keyFn KeyFunc
65- limitCounter LimitCounter
85+ limitCounter ContextLimitCounter
6686 onRequestLimit http.HandlerFunc
6787}
6888
6989func (r * rateLimiter ) Counter () LimitCounter {
90+ return & limitCounterWrap {ContextLimitCounter : r .limitCounter }
91+ }
92+
93+ func (r * rateLimiter ) ContextCounter () ContextLimitCounter {
7094 return r .limitCounter
7195}
7296
73- func (r * rateLimiter ) Status ( key string ) (bool , float64 , error ) {
97+ func (r * rateLimiter ) ContextStatus ( ctx context. Context , key string ) (bool , float64 , error ) {
7498 t := time .Now ().UTC ()
7599 currentWindow := t .Truncate (r .windowLength )
76100 previousWindow := currentWindow .Add (- r .windowLength )
77101
78- currCount , prevCount , err := r .limitCounter .Get (key , currentWindow , previousWindow )
102+ currCount , prevCount , err := r .limitCounter .Get (ctx , key , currentWindow , previousWindow )
79103 if err != nil {
80104 return false , 0 , err
81105 }
@@ -89,8 +113,14 @@ func (r *rateLimiter) Status(key string) (bool, float64, error) {
89113 return true , rate , nil
90114}
91115
116+ func (r * rateLimiter ) Status (key string ) (bool , float64 , error ) {
117+ return r .ContextStatus (context .Background (), key )
118+ }
119+
92120func (l * rateLimiter ) Handler (next http.Handler ) http.Handler {
93121 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
122+ ctx := r .Context ()
123+
94124 key , err := l .keyFn (r )
95125 if err != nil {
96126 http .Error (w , err .Error (), http .StatusPreconditionRequired )
@@ -120,7 +150,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
120150 return
121151 }
122152
123- err = l .limitCounter .Increment (key , currentWindow )
153+ err = l .limitCounter .Increment (ctx , key , currentWindow )
124154 if err != nil {
125155 http .Error (w , err .Error (), http .StatusInternalServerError )
126156 return
@@ -137,14 +167,17 @@ type localCounter struct {
137167 mu sync.Mutex
138168}
139169
140- var _ LimitCounter = & localCounter {}
170+ var (
171+ _ LimitCounter = & limitCounterWrap {ContextLimitCounter : & localCounter {}}
172+ _ ContextLimitCounter = & localCounter {}
173+ )
141174
142175type count struct {
143176 value int
144177 updatedAt time.Time
145178}
146179
147- func (c * localCounter ) Increment (key string , currentWindow time.Time ) error {
180+ func (c * localCounter ) Increment (_ context. Context , key string , currentWindow time.Time ) error {
148181 c .evict ()
149182
150183 c .mu .Lock ()
@@ -163,7 +196,7 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error {
163196 return nil
164197}
165198
166- func (c * localCounter ) Get (key string , currentWindow , previousWindow time.Time ) (int , int , error ) {
199+ func (c * localCounter ) Get (_ context. Context , key string , currentWindow , previousWindow time.Time ) (int , int , error ) {
167200 c .mu .Lock ()
168201 defer c .mu .Unlock ()
169202
0 commit comments