From 9b4e20c294ef6b9a4cea52a015188c020ffc7958 Mon Sep 17 00:00:00 2001 From: Eshaan Agrawal Date: Sun, 31 May 2026 23:10:53 +0530 Subject: [PATCH] fix(ratelimit): prune expired endpoint identities --- backend/secuscan/ratelimit.py | 21 +++++++ .../unit/test_endpoint_rate_limiter.py | 60 +++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/backend/secuscan/ratelimit.py b/backend/secuscan/ratelimit.py index ed231d61..ed3df69f 100644 --- a/backend/secuscan/ratelimit.py +++ b/backend/secuscan/ratelimit.py @@ -167,14 +167,34 @@ def __init__(self, bucket_name: str, limit: int, window_seconds: int): self.limit = limit self.window_seconds = window_seconds self.history: Dict[str, List[datetime]] = defaultdict(list) + self.last_cleanup: datetime | None = None self.lock = asyncio.Lock() + def _cleanup_expired_identities(self, cutoff: datetime, now: datetime): + cleanup_interval = timedelta(seconds=max(1, self.window_seconds)) + if self.last_cleanup and now - self.last_cleanup < cleanup_interval: + return + + expired_identities = [] + for identity, timestamps in self.history.items(): + active_timestamps = [ts for ts in timestamps if ts > cutoff] + if active_timestamps: + self.history[identity] = active_timestamps + else: + expired_identities.append(identity) + + for identity in expired_identities: + self.history.pop(identity, None) + + self.last_cleanup = now + async def __call__(self, request: Request, response: Response): identity = resolve_client_identity(request) async with self.lock: now = datetime.now() cutoff = now - timedelta(seconds=self.window_seconds) + self._cleanup_expired_identities(cutoff, now) # Filter history to keep only timestamps within the sliding window self.history[identity] = [ts for ts in self.history[identity] if ts > cutoff] @@ -214,6 +234,7 @@ async def reset(self): """Clear all rate limiting history for this bucket.""" async with self.lock: self.history.clear() + self.last_cleanup = None # Global instances diff --git a/testing/backend/unit/test_endpoint_rate_limiter.py b/testing/backend/unit/test_endpoint_rate_limiter.py index afc2cfa9..f45aa222 100644 --- a/testing/backend/unit/test_endpoint_rate_limiter.py +++ b/testing/backend/unit/test_endpoint_rate_limiter.py @@ -100,6 +100,66 @@ def __init__(self): assert res.headers["X-RateLimit-Remaining"] == "1" +@pytest.mark.asyncio +async def test_endpoint_rate_limiter_prunes_expired_identity_buckets(): + """Expired identities should not stay resident forever.""" + limiter = EndpointRateLimiter("test_bucket", limit=5, window_seconds=10) + await limiter.reset() + + class MockRequest: + def __init__(self, user_id): + self.client = type("Client", (), {"host": "127.0.0.1"})() + self.headers = {"x-user-id": user_id} + self.state = type("State", (), {})() + + class MockResponse: + def __init__(self): + self.headers = {} + + now = datetime.now() + async with limiter.lock: + limiter.history["user:expired_a"] = [now - timedelta(seconds=30)] + limiter.history["user:expired_b"] = [now - timedelta(seconds=20)] + limiter.history["user:active"] = [now - timedelta(seconds=2)] + limiter.last_cleanup = now - timedelta(seconds=11) + + await limiter(MockRequest("current"), MockResponse()) + + async with limiter.lock: + assert "user:expired_a" not in limiter.history + assert "user:expired_b" not in limiter.history + assert "user:active" in limiter.history + assert "user:current" in limiter.history + + +@pytest.mark.asyncio +async def test_endpoint_rate_limiter_cleanup_is_interval_bounded(): + """Cleanup should not scan every request inside the cleanup interval.""" + limiter = EndpointRateLimiter("test_bucket", limit=5, window_seconds=10) + await limiter.reset() + + class MockRequest: + def __init__(self, user_id): + self.client = type("Client", (), {"host": "127.0.0.1"})() + self.headers = {"x-user-id": user_id} + self.state = type("State", (), {})() + + class MockResponse: + def __init__(self): + self.headers = {} + + now = datetime.now() + async with limiter.lock: + limiter.history["user:expired"] = [now - timedelta(seconds=30)] + limiter.last_cleanup = now + + await limiter(MockRequest("current"), MockResponse()) + + async with limiter.lock: + assert "user:expired" in limiter.history + assert "user:current" in limiter.history + + def test_priority_client_identity_resolution(): """ Verify client identity resolves correctly in priority order: