Skip to content

Commit c7b5909

Browse files
committed
Fixed missing import
1 parent 40f2c72 commit c7b5909

1 file changed

Lines changed: 34 additions & 25 deletions

File tree

backend/middleware/auth_middleware.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import logging
7+
import os
78
from functools import wraps
89
from typing import Any, Callable, Dict, Optional, Tuple
910

@@ -208,16 +209,18 @@ def __init__(self, requests_per_minute: int = 100):
208209
self.requests_per_minute = requests_per_minute
209210
self.user_requests = {} # In production, use Redis
210211
self.blocked_users = set() # Temporarily blocked users
211-
self.rate_limit_enabled = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
212-
212+
self.rate_limit_enabled = (
213+
os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
214+
)
215+
213216
# Different limits for different operations
214217
self.endpoint_limits = {
215218
"auth": 20, # Auth operations
216219
"projects": 50, # Project operations
217220
"chat": 30, # Chat operations
218-
"default": requests_per_minute
221+
"default": requests_per_minute,
219222
}
220-
223+
221224
logger.info(
222225
f"RateLimitMiddleware initialized with {requests_per_minute} requests/minute"
223226
)
@@ -233,70 +236,76 @@ def _get_endpoint_category(self, path: str) -> str:
233236
else:
234237
return "default"
235238

236-
async def check_rate_limit(self, user_id: str, endpoint_path: str = "") -> Tuple[bool, Dict[str, Any]]:
239+
async def check_rate_limit(
240+
self, user_id: str, endpoint_path: str = ""
241+
) -> Tuple[bool, Dict[str, Any]]:
237242
"""Check if user has exceeded rate limit"""
238243
if not self.rate_limit_enabled:
239244
return True, {}
240-
245+
241246
# Check if user is temporarily blocked
242247
if user_id in self.blocked_users:
243248
return False, {
244249
"reason": "Temporarily blocked due to excessive requests",
245-
"retry_after": 300 # 5 minutes
250+
"retry_after": 300, # 5 minutes
246251
}
247-
252+
248253
# Get appropriate limit for endpoint
249254
category = self._get_endpoint_category(endpoint_path)
250255
limit = self.endpoint_limits.get(category, self.endpoint_limits["default"])
251-
256+
252257
# Get current time window
253258
import time
259+
254260
current_time = time.time()
255261
window_start = int(current_time // 60) * 60 # Start of current minute
256-
262+
257263
# Initialize user request tracking
258264
if user_id not in self.user_requests:
259265
self.user_requests[user_id] = {}
260-
266+
261267
# Clean old windows (keep last 2 minutes for analysis)
262268
user_windows = self.user_requests[user_id]
263269
old_windows = [w for w in user_windows.keys() if w < window_start - 120]
264270
for old_window in old_windows:
265271
del user_windows[old_window]
266-
272+
267273
# Count requests in current window
268274
current_requests = user_windows.get(window_start, 0)
269-
275+
270276
if current_requests >= limit:
271277
# Check if user should be temporarily blocked
272278
recent_requests = sum(user_windows.values())
273279
if recent_requests >= limit * 3: # 3x the limit across windows
274280
self.blocked_users.add(user_id)
275-
logger.warning(f"User {user_id} temporarily blocked for excessive requests")
281+
logger.warning(
282+
f"User {user_id} temporarily blocked for excessive requests"
283+
)
276284
return False, {
277285
"reason": "Temporarily blocked due to excessive requests",
278-
"retry_after": 300
286+
"retry_after": 300,
279287
}
280-
288+
281289
return False, {
282290
"reason": "Rate limit exceeded",
283291
"limit": limit,
284292
"current": current_requests,
285-
"retry_after": 60
293+
"retry_after": 60,
286294
}
287-
295+
288296
# Record this request
289297
user_windows[window_start] = current_requests + 1
290-
298+
291299
return True, {
292300
"limit": limit,
293301
"current": current_requests + 1,
294-
"remaining": limit - current_requests - 1
302+
"remaining": limit - current_requests - 1,
295303
}
296304

297305
async def apply_rate_limit(
298-
self, current_user: Optional[UserInDB] = Depends(get_current_user_optional),
299-
request: Request = None
306+
self,
307+
current_user: Optional[UserInDB] = Depends(get_current_user_optional),
308+
request: Request = None,
300309
) -> bool:
301310
"""Apply rate limiting based on user"""
302311
if not current_user:
@@ -306,14 +315,14 @@ async def apply_rate_limit(
306315

307316
endpoint_path = str(request.url.path) if request else ""
308317
allowed, info = await self.check_rate_limit(str(current_user.id), endpoint_path)
309-
318+
310319
if not allowed:
311320
raise HTTPException(
312321
status_code=429,
313322
detail=info.get("reason", "Rate limit exceeded"),
314-
headers={"Retry-After": str(info.get("retry_after", 60))}
323+
headers={"Retry-After": str(info.get("retry_after", 60))},
315324
)
316-
325+
317326
return True
318327

319328

0 commit comments

Comments
 (0)