Skip to content

Commit b3fedf3

Browse files
Copilotintel352
andauthored
fix: rate limit factory validation, dynamic Retry-After, fractional refill test, schema accuracy (#107)
* Initial plan * fix: validate rate limit config values, dynamic Retry-After, fractional refill test, fix schema description Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: intel352 <77607+intel352@users.noreply.github.com>
1 parent 3289baf commit b3fedf3

5 files changed

Lines changed: 116 additions & 10 deletions

File tree

module/http_middleware.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package module
33
import (
44
"context"
55
"fmt"
6+
"math"
67
"net"
78
"net/http"
9+
"strconv"
810
"strings"
911
"sync"
1012
"time"
@@ -167,7 +169,14 @@ func (m *RateLimitMiddleware) Process(next http.Handler) http.Handler {
167169
// Check if request can proceed
168170
if c.tokens < 1 {
169171
m.mu.Unlock()
170-
w.Header().Set("Retry-After", "60")
172+
// Compute how many seconds until 1 token refills, based on the
173+
// fractional per-minute rate (ratePerMinute tokens/minute).
174+
retryAfter := "60"
175+
if m.ratePerMinute > 0 {
176+
secondsUntilToken := 60.0 / m.ratePerMinute
177+
retryAfter = strconv.Itoa(int(math.Ceil(secondsUntilToken)))
178+
}
179+
w.Header().Set("Retry-After", retryAfter)
171180
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
172181
return
173182
}

module/http_middleware_test.go

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,8 @@ func TestRateLimitMiddleware_RetryAfterHeader(t *testing.T) {
460460
if rec2.Code != http.StatusTooManyRequests {
461461
t.Errorf("expected 429, got %d", rec2.Code)
462462
}
463-
if rec2.Header().Get("Retry-After") != "60" {
464-
t.Errorf("expected Retry-After header '60', got %q", rec2.Header().Get("Retry-After"))
463+
if rec2.Header().Get("Retry-After") != "1" {
464+
t.Errorf("expected Retry-After header '1', got %q", rec2.Header().Get("Retry-After"))
465465
}
466466
}
467467

@@ -640,3 +640,53 @@ func TestNewRateLimitMiddlewareWithHourlyRate_RatePerMinute(t *testing.T) {
640640
t.Errorf("expected ratePerMinute=1.0, got %f", m.ratePerMinute)
641641
}
642642
}
643+
644+
func TestNewRateLimitMiddlewareWithHourlyRate_FractionalRefill(t *testing.T) {
645+
// 3600 requests/hour -> ratePerMinute = 60.0, timePerToken = 1 second.
646+
// Using a high hourly rate keeps the sleep short while still exercising
647+
// the fractional refill path.
648+
m := NewRateLimitMiddlewareWithHourlyRate("rl-hour-fractional", 3600, 1)
649+
650+
handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
651+
w.WriteHeader(http.StatusOK)
652+
}))
653+
654+
// First request should be allowed (uses the single burst token).
655+
req1 := httptest.NewRequest("GET", "/fractional", nil)
656+
rec1 := httptest.NewRecorder()
657+
handler.ServeHTTP(rec1, req1)
658+
if rec1.Code != http.StatusOK {
659+
t.Fatalf("first request: expected 200, got %d", rec1.Code)
660+
}
661+
662+
// Second immediate request must be rate-limited (burst exhausted, no refill yet).
663+
req2 := httptest.NewRequest("GET", "/fractional", nil)
664+
rec2 := httptest.NewRecorder()
665+
handler.ServeHTTP(rec2, req2)
666+
if rec2.Code != http.StatusTooManyRequests {
667+
t.Fatalf("second request: expected 429, got %d", rec2.Code)
668+
}
669+
670+
// Wait slightly longer than the time needed to refill one token.
671+
if m.ratePerMinute <= 0 {
672+
t.Fatalf("ratePerMinute must be positive, got %f", m.ratePerMinute)
673+
}
674+
timePerToken := time.Duration(float64(time.Minute) / m.ratePerMinute)
675+
time.Sleep(timePerToken + 100*time.Millisecond)
676+
677+
// After waiting, exactly one additional request should be allowed.
678+
req3 := httptest.NewRequest("GET", "/fractional", nil)
679+
rec3 := httptest.NewRecorder()
680+
handler.ServeHTTP(rec3, req3)
681+
if rec3.Code != http.StatusOK {
682+
t.Fatalf("third request after refill: expected 200, got %d", rec3.Code)
683+
}
684+
685+
// An immediately following request must still be rate-limited.
686+
req4 := httptest.NewRequest("GET", "/fractional", nil)
687+
rec4 := httptest.NewRecorder()
688+
handler.ServeHTTP(rec4, req4)
689+
if rec4.Code != http.StatusTooManyRequests {
690+
t.Fatalf("fourth request after refill: expected 429, got %d", rec4.Code)
691+
}
692+
}

plugins/http/modules.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,24 +135,36 @@ func loggingMiddlewareFactory(name string, cfg map[string]any) modular.Module {
135135
func rateLimitMiddlewareFactory(name string, cfg map[string]any) modular.Module {
136136
burstSize := 10
137137
if bs, ok := cfg["burstSize"].(int); ok {
138-
burstSize = bs
138+
if bs > 0 {
139+
burstSize = bs
140+
}
139141
} else if bs, ok := cfg["burstSize"].(float64); ok {
140-
burstSize = int(bs)
142+
if intBS := int(bs); intBS > 0 {
143+
burstSize = intBS
144+
}
141145
}
142146

143147
// requestsPerHour takes precedence over requestsPerMinute for low-frequency
144148
// endpoints (e.g. registration) where fractional per-minute rates are needed.
145149
if rph, ok := cfg["requestsPerHour"].(int); ok {
146-
return module.NewRateLimitMiddlewareWithHourlyRate(name, rph, burstSize)
150+
if rph > 0 {
151+
return module.NewRateLimitMiddlewareWithHourlyRate(name, rph, burstSize)
152+
}
147153
} else if rph, ok := cfg["requestsPerHour"].(float64); ok {
148-
return module.NewRateLimitMiddlewareWithHourlyRate(name, int(rph), burstSize)
154+
if intRPH := int(rph); intRPH > 0 {
155+
return module.NewRateLimitMiddlewareWithHourlyRate(name, intRPH, burstSize)
156+
}
149157
}
150158

151159
requestsPerMinute := 60
152160
if rpm, ok := cfg["requestsPerMinute"].(int); ok {
153-
requestsPerMinute = rpm
161+
if rpm > 0 {
162+
requestsPerMinute = rpm
163+
}
154164
} else if rpm, ok := cfg["requestsPerMinute"].(float64); ok {
155-
requestsPerMinute = int(rpm)
165+
if intRPM := int(rpm); intRPM > 0 {
166+
requestsPerMinute = intRPM
167+
}
156168
}
157169
return module.NewRateLimitMiddleware(name, requestsPerMinute, burstSize)
158170
}

plugins/http/plugin_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,41 @@ func TestRateLimitMiddlewareFactory_RequestsPerHour(t *testing.T) {
356356
}
357357
}
358358

359+
func TestRateLimitMiddlewareFactory_InvalidValues(t *testing.T) {
360+
factories := moduleFactories()
361+
factory, ok := factories["http.middleware.ratelimit"]
362+
if !ok {
363+
t.Fatal("no factory for http.middleware.ratelimit")
364+
}
365+
366+
// Zero requestsPerHour must fall through to requestsPerMinute path (not crash).
367+
modZeroRPH := factory("rl-zero-rph", map[string]any{
368+
"requestsPerHour": 0,
369+
"requestsPerMinute": 30,
370+
"burstSize": 5,
371+
})
372+
if modZeroRPH == nil {
373+
t.Fatal("factory returned nil for zero requestsPerHour config")
374+
}
375+
376+
// Negative requestsPerMinute must use default (60).
377+
modNegRPM := factory("rl-neg-rpm", map[string]any{
378+
"requestsPerMinute": -10,
379+
})
380+
if modNegRPM == nil {
381+
t.Fatal("factory returned nil for negative requestsPerMinute config")
382+
}
383+
384+
// Zero burstSize must keep default (10).
385+
modZeroBurst := factory("rl-zero-burst", map[string]any{
386+
"requestsPerMinute": 60,
387+
"burstSize": 0,
388+
})
389+
if modZeroBurst == nil {
390+
t.Fatal("factory returned nil for zero burstSize config")
391+
}
392+
}
393+
359394
func TestPluginLoaderIntegration(t *testing.T) {
360395
p := New()
361396

plugins/http/schemas.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ func rateLimitMiddlewareSchema() *schema.ModuleSchema {
161161
ConfigFields: []schema.ConfigFieldDef{
162162
{Key: "requestsPerMinute", Label: "Requests Per Minute", Type: schema.FieldTypeNumber, DefaultValue: 60, Description: "Maximum number of requests per minute per client (mutually exclusive with requestsPerHour)"},
163163
{Key: "requestsPerHour", Label: "Requests Per Hour", Type: schema.FieldTypeNumber, DefaultValue: 0, Description: "Maximum number of requests per hour per client; takes precedence over requestsPerMinute when set"},
164-
{Key: "burstSize", Label: "Burst Size", Type: schema.FieldTypeNumber, DefaultValue: 10, Description: "Maximum burst of requests allowed above the rate limit"},
164+
{Key: "burstSize", Label: "Burst Size", Type: schema.FieldTypeNumber, DefaultValue: 10, Description: "Maximum number of tokens in the bucket; determines how many requests can burst when the bucket is full"},
165165
},
166166
DefaultConfig: map[string]any{"requestsPerMinute": 60, "burstSize": 10},
167167
}

0 commit comments

Comments
 (0)