-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsecurity.py
More file actions
125 lines (95 loc) · 3.85 KB
/
security.py
File metadata and controls
125 lines (95 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Security utilities and shared auth models for the MariaDB REST API."""
import datetime
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from db_utils import get_db_connection
from settings import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, logger
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None
role: Optional[str] = None
class User(BaseModel):
id: int
username: str
role: str
class UserCreate(BaseModel):
username: str
password: str
role: str = "admin" # allow creating admins and regular users
class PasswordChange(BaseModel):
old_password: str
new_password: str
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[datetime.timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.datetime.utcnow() + expires_delta
else:
expire = datetime.datetime.utcnow() + datetime.timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_user_from_db(username: str) -> Optional[User]:
conn = get_db_connection()
try:
cur = conn.cursor(dictionary=True)
cur.execute("SELECT id, username, password_hash, role FROM users WHERE username = %s", (username,))
row = cur.fetchone()
if not row:
return None
return User(id=row["id"], username=row["username"], role=row["role"])
finally:
conn.close()
def authenticate_user(username: str, password: str) -> Optional[User]:
conn = get_db_connection()
try:
cur = conn.cursor(dictionary=True)
cur.execute("SELECT id, username, password_hash, role FROM users WHERE username = %s", (username,))
row = cur.fetchone()
if not row:
return None
if not verify_password(password, row["password_hash"]):
return None
return User(id=row["id"], username=row["username"], role=row["role"])
finally:
conn.close()
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
role: str = payload.get("role")
if username is None or role is None:
raise credentials_exception
token_data = TokenData(username=username, role=role)
except JWTError:
raise credentials_exception
conn = get_db_connection()
try:
cur = conn.cursor(dictionary=True)
cur.execute("SELECT id, username, role FROM users WHERE username = %s", (token_data.username,))
row = cur.fetchone()
if not row:
raise credentials_exception
return User(id=row["id"], username=row["username"], role=row["role"])
finally:
conn.close()
async def get_current_admin(current_user: User = Depends(get_current_user)) -> User:
if current_user.role.lower() != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
return current_user