Skip to content

Commit 76e9da0

Browse files
committed
Task B6: Completed JWT Token Management with Security
1 parent 15c9428 commit 76e9da0

3 files changed

Lines changed: 144 additions & 20 deletions

File tree

backend/api/auth.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
from fastapi import APIRouter, Depends, HTTPException
77
from fastapi.security import HTTPBearer
88

9-
from models.response_schemas import ApiResponse, AuthResponse, LoginRequest, User
9+
from models.response_schemas import (
10+
ApiResponse,
11+
AuthResponse,
12+
LoginRequest,
13+
RefreshTokenRequest,
14+
User,
15+
)
1016
from models.user import UserInDB
1117
from services.auth_service import AuthService
1218

@@ -116,15 +122,15 @@ async def get_current_user(
116122

117123
@router.post("/logout")
118124
async def logout(token: str = Depends(get_current_user_token)) -> ApiResponse[dict]:
119-
"""Logout current user with enhanced logging"""
125+
"""Logout current user with enhanced logging and token blacklisting"""
120126
try:
121127
logger.info("Received logout request")
122128

123129
# Verify token and get user for logging
124130
user = auth_service.get_current_user(token)
125131

126-
# Revoke tokens (placeholder implementation)
127-
success = auth_service.revoke_user_tokens(str(user.id))
132+
# Revoke tokens with proper blacklisting
133+
success = auth_service.revoke_user_tokens(str(user.id), access_token=token)
128134

129135
if success:
130136
logger.info(f"Logout successful for user: {user.email}")
@@ -148,19 +154,18 @@ async def logout(token: str = Depends(get_current_user_token)) -> ApiResponse[di
148154

149155

150156
@router.post("/refresh")
151-
async def refresh_token(request: dict) -> ApiResponse[AuthResponse]:
157+
async def refresh_token(request: RefreshTokenRequest) -> ApiResponse[AuthResponse]:
152158
"""Refresh access token with enhanced validation"""
153159
try:
154160
logger.info("Received token refresh request")
155161

156162
# Validate request
157-
refresh_token = request.get("refresh_token")
158-
if not refresh_token or not refresh_token.strip():
163+
if not request.refresh_token or not request.refresh_token.strip():
159164
logger.warning("Empty refresh token received")
160165
raise HTTPException(status_code=400, detail="Refresh token is required")
161166

162167
new_access_token, user = auth_service.refresh_access_token(
163-
refresh_token.strip()
168+
request.refresh_token.strip()
164169
)
165170

166171
# Convert to response format
@@ -176,7 +181,7 @@ async def refresh_token(request: dict) -> ApiResponse[AuthResponse]:
176181
auth_response = AuthResponse(
177182
user=user_response,
178183
access_token=new_access_token,
179-
refresh_token=refresh_token, # Keep the same refresh token
184+
refresh_token=request.refresh_token, # Keep the same refresh token
180185
expires_in=auth_service.access_token_expire_minutes * 60,
181186
)
182187

backend/services/auth_service.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import uuid
44
from datetime import datetime, timedelta
5-
from typing import Dict, Optional, Tuple
5+
from typing import Dict, Optional, Set, Tuple
66

77
import jwt
88
from google.auth.exceptions import GoogleAuthError
@@ -24,6 +24,11 @@ class TokenData(BaseModel):
2424
user_id: str
2525
email: str
2626
exp: datetime
27+
jti: Optional[str] = None # JWT ID for token tracking
28+
29+
30+
# In-memory token blacklist (in production, use Redis)
31+
_token_blacklist: Set[str] = set()
2732

2833

2934
class AuthService:
@@ -52,31 +57,44 @@ def __init__(self):
5257
logger.info(f"Mock auth enabled: {self.enable_mock_auth}")
5358

5459
def create_access_token(self, user_id: str, email: str) -> str:
55-
"""Create JWT access token"""
60+
"""Create JWT access token with unique JWT ID"""
5661
expire = datetime.utcnow() + timedelta(minutes=self.access_token_expire_minutes)
62+
jti = str(uuid.uuid4()) # Unique token identifier
5763
to_encode = {
5864
"sub": user_id,
5965
"email": email,
6066
"exp": expire,
6167
"iat": datetime.utcnow(),
6268
"type": "access",
69+
"jti": jti,
6370
}
6471
return jwt.encode(to_encode, self.jwt_secret, algorithm=self.algorithm)
6572

6673
def create_refresh_token(self, user_id: str, email: str) -> str:
67-
"""Create JWT refresh token"""
74+
"""Create JWT refresh token with unique JWT ID"""
6875
expire = datetime.utcnow() + timedelta(days=self.refresh_token_expire_days)
76+
jti = str(uuid.uuid4()) # Unique token identifier
6977
to_encode = {
7078
"sub": user_id,
7179
"email": email,
7280
"exp": expire,
7381
"iat": datetime.utcnow(),
7482
"type": "refresh",
83+
"jti": jti,
7584
}
7685
return jwt.encode(to_encode, self.jwt_secret, algorithm=self.algorithm)
7786

87+
def _is_token_blacklisted(self, jti: str) -> bool:
88+
"""Check if token is blacklisted"""
89+
return jti in _token_blacklist
90+
91+
def _blacklist_token(self, jti: str) -> None:
92+
"""Add token to blacklist"""
93+
_token_blacklist.add(jti)
94+
logger.info(f"Token blacklisted: {jti}")
95+
7896
def verify_token(self, token: str, token_type: str = "access") -> TokenData:
79-
"""Verify JWT token and return token data"""
97+
"""Verify JWT token and return token data with blacklist check"""
8098
try:
8199
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.algorithm])
82100

@@ -92,6 +110,11 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData:
92110
):
93111
raise jwt.InvalidTokenError("Token has expired")
94112

113+
# Check if token is blacklisted
114+
jti = payload.get("jti")
115+
if jti and self._is_token_blacklisted(jti):
116+
raise jwt.InvalidTokenError("Token has been revoked")
117+
95118
return TokenData(
96119
user_id=payload.get("sub"),
97120
email=payload.get("email"),
@@ -100,6 +123,7 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData:
100123
if exp_timestamp
101124
else datetime.utcnow()
102125
),
126+
jti=jti,
103127
)
104128
except jwt.ExpiredSignatureError:
105129
raise jwt.InvalidTokenError("Token has expired")
@@ -316,15 +340,44 @@ def get_current_user(self, access_token: str) -> UserInDB:
316340
logger.error(f"Get current user failed: {str(e)}")
317341
raise
318342

319-
def revoke_user_tokens(self, user_id: str) -> bool:
343+
def revoke_token_by_jti(self, jti: str) -> bool:
344+
"""Revoke a specific token by its JWT ID"""
345+
if not jti:
346+
logger.warning("Attempted to revoke token without JTI")
347+
return False
348+
349+
self._blacklist_token(jti)
350+
return True
351+
352+
def revoke_user_tokens(
353+
self, user_id: str, access_token: Optional[str] = None
354+
) -> bool:
320355
"""
321-
Revoke all tokens for a user (logout)
322-
Note: With JWT, we can't actually revoke tokens server-side without a blacklist.
323-
This is a placeholder for future token blacklist implementation.
356+
Revoke user tokens (logout with proper token blacklisting)
357+
In a production system, you would query all active tokens for the user.
358+
For now, we blacklist the current access token if provided.
324359
"""
325360
logger.info(f"Token revocation requested for user: {user_id}")
326-
# In a production system, you would add the user's tokens to a blacklist
327-
# For now, we just return True as logout is handled client-side
361+
362+
if access_token:
363+
try:
364+
# Verify the token to get its JTI before blacklisting
365+
token_data = self.verify_token(access_token, token_type="access")
366+
if token_data.jti:
367+
self._blacklist_token(token_data.jti)
368+
logger.info(f"Successfully revoked token for user: {user_id}")
369+
return True
370+
else:
371+
logger.warning(f"Token missing JTI for user: {user_id}")
372+
return False
373+
except jwt.InvalidTokenError as e:
374+
logger.warning(
375+
f"Invalid token during revocation for user {user_id}: {str(e)}"
376+
)
377+
return False
378+
379+
# If no token provided, still consider it successful
380+
# (client-side logout without server-side token invalidation)
328381
return True
329382

330383
def validate_google_client_configuration(self) -> Dict[str, any]:
@@ -349,6 +402,14 @@ def validate_google_client_configuration(self) -> Dict[str, any]:
349402

350403
return config_status
351404

405+
def get_blacklist_stats(self) -> Dict[str, any]:
406+
"""Get token blacklist statistics"""
407+
return {
408+
"blacklisted_tokens": len(_token_blacklist),
409+
"implementation": "in_memory",
410+
"note": "In production, use Redis for distributed blacklist",
411+
}
412+
352413
def health_check(self) -> Dict[str, any]:
353414
"""Enhanced health check for auth service"""
354415
try:
@@ -383,6 +444,7 @@ def health_check(self) -> Dict[str, any]:
383444
"jwt_working": jwt_working,
384445
"user_service": user_health,
385446
"google_oauth": google_config,
447+
"token_blacklist": self.get_blacklist_stats(),
386448
"environment": self.environment,
387449
"access_token_expire_minutes": self.access_token_expire_minutes,
388450
"refresh_token_expire_days": self.refresh_token_expire_days,

backend/tests/test_auth_service.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def test_create_access_token(self, auth_service):
6060
assert payload["sub"] == user_id
6161
assert payload["email"] == email
6262
assert payload["type"] == "access"
63+
assert "jti" in payload
64+
assert payload["jti"] is not None
6365

6466
def test_create_refresh_token(self, auth_service):
6567
"""Test refresh token creation"""
@@ -75,6 +77,8 @@ def test_create_refresh_token(self, auth_service):
7577
assert payload["sub"] == user_id
7678
assert payload["email"] == email
7779
assert payload["type"] == "refresh"
80+
assert "jti" in payload
81+
assert payload["jti"] is not None
7882

7983
def test_verify_access_token_success(self, auth_service):
8084
"""Test successful access token verification"""
@@ -87,6 +91,7 @@ def test_verify_access_token_success(self, auth_service):
8791
assert isinstance(token_data, TokenData)
8892
assert token_data.user_id == user_id
8993
assert token_data.email == email
94+
assert token_data.jti is not None
9095

9196
def test_verify_refresh_token_success(self, auth_service):
9297
"""Test successful refresh token verification"""
@@ -99,6 +104,7 @@ def test_verify_refresh_token_success(self, auth_service):
99104
assert isinstance(token_data, TokenData)
100105
assert token_data.user_id == user_id
101106
assert token_data.email == email
107+
assert token_data.jti is not None
102108

103109
def test_verify_invalid_token(self, auth_service):
104110
"""Test invalid token verification"""
@@ -403,6 +409,57 @@ def test_get_current_user_inactive(self, auth_service, sample_user):
403409
auth_service.get_current_user(access_token)
404410

405411
def test_revoke_user_tokens(self, auth_service):
406-
"""Test token revocation (placeholder implementation)"""
412+
"""Test token revocation with proper blacklisting"""
413+
# Test without access token
407414
result = auth_service.revoke_user_tokens("test_user_123")
408415
assert result is True
416+
417+
# Test with access token
418+
access_token = auth_service.create_access_token(
419+
"test_user_123", "test@example.com"
420+
)
421+
result = auth_service.revoke_user_tokens(
422+
"test_user_123", access_token=access_token
423+
)
424+
assert result is True
425+
426+
# Verify token is now blacklisted
427+
with pytest.raises(jwt.InvalidTokenError, match="Token has been revoked"):
428+
auth_service.verify_token(access_token)
429+
430+
def test_revoke_token_by_jti(self, auth_service):
431+
"""Test token revocation by JWT ID"""
432+
# Test with valid JTI
433+
result = auth_service.revoke_token_by_jti("test_jti_123")
434+
assert result is True
435+
436+
# Test with empty JTI
437+
result = auth_service.revoke_token_by_jti("")
438+
assert result is False
439+
440+
def test_token_blacklisting(self, auth_service):
441+
"""Test comprehensive token blacklisting functionality"""
442+
user_id = "test_user_123"
443+
email = "test@example.com"
444+
445+
# Create token
446+
token = auth_service.create_access_token(user_id, email)
447+
448+
# Verify token works initially
449+
token_data = auth_service.verify_token(token)
450+
assert token_data.user_id == user_id
451+
assert token_data.jti is not None
452+
453+
# Blacklist the token
454+
auth_service._blacklist_token(token_data.jti)
455+
456+
# Verify token is now rejected
457+
with pytest.raises(jwt.InvalidTokenError, match="Token has been revoked"):
458+
auth_service.verify_token(token)
459+
460+
def test_get_blacklist_stats(self, auth_service):
461+
"""Test blacklist statistics"""
462+
stats = auth_service.get_blacklist_stats()
463+
assert "blacklisted_tokens" in stats
464+
assert "implementation" in stats
465+
assert stats["implementation"] == "in_memory"

0 commit comments

Comments
 (0)