22import os
33import uuid
44from datetime import datetime , timedelta
5- from typing import Dict , Optional , Tuple
5+ from typing import Dict , Optional , Set , Tuple
66
77import jwt
88from google .auth .exceptions import GoogleAuthError
@@ -24,6 +24,11 @@ class TokenData(BaseModel):
2424 user_id : str
2525 email : str
2626 exp : datetime
27+ jti : Optional [str ] = None # JWT ID for token tracking
28+
29+
30+ # In-memory token blacklist (in production, use Redis)
31+ _token_blacklist : Set [str ] = set ()
2732
2833
2934class AuthService :
@@ -52,31 +57,44 @@ def __init__(self):
5257 logger .info (f"Mock auth enabled: { self .enable_mock_auth } " )
5358
5459 def create_access_token (self , user_id : str , email : str ) -> str :
55- """Create JWT access token"""
60+ """Create JWT access token with unique JWT ID """
5661 expire = datetime .utcnow () + timedelta (minutes = self .access_token_expire_minutes )
62+ jti = str (uuid .uuid4 ()) # Unique token identifier
5763 to_encode = {
5864 "sub" : user_id ,
5965 "email" : email ,
6066 "exp" : expire ,
6167 "iat" : datetime .utcnow (),
6268 "type" : "access" ,
69+ "jti" : jti ,
6370 }
6471 return jwt .encode (to_encode , self .jwt_secret , algorithm = self .algorithm )
6572
6673 def create_refresh_token (self , user_id : str , email : str ) -> str :
67- """Create JWT refresh token"""
74+ """Create JWT refresh token with unique JWT ID """
6875 expire = datetime .utcnow () + timedelta (days = self .refresh_token_expire_days )
76+ jti = str (uuid .uuid4 ()) # Unique token identifier
6977 to_encode = {
7078 "sub" : user_id ,
7179 "email" : email ,
7280 "exp" : expire ,
7381 "iat" : datetime .utcnow (),
7482 "type" : "refresh" ,
83+ "jti" : jti ,
7584 }
7685 return jwt .encode (to_encode , self .jwt_secret , algorithm = self .algorithm )
7786
87+ def _is_token_blacklisted (self , jti : str ) -> bool :
88+ """Check if token is blacklisted"""
89+ return jti in _token_blacklist
90+
91+ def _blacklist_token (self , jti : str ) -> None :
92+ """Add token to blacklist"""
93+ _token_blacklist .add (jti )
94+ logger .info (f"Token blacklisted: { jti } " )
95+
7896 def verify_token (self , token : str , token_type : str = "access" ) -> TokenData :
79- """Verify JWT token and return token data"""
97+ """Verify JWT token and return token data with blacklist check """
8098 try :
8199 payload = jwt .decode (token , self .jwt_secret , algorithms = [self .algorithm ])
82100
@@ -92,6 +110,11 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData:
92110 ):
93111 raise jwt .InvalidTokenError ("Token has expired" )
94112
113+ # Check if token is blacklisted
114+ jti = payload .get ("jti" )
115+ if jti and self ._is_token_blacklisted (jti ):
116+ raise jwt .InvalidTokenError ("Token has been revoked" )
117+
95118 return TokenData (
96119 user_id = payload .get ("sub" ),
97120 email = payload .get ("email" ),
@@ -100,6 +123,7 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData:
100123 if exp_timestamp
101124 else datetime .utcnow ()
102125 ),
126+ jti = jti ,
103127 )
104128 except jwt .ExpiredSignatureError :
105129 raise jwt .InvalidTokenError ("Token has expired" )
@@ -316,15 +340,44 @@ def get_current_user(self, access_token: str) -> UserInDB:
316340 logger .error (f"Get current user failed: { str (e )} " )
317341 raise
318342
319- def revoke_user_tokens (self , user_id : str ) -> bool :
343+ def revoke_token_by_jti (self , jti : str ) -> bool :
344+ """Revoke a specific token by its JWT ID"""
345+ if not jti :
346+ logger .warning ("Attempted to revoke token without JTI" )
347+ return False
348+
349+ self ._blacklist_token (jti )
350+ return True
351+
352+ def revoke_user_tokens (
353+ self , user_id : str , access_token : Optional [str ] = None
354+ ) -> bool :
320355 """
321- Revoke all tokens for a user (logout )
322- Note: With JWT, we can't actually revoke tokens server-side without a blacklist .
323- This is a placeholder for future token blacklist implementation .
356+ Revoke user tokens (logout with proper token blacklisting )
357+ In a production system, you would query all active tokens for the user .
358+ For now, we blacklist the current access token if provided .
324359 """
325360 logger .info (f"Token revocation requested for user: { user_id } " )
326- # In a production system, you would add the user's tokens to a blacklist
327- # For now, we just return True as logout is handled client-side
361+
362+ if access_token :
363+ try :
364+ # Verify the token to get its JTI before blacklisting
365+ token_data = self .verify_token (access_token , token_type = "access" )
366+ if token_data .jti :
367+ self ._blacklist_token (token_data .jti )
368+ logger .info (f"Successfully revoked token for user: { user_id } " )
369+ return True
370+ else :
371+ logger .warning (f"Token missing JTI for user: { user_id } " )
372+ return False
373+ except jwt .InvalidTokenError as e :
374+ logger .warning (
375+ f"Invalid token during revocation for user { user_id } : { str (e )} "
376+ )
377+ return False
378+
379+ # If no token provided, still consider it successful
380+ # (client-side logout without server-side token invalidation)
328381 return True
329382
330383 def validate_google_client_configuration (self ) -> Dict [str , any ]:
@@ -349,6 +402,14 @@ def validate_google_client_configuration(self) -> Dict[str, any]:
349402
350403 return config_status
351404
405+ def get_blacklist_stats (self ) -> Dict [str , any ]:
406+ """Get token blacklist statistics"""
407+ return {
408+ "blacklisted_tokens" : len (_token_blacklist ),
409+ "implementation" : "in_memory" ,
410+ "note" : "In production, use Redis for distributed blacklist" ,
411+ }
412+
352413 def health_check (self ) -> Dict [str , any ]:
353414 """Enhanced health check for auth service"""
354415 try :
@@ -383,6 +444,7 @@ def health_check(self) -> Dict[str, any]:
383444 "jwt_working" : jwt_working ,
384445 "user_service" : user_health ,
385446 "google_oauth" : google_config ,
447+ "token_blacklist" : self .get_blacklist_stats (),
386448 "environment" : self .environment ,
387449 "access_token_expire_minutes" : self .access_token_expire_minutes ,
388450 "refresh_token_expire_days" : self .refresh_token_expire_days ,
0 commit comments