44"""
55
66import logging
7+ import os
78from functools import wraps
89from typing import Any , Callable , Dict , Optional , Tuple
910
@@ -208,16 +209,18 @@ def __init__(self, requests_per_minute: int = 100):
208209 self .requests_per_minute = requests_per_minute
209210 self .user_requests = {} # In production, use Redis
210211 self .blocked_users = set () # Temporarily blocked users
211- self .rate_limit_enabled = os .getenv ("RATE_LIMIT_ENABLED" , "true" ).lower () == "true"
212-
212+ self .rate_limit_enabled = (
213+ os .getenv ("RATE_LIMIT_ENABLED" , "true" ).lower () == "true"
214+ )
215+
213216 # Different limits for different operations
214217 self .endpoint_limits = {
215218 "auth" : 20 , # Auth operations
216219 "projects" : 50 , # Project operations
217220 "chat" : 30 , # Chat operations
218- "default" : requests_per_minute
221+ "default" : requests_per_minute ,
219222 }
220-
223+
221224 logger .info (
222225 f"RateLimitMiddleware initialized with { requests_per_minute } requests/minute"
223226 )
@@ -233,70 +236,76 @@ def _get_endpoint_category(self, path: str) -> str:
233236 else :
234237 return "default"
235238
236- async def check_rate_limit (self , user_id : str , endpoint_path : str = "" ) -> Tuple [bool , Dict [str , Any ]]:
239+ async def check_rate_limit (
240+ self , user_id : str , endpoint_path : str = ""
241+ ) -> Tuple [bool , Dict [str , Any ]]:
237242 """Check if user has exceeded rate limit"""
238243 if not self .rate_limit_enabled :
239244 return True , {}
240-
245+
241246 # Check if user is temporarily blocked
242247 if user_id in self .blocked_users :
243248 return False , {
244249 "reason" : "Temporarily blocked due to excessive requests" ,
245- "retry_after" : 300 # 5 minutes
250+ "retry_after" : 300 , # 5 minutes
246251 }
247-
252+
248253 # Get appropriate limit for endpoint
249254 category = self ._get_endpoint_category (endpoint_path )
250255 limit = self .endpoint_limits .get (category , self .endpoint_limits ["default" ])
251-
256+
252257 # Get current time window
253258 import time
259+
254260 current_time = time .time ()
255261 window_start = int (current_time // 60 ) * 60 # Start of current minute
256-
262+
257263 # Initialize user request tracking
258264 if user_id not in self .user_requests :
259265 self .user_requests [user_id ] = {}
260-
266+
261267 # Clean old windows (keep last 2 minutes for analysis)
262268 user_windows = self .user_requests [user_id ]
263269 old_windows = [w for w in user_windows .keys () if w < window_start - 120 ]
264270 for old_window in old_windows :
265271 del user_windows [old_window ]
266-
272+
267273 # Count requests in current window
268274 current_requests = user_windows .get (window_start , 0 )
269-
275+
270276 if current_requests >= limit :
271277 # Check if user should be temporarily blocked
272278 recent_requests = sum (user_windows .values ())
273279 if recent_requests >= limit * 3 : # 3x the limit across windows
274280 self .blocked_users .add (user_id )
275- logger .warning (f"User { user_id } temporarily blocked for excessive requests" )
281+ logger .warning (
282+ f"User { user_id } temporarily blocked for excessive requests"
283+ )
276284 return False , {
277285 "reason" : "Temporarily blocked due to excessive requests" ,
278- "retry_after" : 300
286+ "retry_after" : 300 ,
279287 }
280-
288+
281289 return False , {
282290 "reason" : "Rate limit exceeded" ,
283291 "limit" : limit ,
284292 "current" : current_requests ,
285- "retry_after" : 60
293+ "retry_after" : 60 ,
286294 }
287-
295+
288296 # Record this request
289297 user_windows [window_start ] = current_requests + 1
290-
298+
291299 return True , {
292300 "limit" : limit ,
293301 "current" : current_requests + 1 ,
294- "remaining" : limit - current_requests - 1
302+ "remaining" : limit - current_requests - 1 ,
295303 }
296304
297305 async def apply_rate_limit (
298- self , current_user : Optional [UserInDB ] = Depends (get_current_user_optional ),
299- request : Request = None
306+ self ,
307+ current_user : Optional [UserInDB ] = Depends (get_current_user_optional ),
308+ request : Request = None ,
300309 ) -> bool :
301310 """Apply rate limiting based on user"""
302311 if not current_user :
@@ -306,14 +315,14 @@ async def apply_rate_limit(
306315
307316 endpoint_path = str (request .url .path ) if request else ""
308317 allowed , info = await self .check_rate_limit (str (current_user .id ), endpoint_path )
309-
318+
310319 if not allowed :
311320 raise HTTPException (
312321 status_code = 429 ,
313322 detail = info .get ("reason" , "Rate limit exceeded" ),
314- headers = {"Retry-After" : str (info .get ("retry_after" , 60 ))}
323+ headers = {"Retry-After" : str (info .get ("retry_after" , 60 ))},
315324 )
316-
325+
317326 return True
318327
319328
0 commit comments