Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/backend-Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ jobs:
DATABASE_URL: postgresql+psycopg://cms:cmspass@localhost:5432/cmstest
JWT_SECRET: DH8kSxcflUVfNRdkEiJJCn2dOOKI3qfw
ALEMBIC_UPGRADE_HEAD_ON_START: false
AUTH_MODES: local
run: inv coverage --args "-vvv"

- name: Upload coverage report to codecov
Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dynamic = ["version"]
api = [
"fastapi[all] == 0.115.2",
"PyJWT == 2.11.0",
"Werkzeug == 3.1.5",
"cryptography == 46.0.4"
]
scripts = [
Expand Down
19 changes: 17 additions & 2 deletions backend/src/cms_backend/api/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from dataclasses import dataclass

from humanfriendly import parse_timespan

from cms_backend.context import parse_bool


@dataclass(kw_only=True)
class Context:
"""Class holding every contextual / configuration bits which can be moved

Expand All @@ -27,3 +27,18 @@ class Context:
create_new_oauth_account = parse_bool(
os.getenv("CREATE_NEW_OAUTH_ACCOUNT", default="true")
)
# List of authentication modes. Allowed values are "local", "oauth-session"
auth_modes: list[str] = os.getenv(
"AUTH_MODES",
default="oauth-session",
).split(",")

# Local Authentication JWT settings
jwt_secret: str = os.getenv("JWT_SECRET", default="")
jwt_token_issuer: str = os.getenv("JWT_TOKEN_ISSUER", default="cms_backend")
jwt_token_expiry_duration = parse_timespan(
os.getenv("JWT_TOKEN_EXPIRY_DURATION", default="1d")
)
refresh_token_expiry_duration = parse_timespan(
os.getenv("REFRESH_TOKEN_EXPIRY_DURATION", default="30d")
)
2 changes: 2 additions & 0 deletions backend/src/cms_backend/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from cms_backend.api.routes.healthcheck import router as healthcheck_router
from cms_backend.api.routes.http_errors import BadRequestError
from cms_backend.api.routes.titles import router as titles_router
from cms_backend.api.routes.user import router as user_router
from cms_backend.api.routes.zimfarm_notifications import (
router as zimfarm_notification_router,
)
Expand Down Expand Up @@ -62,6 +63,7 @@ def create_app(*, debug: bool = True):
main_router.include_router(router=books_router)
main_router.include_router(router=collection_router)
main_router.include_router(router=auth_router)
main_router.include_router(router=user_router)

app.include_router(router=main_router)

Expand Down
112 changes: 110 additions & 2 deletions backend/src/cms_backend/api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,123 @@
import datetime
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Response
from sqlalchemy.orm import Session as OrmSession
from werkzeug.security import check_password_hash

from cms_backend.api.context import Context
from cms_backend.api.routes.dependencies import get_current_user
from cms_backend.api.routes.http_errors import UnauthorizedError
from cms_backend.api.token import generate_access_token
from cms_backend.db import gen_dbsession
from cms_backend.db.exceptions import RecordDoesNotExistError
from cms_backend.db.models import User
from cms_backend.db.user import create_user_schema
from cms_backend.db.refresh_token import (
create_refresh_token,
delete_refresh_token,
expire_refresh_tokens,
get_refresh_token,
)
from cms_backend.db.user import create_user_schema, get_user_by_username
from cms_backend.schemas import BaseModel
from cms_backend.schemas.orms import UserSchema
from cms_backend.utils.datetime import getnow

router = APIRouter(prefix="/auth", tags=["auth"])


class CredentialsIn(BaseModel):
username: str
password: str


class RefreshTokenIn(BaseModel):
refresh_token: UUID


class Token(BaseModel):
"""Access token on successful authentication."""

access_token: str
token_type: str = "Bearer"
expires_time: datetime.datetime
refresh_token: str


def _access_token_response(db_session: OrmSession, db_user: User, response: Response):
response.headers["Cache-Control"] = "no-store"
response.headers["Pragma"] = "no-cache"
issue_time = getnow()
return Token(
access_token=generate_access_token(
user_id=str(db_user.id),
issue_time=issue_time,
),
refresh_token=str(
create_refresh_token(session=db_session, user_id=db_user.id).token
),
expires_time=issue_time
+ datetime.timedelta(seconds=Context.jwt_token_expiry_duration),
)


def _auth_with_credentials(
db_session: OrmSession, credentials: CredentialsIn, response: Response
):
"""Authorize a user with username and password."""
try:
db_user = get_user_by_username(db_session, username=credentials.username)
except RecordDoesNotExistError as exc:
raise UnauthorizedError() from exc

if not (
db_user.password_hash
and check_password_hash(db_user.password_hash, credentials.password)
):
raise UnauthorizedError("Invalid credentials")

return _access_token_response(db_session, db_user, response)


def _refresh_access_token(
db_session: OrmSession, refresh_token: UUID, response: Response
):
"""Issue a new set of access and refresh tokens."""
try:
db_refresh_token = get_refresh_token(db_session, token=refresh_token)
except RecordDoesNotExistError as exc:
raise UnauthorizedError() from exc

now = getnow()
if db_refresh_token.expire_time < now:
raise UnauthorizedError("Refresh token expired")

delete_refresh_token(db_session, token=refresh_token)
expire_refresh_tokens(db_session, expire_time=now)

return _access_token_response(db_session, db_refresh_token.user, response)


@router.post("/authorize")
def auth_with_credentials(
credentials: CredentialsIn,
response: Response,
db_session: Annotated[OrmSession, Depends(gen_dbsession)],
) -> Token:
"""Authorize a user with username and password."""
return _auth_with_credentials(db_session, credentials, response)


@router.post("/refresh")
def refresh_access_token(
request: RefreshTokenIn,
db_session: Annotated[OrmSession, Depends(gen_dbsession)],
response: Response,
) -> Token:
return _refresh_access_token(db_session, request.refresh_token, response)


@router.get("/me")
def get_current_user_info(
current_user: Annotated[User, Depends(get_current_user)],
Expand Down
4 changes: 3 additions & 1 deletion backend/src/cms_backend/api/routes/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def _get_current_user_or_none(
if claims is None:
return None
user = get_user_by_id_or_none(session, user_id=claims.sub)
# If this is a kiwix token, we create a new user account
# If this is a kiwix token (wilkl have a name), we create a new user account
if user is None and Context.create_new_oauth_account:
if not claims.name:
raise UnauthorizedError("Token is missing 'profile' scope")
create_user(
session,
username=claims.name,
Expand Down
106 changes: 106 additions & 0 deletions backend/src/cms_backend/api/routes/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from http import HTTPStatus
from typing import Annotated

from fastapi import APIRouter, Depends, Path, Response
from sqlalchemy.orm import Session as OrmSession
from werkzeug.security import check_password_hash, generate_password_hash

from cms_backend.api.routes.dependencies import get_current_user, require_permission
from cms_backend.api.routes.fields import NotEmptyString
from cms_backend.api.routes.http_errors import BadRequestError, UnauthorizedError
from cms_backend.db import gen_dbsession
from cms_backend.db.models import User
from cms_backend.db.user import (
check_user_permission,
create_user_schema,
get_user_by_username,
)
from cms_backend.db.user import create_user as db_create_user
from cms_backend.db.user import delete_user as db_delete_user
from cms_backend.db.user import update_user_password as db_update_user_password
from cms_backend.roles import RoleEnum
from cms_backend.schemas import BaseModel
from cms_backend.schemas.orms import UserSchema

router = APIRouter(prefix="/users", tags=["users"])


class UserCreateSchema(BaseModel):
"""
Schema for creating a user
"""

username: NotEmptyString
password: NotEmptyString
role: RoleEnum


class PasswordUpdateSchema(BaseModel):
"""
Schema for updating a user's password
"""

# users with elevated permissions can omit the current password
current: NotEmptyString | None = None
new: NotEmptyString


@router.post(
"", dependencies=[Depends(require_permission(namespace="user", name="create"))]
)
def create_user(
user_schema: UserCreateSchema,
db_session: Annotated[OrmSession, Depends(gen_dbsession)],
) -> UserSchema:
user = db_create_user(
db_session,
username=user_schema.username,
password_hash=generate_password_hash(user_schema.password),
role=user_schema.role,
)

return create_user_schema(user)


@router.delete(
"/{username}",
dependencies=[Depends(require_permission(namespace="user", name="delete"))],
)
def delete_user(
username: Annotated[str, Path()],
db_session: Annotated[OrmSession, Depends(gen_dbsession)],
) -> Response:
"""Delete a specific user"""
user = get_user_by_username(db_session, username=username)
db_delete_user(db_session, user_id=user.id)
return Response(status_code=HTTPStatus.NO_CONTENT)


@router.patch("/{username}/password")
def update_user_password(
username: Annotated[str, Path()],
password_update: PasswordUpdateSchema,
db_session: Annotated[OrmSession, Depends(gen_dbsession)],
current_user: Annotated[User, Depends(get_current_user)],
) -> Response:
"""Update a user's password"""
user = get_user_by_username(db_session, username=username)

if current_user.username == username:
if password_update.current is None:
raise BadRequestError("You must enter your current password.")

if not check_password_hash(
current_user.password_hash or "", password_update.current
):
raise BadRequestError()

elif not check_user_permission(current_user, namespace="user", name="update"):
raise UnauthorizedError("You are not allowed to access this resource")

db_update_user_password(
db_session,
user_id=user.id,
password_hash=generate_password_hash(password_update.new),
)
return Response(status_code=HTTPStatus.NO_CONTENT)
50 changes: 47 additions & 3 deletions backend/src/cms_backend/api/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class JWTClaims(BaseModel):
exp: datetime.datetime
iat: datetime.datetime
sub: uuid.UUID = Field(alias="subject")
name: str
name: str | None = Field(exclude=True, default=None)


class TokenDecoder(abc.ABC):
Expand Down Expand Up @@ -48,6 +48,29 @@ def can_decode(self) -> bool:
pass


class LocalTokenDecoder(TokenDecoder):
"""Decoder for local CMS JWT tokens."""

def __init__(self, secret: str = Context.jwt_secret, algorithm: str = "HS256"):
self.secret = secret
self.algorithm = algorithm

def decode(self, token: str) -> JWTClaims:
"""
Decode and validate a local CMS token.
"""
jwt_claims = jwt.decode(token, self.secret, algorithms=[self.algorithm])
return JWTClaims(**jwt_claims)

@property
def name(self) -> str:
return "local"

@property
def can_decode(self) -> bool:
return "local" in Context.auth_modes


class OAuthSessionTokenDecoder(TokenDecoder):
"""Decoder for OAuth Session JWT tokens."""

Expand Down Expand Up @@ -92,7 +115,7 @@ def name(self) -> str:

@property
def can_decode(self) -> bool:
return True
return "oauth-session" in Context.auth_modes


class TokenDecoderChain:
Expand Down Expand Up @@ -132,4 +155,25 @@ def decode(self, token: str) -> JWTClaims:
raise ValueError("Inavlid token")


token_decoder = TokenDecoderChain(decoders=[OAuthSessionTokenDecoder()])
token_decoder = TokenDecoderChain(
decoders=[OAuthSessionTokenDecoder(), LocalTokenDecoder()]
)


def generate_access_token(
*,
user_id: str,
issue_time: datetime.datetime,
) -> str:
"""Generate a JWT access token for the given user ID with configured expiry."""

expire_time = issue_time + datetime.timedelta(
seconds=Context.jwt_token_expiry_duration
)
payload = {
"iss": Context.jwt_token_issuer, # issuer
"exp": expire_time.timestamp(), # expiration time
"iat": issue_time.timestamp(), # issued at
"subject": user_id,
}
return jwt.encode(payload, key=Context.jwt_secret, algorithm="HS256")
Loading