diff --git a/.github/workflows/backend-Tests.yml b/.github/workflows/backend-Tests.yml index efc4d02..1089a91 100644 --- a/.github/workflows/backend-Tests.yml +++ b/.github/workflows/backend-Tests.yml @@ -90,6 +90,7 @@ jobs: DATABASE_URL: postgresql+psycopg://cms:cmspass@localhost:5432/cmstest JWT_SECRET: DH8kSxcflUVfNRdkEiJJCn2dOOKI3qfw ALEMBIC_UPGRADE_HEAD_ON_START: false + AUTH_MODES: local run: inv coverage --args "-vvv" - name: Upload coverage report to codecov diff --git a/backend/pyproject.toml b/backend/pyproject.toml index aac0469..a930329 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -31,6 +31,7 @@ dynamic = ["version"] api = [ "fastapi[all] == 0.115.2", "PyJWT == 2.11.0", + "Werkzeug == 3.1.5", "cryptography == 46.0.4" ] scripts = [ diff --git a/backend/src/cms_backend/api/context.py b/backend/src/cms_backend/api/context.py index 97ca57e..c0cb826 100644 --- a/backend/src/cms_backend/api/context.py +++ b/backend/src/cms_backend/api/context.py @@ -1,10 +1,10 @@ import os -from dataclasses import dataclass + +from humanfriendly import parse_timespan from cms_backend.context import parse_bool -@dataclass(kw_only=True) class Context: """Class holding every contextual / configuration bits which can be moved @@ -27,3 +27,18 @@ class Context: create_new_oauth_account = parse_bool( os.getenv("CREATE_NEW_OAUTH_ACCOUNT", default="true") ) + # List of authentication modes. Allowed values are "local", "oauth-session" + auth_modes: list[str] = os.getenv( + "AUTH_MODES", + default="oauth-session", + ).split(",") + + # Local Authentication JWT settings + jwt_secret: str = os.getenv("JWT_SECRET", default="") + jwt_token_issuer: str = os.getenv("JWT_TOKEN_ISSUER", default="cms_backend") + jwt_token_expiry_duration = parse_timespan( + os.getenv("JWT_TOKEN_EXPIRY_DURATION", default="1d") + ) + refresh_token_expiry_duration = parse_timespan( + os.getenv("REFRESH_TOKEN_EXPIRY_DURATION", default="30d") + ) diff --git a/backend/src/cms_backend/api/main.py b/backend/src/cms_backend/api/main.py index 6d0e3ad..06c82a9 100644 --- a/backend/src/cms_backend/api/main.py +++ b/backend/src/cms_backend/api/main.py @@ -15,6 +15,7 @@ from cms_backend.api.routes.healthcheck import router as healthcheck_router from cms_backend.api.routes.http_errors import BadRequestError from cms_backend.api.routes.titles import router as titles_router +from cms_backend.api.routes.user import router as user_router from cms_backend.api.routes.zimfarm_notifications import ( router as zimfarm_notification_router, ) @@ -62,6 +63,7 @@ def create_app(*, debug: bool = True): main_router.include_router(router=books_router) main_router.include_router(router=collection_router) main_router.include_router(router=auth_router) + main_router.include_router(router=user_router) app.include_router(router=main_router) diff --git a/backend/src/cms_backend/api/routes/auth.py b/backend/src/cms_backend/api/routes/auth.py index 79f5906..c7e4d8d 100644 --- a/backend/src/cms_backend/api/routes/auth.py +++ b/backend/src/cms_backend/api/routes/auth.py @@ -1,15 +1,123 @@ +import datetime from typing import Annotated +from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Response +from sqlalchemy.orm import Session as OrmSession +from werkzeug.security import check_password_hash +from cms_backend.api.context import Context from cms_backend.api.routes.dependencies import get_current_user +from cms_backend.api.routes.http_errors import UnauthorizedError +from cms_backend.api.token import generate_access_token +from cms_backend.db import gen_dbsession +from cms_backend.db.exceptions import RecordDoesNotExistError from cms_backend.db.models import User -from cms_backend.db.user import create_user_schema +from cms_backend.db.refresh_token import ( + create_refresh_token, + delete_refresh_token, + expire_refresh_tokens, + get_refresh_token, +) +from cms_backend.db.user import create_user_schema, get_user_by_username +from cms_backend.schemas import BaseModel from cms_backend.schemas.orms import UserSchema +from cms_backend.utils.datetime import getnow router = APIRouter(prefix="/auth", tags=["auth"]) +class CredentialsIn(BaseModel): + username: str + password: str + + +class RefreshTokenIn(BaseModel): + refresh_token: UUID + + +class Token(BaseModel): + """Access token on successful authentication.""" + + access_token: str + token_type: str = "Bearer" + expires_time: datetime.datetime + refresh_token: str + + +def _access_token_response(db_session: OrmSession, db_user: User, response: Response): + response.headers["Cache-Control"] = "no-store" + response.headers["Pragma"] = "no-cache" + issue_time = getnow() + return Token( + access_token=generate_access_token( + user_id=str(db_user.id), + issue_time=issue_time, + ), + refresh_token=str( + create_refresh_token(session=db_session, user_id=db_user.id).token + ), + expires_time=issue_time + + datetime.timedelta(seconds=Context.jwt_token_expiry_duration), + ) + + +def _auth_with_credentials( + db_session: OrmSession, credentials: CredentialsIn, response: Response +): + """Authorize a user with username and password.""" + try: + db_user = get_user_by_username(db_session, username=credentials.username) + except RecordDoesNotExistError as exc: + raise UnauthorizedError() from exc + + if not ( + db_user.password_hash + and check_password_hash(db_user.password_hash, credentials.password) + ): + raise UnauthorizedError("Invalid credentials") + + return _access_token_response(db_session, db_user, response) + + +def _refresh_access_token( + db_session: OrmSession, refresh_token: UUID, response: Response +): + """Issue a new set of access and refresh tokens.""" + try: + db_refresh_token = get_refresh_token(db_session, token=refresh_token) + except RecordDoesNotExistError as exc: + raise UnauthorizedError() from exc + + now = getnow() + if db_refresh_token.expire_time < now: + raise UnauthorizedError("Refresh token expired") + + delete_refresh_token(db_session, token=refresh_token) + expire_refresh_tokens(db_session, expire_time=now) + + return _access_token_response(db_session, db_refresh_token.user, response) + + +@router.post("/authorize") +def auth_with_credentials( + credentials: CredentialsIn, + response: Response, + db_session: Annotated[OrmSession, Depends(gen_dbsession)], +) -> Token: + """Authorize a user with username and password.""" + return _auth_with_credentials(db_session, credentials, response) + + +@router.post("/refresh") +def refresh_access_token( + request: RefreshTokenIn, + db_session: Annotated[OrmSession, Depends(gen_dbsession)], + response: Response, +) -> Token: + return _refresh_access_token(db_session, request.refresh_token, response) + + @router.get("/me") def get_current_user_info( current_user: Annotated[User, Depends(get_current_user)], diff --git a/backend/src/cms_backend/api/routes/dependencies.py b/backend/src/cms_backend/api/routes/dependencies.py index 70b6899..435f64d 100644 --- a/backend/src/cms_backend/api/routes/dependencies.py +++ b/backend/src/cms_backend/api/routes/dependencies.py @@ -58,8 +58,10 @@ def _get_current_user_or_none( if claims is None: return None user = get_user_by_id_or_none(session, user_id=claims.sub) - # If this is a kiwix token, we create a new user account + # If this is a kiwix token (wilkl have a name), we create a new user account if user is None and Context.create_new_oauth_account: + if not claims.name: + raise UnauthorizedError("Token is missing 'profile' scope") create_user( session, username=claims.name, diff --git a/backend/src/cms_backend/api/routes/user.py b/backend/src/cms_backend/api/routes/user.py new file mode 100644 index 0000000..7bc22fc --- /dev/null +++ b/backend/src/cms_backend/api/routes/user.py @@ -0,0 +1,106 @@ +from http import HTTPStatus +from typing import Annotated + +from fastapi import APIRouter, Depends, Path, Response +from sqlalchemy.orm import Session as OrmSession +from werkzeug.security import check_password_hash, generate_password_hash + +from cms_backend.api.routes.dependencies import get_current_user, require_permission +from cms_backend.api.routes.fields import NotEmptyString +from cms_backend.api.routes.http_errors import BadRequestError, UnauthorizedError +from cms_backend.db import gen_dbsession +from cms_backend.db.models import User +from cms_backend.db.user import ( + check_user_permission, + create_user_schema, + get_user_by_username, +) +from cms_backend.db.user import create_user as db_create_user +from cms_backend.db.user import delete_user as db_delete_user +from cms_backend.db.user import update_user_password as db_update_user_password +from cms_backend.roles import RoleEnum +from cms_backend.schemas import BaseModel +from cms_backend.schemas.orms import UserSchema + +router = APIRouter(prefix="/users", tags=["users"]) + + +class UserCreateSchema(BaseModel): + """ + Schema for creating a user + """ + + username: NotEmptyString + password: NotEmptyString + role: RoleEnum + + +class PasswordUpdateSchema(BaseModel): + """ + Schema for updating a user's password + """ + + # users with elevated permissions can omit the current password + current: NotEmptyString | None = None + new: NotEmptyString + + +@router.post( + "", dependencies=[Depends(require_permission(namespace="user", name="create"))] +) +def create_user( + user_schema: UserCreateSchema, + db_session: Annotated[OrmSession, Depends(gen_dbsession)], +) -> UserSchema: + user = db_create_user( + db_session, + username=user_schema.username, + password_hash=generate_password_hash(user_schema.password), + role=user_schema.role, + ) + + return create_user_schema(user) + + +@router.delete( + "/{username}", + dependencies=[Depends(require_permission(namespace="user", name="delete"))], +) +def delete_user( + username: Annotated[str, Path()], + db_session: Annotated[OrmSession, Depends(gen_dbsession)], +) -> Response: + """Delete a specific user""" + user = get_user_by_username(db_session, username=username) + db_delete_user(db_session, user_id=user.id) + return Response(status_code=HTTPStatus.NO_CONTENT) + + +@router.patch("/{username}/password") +def update_user_password( + username: Annotated[str, Path()], + password_update: PasswordUpdateSchema, + db_session: Annotated[OrmSession, Depends(gen_dbsession)], + current_user: Annotated[User, Depends(get_current_user)], +) -> Response: + """Update a user's password""" + user = get_user_by_username(db_session, username=username) + + if current_user.username == username: + if password_update.current is None: + raise BadRequestError("You must enter your current password.") + + if not check_password_hash( + current_user.password_hash or "", password_update.current + ): + raise BadRequestError() + + elif not check_user_permission(current_user, namespace="user", name="update"): + raise UnauthorizedError("You are not allowed to access this resource") + + db_update_user_password( + db_session, + user_id=user.id, + password_hash=generate_password_hash(password_update.new), + ) + return Response(status_code=HTTPStatus.NO_CONTENT) diff --git a/backend/src/cms_backend/api/token.py b/backend/src/cms_backend/api/token.py index a92c9c9..cf41ba6 100644 --- a/backend/src/cms_backend/api/token.py +++ b/backend/src/cms_backend/api/token.py @@ -18,7 +18,7 @@ class JWTClaims(BaseModel): exp: datetime.datetime iat: datetime.datetime sub: uuid.UUID = Field(alias="subject") - name: str + name: str | None = Field(exclude=True, default=None) class TokenDecoder(abc.ABC): @@ -48,6 +48,29 @@ def can_decode(self) -> bool: pass +class LocalTokenDecoder(TokenDecoder): + """Decoder for local CMS JWT tokens.""" + + def __init__(self, secret: str = Context.jwt_secret, algorithm: str = "HS256"): + self.secret = secret + self.algorithm = algorithm + + def decode(self, token: str) -> JWTClaims: + """ + Decode and validate a local CMS token. + """ + jwt_claims = jwt.decode(token, self.secret, algorithms=[self.algorithm]) + return JWTClaims(**jwt_claims) + + @property + def name(self) -> str: + return "local" + + @property + def can_decode(self) -> bool: + return "local" in Context.auth_modes + + class OAuthSessionTokenDecoder(TokenDecoder): """Decoder for OAuth Session JWT tokens.""" @@ -92,7 +115,7 @@ def name(self) -> str: @property def can_decode(self) -> bool: - return True + return "oauth-session" in Context.auth_modes class TokenDecoderChain: @@ -132,4 +155,25 @@ def decode(self, token: str) -> JWTClaims: raise ValueError("Inavlid token") -token_decoder = TokenDecoderChain(decoders=[OAuthSessionTokenDecoder()]) +token_decoder = TokenDecoderChain( + decoders=[OAuthSessionTokenDecoder(), LocalTokenDecoder()] +) + + +def generate_access_token( + *, + user_id: str, + issue_time: datetime.datetime, +) -> str: + """Generate a JWT access token for the given user ID with configured expiry.""" + + expire_time = issue_time + datetime.timedelta( + seconds=Context.jwt_token_expiry_duration + ) + payload = { + "iss": Context.jwt_token_issuer, # issuer + "exp": expire_time.timestamp(), # expiration time + "iat": issue_time.timestamp(), # issued at + "subject": user_id, + } + return jwt.encode(payload, key=Context.jwt_secret, algorithm="HS256") diff --git a/backend/src/cms_backend/db/models.py b/backend/src/cms_backend/db/models.py index 4f80fa8..1e27530 100644 --- a/backend/src/cms_backend/db/models.py +++ b/backend/src/cms_backend/db/models.py @@ -281,7 +281,29 @@ def full_str(self) -> str: class User(Base): __tablename__ = "user" - idp_sub: Mapped[UUID] = mapped_column(primary_key=True) + id: Mapped[UUID] = mapped_column( + init=False, primary_key=True, server_default=text("uuid_generate_v4()") + ) + idp_sub: Mapped[UUID | None] username: Mapped[str] = mapped_column(unique=True, index=True) role: Mapped[str] + password_hash: Mapped[str | None] deleted: Mapped[bool] = mapped_column(default=False, server_default=false()) + + refresh_tokens: Mapped[list["Refreshtoken"]] = relationship( + back_populates="user", cascade="all, delete-orphan", init=False + ) + + +class Refreshtoken(Base): + __tablename__ = "refresh_token" + id: Mapped[UUID] = mapped_column( + init=False, primary_key=True, server_default=text("uuid_generate_v4()") + ) + token: Mapped[UUID] = mapped_column(server_default=text("uuid_generate_v4()")) + expire_time: Mapped[datetime] + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), init=False) + + user: Mapped["User"] = relationship(back_populates="refresh_tokens", init=False) + + __table__args = (Index("user_id", "token", unique=True),) diff --git a/backend/src/cms_backend/db/refresh_token.py b/backend/src/cms_backend/db/refresh_token.py new file mode 100644 index 0000000..f563e9a --- /dev/null +++ b/backend/src/cms_backend/db/refresh_token.py @@ -0,0 +1,57 @@ +import datetime +from uuid import UUID, uuid4 + +from sqlalchemy import select +from sqlalchemy.orm import Session as OrmSession + +from cms_backend.api.context import Context +from cms_backend.db.exceptions import RecordDoesNotExistError +from cms_backend.db.models import Refreshtoken +from cms_backend.db.user import get_user_by_id +from cms_backend.utils.datetime import getnow + + +def get_refresh_token_or_none(session: OrmSession, token: UUID) -> Refreshtoken | None: + """Get a refresh token by token""" + return session.scalars( + select(Refreshtoken).where(Refreshtoken.token == token) + ).one_or_none() + + +def get_refresh_token(session: OrmSession, token: UUID) -> Refreshtoken: + """Get a refresh token by token""" + db_refresh_token = get_refresh_token_or_none(session, token) + if db_refresh_token is None: + raise RecordDoesNotExistError("Refresh token not found") + return db_refresh_token + + +def create_refresh_token(session: OrmSession, user_id: UUID) -> Refreshtoken: + """Create a refresh token for a user""" + refresh_token = Refreshtoken( + token=uuid4(), + expire_time=getnow() + + datetime.timedelta(seconds=Context.refresh_token_expiry_duration), + ) + refresh_token.user = get_user_by_id(session, user_id=user_id) + session.add(refresh_token) + session.flush() + return refresh_token + + +def delete_refresh_token(session: OrmSession, token: UUID) -> None: + """Delete a refresh token by token""" + db_refresh_token = get_refresh_token_or_none(session, token) + if db_refresh_token is None: + raise RecordDoesNotExistError("Refresh token not found") + session.delete(db_refresh_token) + session.flush() + + +def expire_refresh_tokens(session: OrmSession, expire_time: datetime.datetime) -> None: + """Expire all refresh tokens before a given time""" + for db_refresh_token in session.scalars( + select(Refreshtoken).where(Refreshtoken.expire_time < expire_time) + ): + session.delete(db_refresh_token) + session.flush() diff --git a/backend/src/cms_backend/db/user.py b/backend/src/cms_backend/db/user.py index e290015..eb1e703 100644 --- a/backend/src/cms_backend/db/user.py +++ b/backend/src/cms_backend/db/user.py @@ -1,6 +1,6 @@ from uuid import UUID -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session as OrmSession @@ -27,7 +27,9 @@ def get_user_by_username(session: OrmSession, *, username: str) -> User: def get_user_by_id_or_none(session: OrmSession, *, user_id: UUID) -> User | None: """Get a user by id or return None if the user does not exist""" - return session.scalars(select(User).where(User.idp_sub == user_id)).one_or_none() + return session.scalars( + select(User).where((User.idp_sub == user_id) | (User.id == user_id)) + ).one_or_none() def get_user_by_id(session: OrmSession, *, user_id: UUID) -> User: @@ -55,7 +57,6 @@ def create_user_schema(user: User) -> UserSchema: return UserSchema( username=user.username, role=user.role, - idp_sub=user.idp_sub, scope=merge_scopes(ROLES.get(user.role, {}), ROLES[RoleEnum.EDITOR]), ) @@ -65,7 +66,8 @@ def create_user( *, username: str, role: str, - idp_sub: UUID, + idp_sub: UUID | None = None, + password_hash: str | None = None, ) -> User: """Create a new user""" user = User( @@ -73,6 +75,7 @@ def create_user( role=role, deleted=False, idp_sub=idp_sub, + password_hash=password_hash, ) session.add(user) try: @@ -80,3 +83,24 @@ def create_user( except IntegrityError as exc: raise RecordAlreadyExistsError("User already exists") from exc return user + + +def update_user_password( + session: OrmSession, + *, + user_id: UUID, + password_hash: str, +) -> None: + """Update a user's password""" + session.execute( + update(User).where(User.id == user_id).values(password_hash=password_hash) + ) + + +def delete_user( + session: OrmSession, + *, + user_id: UUID, +) -> None: + """Delete a user""" + session.execute(update(User).where(User.id == user_id).values(deleted=True)) diff --git a/backend/src/cms_backend/migrations/versions/f133e17aa945_add_local_user_details.py b/backend/src/cms_backend/migrations/versions/f133e17aa945_add_local_user_details.py new file mode 100644 index 0000000..b03dce9 --- /dev/null +++ b/backend/src/cms_backend/migrations/versions/f133e17aa945_add_local_user_details.py @@ -0,0 +1,65 @@ +"""add local user details + +Revision ID: f133e17aa945 +Revises: a5f67b148119 +Create Date: 2026-02-10 11:35:10.652021 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "f133e17aa945" +down_revision = "a5f67b148119" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "user", + sa.Column( + "id", + sa.Uuid(), + server_default=sa.text("uuid_generate_v4()"), + nullable=False, + ), + ) + op.add_column("user", sa.Column("password_hash", sa.String(), nullable=True)) + op.drop_constraint(op.f("pk_user"), "user", type_="primary") + op.alter_column("user", "idp_sub", existing_type=sa.UUID(), nullable=True) + op.create_primary_key(op.f("pk_user"), "user", ["id"]) + + op.create_table( + "refresh_token", + sa.Column( + "id", + sa.Uuid(), + server_default=sa.text("uuid_generate_v4()"), + nullable=False, + ), + sa.Column( + "token", + sa.Uuid(), + server_default=sa.text("uuid_generate_v4()"), + nullable=False, + ), + sa.Column("expire_time", sa.DateTime(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], ["user.id"], name=op.f("fk_refresh_token_user_id_user") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_refresh_token")), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("user", "idp_sub", existing_type=sa.UUID(), nullable=False) + op.drop_column("user", "password_hash") + # op.drop_column("user", "id") uncommented as this should be the pk regardless + op.drop_table("refresh_token") + # ### end Alembic commands ### diff --git a/backend/src/cms_backend/roles.py b/backend/src/cms_backend/roles.py index d681f0d..631a5cd 100644 --- a/backend/src/cms_backend/roles.py +++ b/backend/src/cms_backend/roles.py @@ -30,6 +30,7 @@ class RoleEnum(StrEnum): "book": ResourcePermissions.get_all(), "title": ResourcePermissions.get_all(), "zimfarm_notification": ResourcePermissions.get_all(), + "user": ResourcePermissions.get_all(), }, RoleEnum.ZIMFARM: { "zimfarm_notification": ResourcePermissions.get(read=True, create=True), diff --git a/backend/src/cms_backend/schemas/orms.py b/backend/src/cms_backend/schemas/orms.py index 400838d..3e12433 100644 --- a/backend/src/cms_backend/schemas/orms.py +++ b/backend/src/cms_backend/schemas/orms.py @@ -124,5 +124,4 @@ class UserSchema(BaseModel): username: str role: str - idp_sub: UUID scope: dict[str, dict[str, bool]] diff --git a/backend/tests/api/routes/conftest.py b/backend/tests/api/routes/conftest.py index 0b65a2a..c9737b6 100644 --- a/backend/tests/api/routes/conftest.py +++ b/backend/tests/api/routes/conftest.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Generator +from collections.abc import Generator import pytest from fastapi.testclient import TestClient @@ -6,13 +6,10 @@ from cms_backend.api.main import app from cms_backend.db import gen_dbsession, gen_manual_dbsession -from cms_backend.db.models import User @pytest.fixture -def client( - dbsession: OrmSession, user: User, mock_token_for_user: Callable[[User], None] -) -> TestClient: +def client(dbsession: OrmSession) -> TestClient: def test_dbsession() -> Generator[OrmSession]: yield dbsession @@ -20,10 +17,4 @@ def test_dbsession() -> Generator[OrmSession]: app.dependency_overrides[gen_dbsession] = test_dbsession app.dependency_overrides[gen_manual_dbsession] = test_dbsession - # Set up default authentication for the default user - mock_token_for_user(user) - - client = TestClient(app=app) - client.headers["Authorization"] = "Bearer test-token" - - return client + return TestClient(app=app) diff --git a/backend/tests/api/routes/test_titles.py b/backend/tests/api/routes/test_titles.py index 7ef59fa..f55371a 100644 --- a/backend/tests/api/routes/test_titles.py +++ b/backend/tests/api/routes/test_titles.py @@ -6,8 +6,10 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session as OrmSession +from cms_backend.api.token import generate_access_token from cms_backend.db.models import Book, Collection, Title, User from cms_backend.roles import RoleEnum +from cms_backend.utils.datetime import getnow def test_get_titles_empty(client: TestClient): @@ -58,7 +60,6 @@ def test_get_titles( def test_create_title_required_permissions( client: TestClient, create_user: Callable[..., User], - mock_token_for_user: Callable[[User], None], permission: RoleEnum, expected_status_code: HTTPStatus, ): @@ -68,22 +69,30 @@ def test_create_title_required_permissions( } user = create_user(permission=permission) - mock_token_for_user(user) - - response = client.post("/v1/titles", json=title_data) + access_token = generate_access_token(user_id=str(user.id), issue_time=getnow()) + response = client.post( + "/v1/titles", + json=title_data, + headers={"Authorization": f"Bearer {access_token}"}, + ) assert response.status_code == expected_status_code def test_create_title_required_fields_only( client: TestClient, dbsession: OrmSession, + access_token: str, ): """Test creating a title with only required fields""" title_data = { "name": "wikipedia_en_test", } - response = client.post("/v1/titles", json=title_data) + response = client.post( + "/v1/titles", + json=title_data, + headers={"Authorization": f"Bearer {access_token}"}, + ) assert response.status_code == HTTPStatus.OK data = response.json() @@ -101,6 +110,7 @@ def test_create_title_all_fields( client: TestClient, dbsession: OrmSession, create_collection: Callable[..., Collection], + access_token: str, ): """Test creating a title with all fields""" collection = create_collection(name="wikipedia") @@ -115,7 +125,11 @@ def test_create_title_all_fields( ], } - response = client.post("/v1/titles", json=title_data) + response = client.post( + "/v1/titles", + json=title_data, + headers={"Authorization": f"Bearer {access_token}"}, + ) assert response.status_code == HTTPStatus.OK data = response.json() @@ -135,6 +149,7 @@ def test_create_title_all_fields( def test_create_title_with_duplicate_collection_name( client: TestClient, + access_token: str, ): """Test creating a title with the same collection repeated.""" title_data = { @@ -149,12 +164,17 @@ def test_create_title_with_duplicate_collection_name( ], } - response = client.post("/v1/titles", json=title_data) + response = client.post( + "/v1/titles", + json=title_data, + headers={"Authorization": f"Bearer {access_token}"}, + ) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY def test_create_title_duplicate_name( client: TestClient, + access_token: str, ): """Test creating a title with duplicate name returns conflict error""" title_data = { @@ -162,11 +182,19 @@ def test_create_title_duplicate_name( } # Create the first title - response = client.post("/v1/titles", json=title_data) + response = client.post( + "/v1/titles", + json=title_data, + headers={"Authorization": f"Bearer {access_token}"}, + ) assert response.status_code == HTTPStatus.OK # Try to create another title with the same name - response = client.post("/v1/titles", json=title_data) + response = client.post( + "/v1/titles", + json=title_data, + headers={"Authorization": f"Bearer {access_token}"}, + ) assert response.status_code == HTTPStatus.CONFLICT assert "already exists" in response.json()["message"].lower() diff --git a/backend/tests/api/routes/test_user.py b/backend/tests/api/routes/test_user.py new file mode 100644 index 0000000..6040e5e --- /dev/null +++ b/backend/tests/api/routes/test_user.py @@ -0,0 +1,75 @@ +from collections.abc import Callable +from http import HTTPStatus + +import pytest +from fastapi.testclient import TestClient + +from cms_backend.api.token import generate_access_token +from cms_backend.db.models import User +from cms_backend.utils.datetime import getnow + + +def test_create_user(client: TestClient, user: User): + url = "/v1/users/" + access_token = generate_access_token( + issue_time=getnow(), + user_id=str(user.id), + ) + response = client.post( + url, + headers={"Authorization": f"Bearer {access_token}"}, + json={ + "username": "test", + "password": "test", + "role": "viewer", + }, + ) + assert response.status_code == HTTPStatus.OK + + +def test_create_user_duplicate(client: TestClient, user: User): + url = "/v1/users/" + access_token = generate_access_token( + issue_time=getnow(), + user_id=str(user.id), + ) + response = client.post( + url, + headers={"Authorization": f"Bearer {access_token}"}, + json={ + "username": user.username, + "password": "test", + "role": "viewer", + }, + ) + assert response.status_code == HTTPStatus.CONFLICT + + +@pytest.mark.parametrize( + "current,new,expected", + [ + ("invalid", "test2", HTTPStatus.BAD_REQUEST), + (None, "test2", HTTPStatus.BAD_REQUEST), + ("testpassword", "test2", HTTPStatus.NO_CONTENT), + ], +) +def test_update_user_password_invalid( + client: TestClient, + create_user: Callable[..., User], + current: str, + new: str, + expected: HTTPStatus, +): + """Test updating a user's password with an invalid current password""" + user = create_user(password="testpassword") + + access_token = generate_access_token( + issue_time=getnow(), + user_id=str(user.id), + ) + response = client.patch( + f"/v1/users/{user.username}/password", + headers={"Authorization": f"Bearer {access_token}"}, + json={"current": current, "new": new}, + ) + assert response.status_code == expected diff --git a/backend/tests/api/routes/test_zimfarm_notification.py b/backend/tests/api/routes/test_zimfarm_notification.py index 5954926..66bd311 100644 --- a/backend/tests/api/routes/test_zimfarm_notification.py +++ b/backend/tests/api/routes/test_zimfarm_notification.py @@ -46,6 +46,7 @@ ) def test_create_zimfarm_notification( client: TestClient, + access_token: str, payload: dict[str, Any], expected_status_code: HTTPStatus, ): @@ -54,6 +55,7 @@ def test_create_zimfarm_notification( response = client.post( "/v1/zimfarm-notifications", json=payload, + headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == expected_status_code if expected_status_code == HTTPStatus.ACCEPTED: @@ -81,10 +83,14 @@ def test_create_zimfarm_notification( def test_create_zimfarm_notification_is_idempotent( client: TestClient, zimfarm_notification: ZimfarmNotification, + access_token: str, ): """Test create zimfarm_notification endpoint""" - response = client.get(f"/v1/zimfarm-notifications/{zimfarm_notification.id}") + response = client.get( + f"/v1/zimfarm-notifications/{zimfarm_notification.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) assert response.status_code == HTTPStatus.OK # try to recreate same Zimfarm notification with different data (to check this @@ -98,6 +104,7 @@ def test_create_zimfarm_notification_is_idempotent( response = client.post( "/v1/zimfarm-notifications", json=payload, + headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == HTTPStatus.ACCEPTED diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 0e559b0..be38f79 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -8,8 +8,9 @@ import pytest from faker import Faker from sqlalchemy.orm import Session as OrmSession +from werkzeug.security import generate_password_hash -from cms_backend.api.token import JWTClaims +from cms_backend.api.token import generate_access_token from cms_backend.db import Session from cms_backend.db.models import ( Base, @@ -325,11 +326,15 @@ def _create_user( *, username: str | None = None, permission: RoleEnum = RoleEnum.EDITOR, + password: str | None = None, ): user = User( username=username or faker.first_name(), role=permission, idp_sub=uuid4(), + password_hash=( + None if password is None else generate_password_hash(password) + ), ) dbsession.add(user) @@ -346,19 +351,8 @@ def user(create_user: Callable[..., User]): @pytest.fixture -def mock_token_for_user(monkeypatch: pytest.MonkeyPatch) -> Callable[[User], None]: - def _mock_for_user(user: User) -> None: - def mock_decode(_: str) -> JWTClaims: - return JWTClaims( - iss="https://test.kiwix.org", - subject=user.idp_sub, - name=user.username, - iat=getnow(), - exp=getnow(), - ) - - monkeypatch.setattr( - "cms_backend.api.routes.dependencies.token_decoder.decode", mock_decode - ) - - return _mock_for_user +def access_token(user: User) -> str: + return generate_access_token( + issue_time=getnow(), + user_id=str(user.id), + ) diff --git a/backend/tests/db/test_refresh_token.py b/backend/tests/db/test_refresh_token.py new file mode 100644 index 0000000..03b17cf --- /dev/null +++ b/backend/tests/db/test_refresh_token.py @@ -0,0 +1,72 @@ +import datetime +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session as OrmSession + +from cms_backend.db.exceptions import RecordDoesNotExistError +from cms_backend.db.models import Refreshtoken, User +from cms_backend.db.refresh_token import ( + create_refresh_token, + delete_refresh_token, + expire_refresh_tokens, + get_refresh_token, + get_refresh_token_or_none, +) +from cms_backend.utils.datetime import getnow + + +@pytest.fixture +def refresh_token(dbsession: OrmSession, user: User) -> Refreshtoken: + """Create a refresh token for a user""" + token = Refreshtoken( + token=uuid4(), + expire_time=getnow() + datetime.timedelta(seconds=1_000), + ) + token.user = user + dbsession.add(token) + dbsession.flush() + return token + + +def test_get_refresh_token_or_none(dbsession: OrmSession): + """Test db returns None if the refresh token does not exist""" + refresh_token = get_refresh_token_or_none(dbsession, uuid4()) + assert refresh_token is None + + +def test_create_refresh_token(dbsession: OrmSession, user: User): + """Test that create_refresh_token creates a refresh token""" + refresh_token = create_refresh_token(dbsession, user.id) + assert refresh_token is not None + assert refresh_token.token is not None + assert refresh_token.user_id == user.id + assert refresh_token.expire_time is not None + + +def test_get_refresh_token(dbsession: OrmSession, refresh_token: Refreshtoken): + """Test that get_refresh_token returns the refresh token""" + db_refresh_token = get_refresh_token(dbsession, refresh_token.token) + assert db_refresh_token is not None + assert db_refresh_token.token == refresh_token.token + + +def test_delete_refresh_token(dbsession: OrmSession, refresh_token: Refreshtoken): + """Test that delete_refresh_token deletes the refresh token""" + delete_refresh_token(dbsession, refresh_token.token) + assert get_refresh_token_or_none(dbsession, refresh_token.token) is None + + +def test_expire_refresh_tokens(dbsession: OrmSession, refresh_token: Refreshtoken): + """Test that expire_refresh_tokens expires the refresh tokens""" + expire_refresh_tokens( + dbsession, + getnow() + datetime.timedelta(seconds=2_000), + ) + assert get_refresh_token_or_none(dbsession, refresh_token.token) is None + + +def test_delete_refresh_token_not_found(dbsession: OrmSession): + """Test aises an exception if the refresh token does not exist""" + with pytest.raises(RecordDoesNotExistError): + delete_refresh_token(dbsession, uuid4()) diff --git a/backend/tests/db/test_user.py b/backend/tests/db/test_user.py index 576bf1f..eb6ab85 100644 --- a/backend/tests/db/test_user.py +++ b/backend/tests/db/test_user.py @@ -8,6 +8,7 @@ ) from cms_backend.db.models import User from cms_backend.db.user import ( + delete_user, get_user_by_id, get_user_by_id_or_none, get_user_by_username, @@ -65,9 +66,9 @@ def test_get_user_by_id_not_found(dbsession: OrmSession): def test_get_user_by_id(dbsession: OrmSession, user: User): """Test that get_user_by_id returns the user if the user exists""" - db_user = get_user_by_id(dbsession, user_id=user.idp_sub) + db_user = get_user_by_id(dbsession, user_id=user.id) assert db_user is not None - assert db_user.idp_sub == user.idp_sub + assert db_user.id == user.id def test_get_user_by_username_or_none(dbsession: OrmSession): @@ -80,3 +81,10 @@ def test_get_user_by_username_not_found(dbsession: OrmSession): """Test that get_user_by_username raises an exception if the user does not exist""" with pytest.raises(RecordDoesNotExistError): get_user_by_username(dbsession, username="doesnotexist") + + +def test_delete_user(dbsession: OrmSession, user: User): + """Test that delete_user marks user as deleted""" + delete_user(dbsession, user_id=user.id) + dbsession.refresh(user) + assert user.deleted diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index 966a9f0..c4a2bd3 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -41,6 +41,8 @@ services: OAUTH_SESSION_AUDIENCE_ID: 309693e7-ad5e-4379-bf93-ba89314230fd OAUTH_SESSION_LOGIN_REQUIRE_2FA: true CREATE_NEW_OAUTH_ACCOUNT: true + AUTH_MODES: local + JWT_SECRET: DH8kSxcflUVfNRdkEiJJCn2dOOKI3qfw command: - uvicorn - cms_backend.api.main:app @@ -103,6 +105,8 @@ services: environment: DATABASE_URL: postgresql+psycopg://cms:cmspass@postgresdb:5432/cmstest ALEMBIC_UPGRADE_HEAD_ON_START: false + AUTH_MODES: local + JWT_SECRET: DH8kSxcflUVfNRdkEiJJCn2dOOKI3qfw depends_on: - postgresdb frontend: diff --git a/dev/frontend-dev/config.json b/dev/frontend-dev/config.json index f26b264..c031138 100644 --- a/dev/frontend-dev/config.json +++ b/dev/frontend-dev/config.json @@ -1,4 +1,5 @@ { "CMS_API": "http://localhost:37601/v1", - "OAUTH_BASE_URL": "https://login-staging.kiwix.org" + "OAUTH_BASE_URL": "https://login-staging.kiwix.org", + "LOGIN_MODES": ["local", "oauth"] } diff --git a/frontend/public/config.json b/frontend/public/config.json index 6411ad5..2d5e951 100644 --- a/frontend/public/config.json +++ b/frontend/public/config.json @@ -4,5 +4,6 @@ "MATOMO_HOST": "https://stats.kiwix.org", "MATOMO_SITE_ID": 12, "MATOMO_TRACKER_FILE_NAME": "matomo", - "OAUTH_BASE_URL": "https://tender-wiles-i8ance55ra.projects.oryapis.com" + "OAUTH_BASE_URL": "https://login.kiwix.org", + "LOGIN_MODES": ["local", "oauth"] } diff --git a/frontend/src/config.ts b/frontend/src/config.ts index 2d3d4ec..b2ab241 100644 --- a/frontend/src/config.ts +++ b/frontend/src/config.ts @@ -10,6 +10,7 @@ export interface Config { MATOMO_SITE_ID: number MATOMO_TRACKER_FILE_NAME: string OAUTH_BASE_URL: string + LOGIN_MODES: Array } export const ConfigService = { diff --git a/frontend/src/constants.ts b/frontend/src/constants.ts index 16a75b3..3340943 100644 --- a/frontend/src/constants.ts +++ b/frontend/src/constants.ts @@ -4,6 +4,7 @@ import type { InjectionKey } from 'vue' export default { config: Symbol() as InjectionKey, COOKIE_LIFETIME_EXPIRY: '10y', // 10 years + TOKEN_STORAGE_KEY: 'cms-auth', // Notification constants NOTIFICATION_DEFAULT_DURATION: 5000, // 5 seconds NOTIFICATION_ERROR_DURATION: 8000, // 8 seconds for errors diff --git a/frontend/src/services/auth/LocalAuthProvider.ts b/frontend/src/services/auth/LocalAuthProvider.ts new file mode 100644 index 0000000..7a7d247 --- /dev/null +++ b/frontend/src/services/auth/LocalAuthProvider.ts @@ -0,0 +1,111 @@ +import type { StoredToken } from '@/types/auth' +import constants from '@/constants' +import { AuthProvider } from '@/services/auth/base' +import httpRequest from '@/utils/httpRequest' + +/** + * Local authentication provider for username/password authentication + * Uses the CMS API's /auth endpoints + */ +export class LocalAuthProvider extends AuthProvider { + private cmsApiBaseUrl: string + + constructor(cmsApiBaseUrl: string) { + super() + this.cmsApiBaseUrl = cmsApiBaseUrl + } + + /** + * Initiates local login - not applicable for local auth + * Local auth uses direct username/password authentication via authenticate method + */ + async initiateLogin(username?: string, password?: string): Promise { + const service = httpRequest({ + baseURL: `${this.cmsApiBaseUrl}/auth`, + }) + const response = await service.post< + { username: string; password: string }, + { access_token: string; refresh_token: string; expires_time: string } + >('/authorize', { + username, + password, + }) + + const newToken: StoredToken = { + access_token: response.access_token, + refresh_token: response.refresh_token, + token_type: 'local', + expires_time: response.expires_time, + } + this.saveToken(newToken) + } + + saveToken(token: StoredToken): null { + localStorage.setItem(constants.TOKEN_STORAGE_KEY, JSON.stringify(token)) + return null + } + + removeToken(): void { + localStorage.removeItem(constants.TOKEN_STORAGE_KEY) + } + + async loadToken(): Promise { + const storedValue = localStorage.getItem(constants.TOKEN_STORAGE_KEY) + if (!storedValue) return null + let storedToken: StoredToken + try { + storedToken = JSON.parse(storedValue) + + // Validate token structure + if (!storedToken.access_token || !storedToken.refresh_token || !storedToken.token_type) { + throw new Error('Invalid token structure in localStorage') + } + } catch (error) { + console.error('Error parsing localStorage value', error) + // Incorrect token payload + this.removeToken() + return null + } + return storedToken + } + + /** + * Logout from local auth + * For local auth, we just clear client-side state (no server-side revocation) + */ + async logout(): Promise { + this.removeToken() + } + + /** + * Refresh access token using refresh token for local auth + */ + async refreshAuth(refreshToken: string): Promise { + const service = httpRequest({ + baseURL: `${this.cmsApiBaseUrl}/auth`, + }) + const response = await service.post< + { refresh_token: string }, + { access_token: string; refresh_token: string; expires_time: string } + >('/refresh', { + refresh_token: refreshToken, + }) + + const newToken: StoredToken = { + access_token: response.access_token, + refresh_token: response.refresh_token, + token_type: 'local', + expires_time: response.expires_time, + } + + this.saveToken(newToken) + return newToken + } + + /** + * Callback handling not applicable for local auth + */ + async onCallback(): Promise { + throw new Error('onCallback not applicable for local username/password authentication') + } +} diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index f6038d5..e3a1ce3 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -5,9 +5,11 @@ import httpRequest from '@/utils/httpRequest' import type { ErrorResponse, OAuth2ErrorResponse } from '@/types/errors' import { defineStore } from 'pinia' import { inject, ref, computed } from 'vue' -import type { StoredToken } from '@/types/auth' +import type { StoredToken, AuthProviderType } from '@/types/auth' import { getOAuthConfig } from '@/services/auth/base' import { OAuthSessionProvider } from '@/services/auth/OAuthSessionProvider' +import { LocalAuthProvider } from '@/services/auth/LocalAuthProvider' +import type { AuthProvider } from '@/services/auth/base' import type { User } from '@/types/user' export const useAuthStore = defineStore('auth', () => { @@ -22,6 +24,21 @@ export const useAuthStore = defineStore('auth', () => { } const oauthProvider = new OAuthSessionProvider(getOAuthConfig(config)) + const localauthProvider = new LocalAuthProvider(config.CMS_API) + + const getAuthProvider = (providerType: AuthProviderType): AuthProvider => { + switch (providerType) { + case 'oauth': + if (!oauthProvider) { + throw new Error('No oauth provider configured.') + } + return oauthProvider + case 'local': + return localauthProvider + default: + throw new Error(`Unknown auth provider type: ${providerType}`) + } + } // Track refresh state to prevent duplicate requests const isRefreshFailed = ref(false) @@ -88,13 +105,18 @@ export const useAuthStore = defineStore('auth', () => { return false } - const authenticate = async () => { + const authenticate = async ( + providerType: AuthProviderType, + username?: string, + password?: string, + ) => { try { - await oauthProvider.initiateLogin() + const provider = getAuthProvider(providerType) + await provider.initiateLogin(username, password) // Oauth providers typically redirect to a new url as part of the // login process. If we are still here, it means this is from the local // provider which has stored the token - const newToken = await oauthProvider.loadToken() + const newToken = await provider.loadToken() if (!newToken) { throw new Error('Invalid authentication token') } @@ -102,7 +124,7 @@ export const useAuthStore = defineStore('auth', () => { await fetchUserInfo(newToken.access_token) errors.value = [] - oauthProvider.saveToken(newToken) + provider.saveToken(newToken) isRefreshFailed.value = false @@ -155,8 +177,15 @@ export const useAuthStore = defineStore('auth', () => { } let storedToken: StoredToken | null = null + // Try to load from kiwx/local providers as we don't know which try { - storedToken = await oauthProvider.loadToken() + if (oauthProvider) { + storedToken = await oauthProvider.loadToken() + } + + if (!storedToken) { + storedToken = await localauthProvider.loadToken() + } } catch (error: unknown) { console.error('Failed to load token:', error) await logout() @@ -196,8 +225,9 @@ export const useAuthStore = defineStore('auth', () => { const renewToken = async (storedToken: StoredToken): Promise => { // If refresh has already failed permanently, don't retry + const provider = getAuthProvider(storedToken.token_type) if (isRefreshFailed.value) { - oauthProvider.removeToken() + provider.removeToken() return null } @@ -212,7 +242,7 @@ export const useAuthStore = defineStore('auth', () => { } // Create and store the refresh promise to prevent duplicate requests - refreshPromise.value = oauthProvider.refreshAuth() + refreshPromise.value = provider.refreshAuth(storedToken.refresh_token) try { const newToken = await refreshPromise.value @@ -228,7 +258,7 @@ export const useAuthStore = defineStore('auth', () => { // Check if this is a permanent failure if (isPermanentRefreshFailure(error)) { isRefreshFailed.value = true - oauthProvider.removeToken() + provider.removeToken() } token.value = null @@ -245,7 +275,8 @@ export const useAuthStore = defineStore('auth', () => { // If we have a Kiwix token, revoke it if (token.value?.token_type) { try { - await oauthProvider.logout() + const provider = getAuthProvider(token.value?.token_type) + await provider.logout() } catch (error) { console.error('Error revoking token:', error) } @@ -259,16 +290,17 @@ export const useAuthStore = defineStore('auth', () => { refreshPromise.value = null } - const handleCallBack = async () => { + const handleCallBack = async (providerType: AuthProviderType, callbackUrl: string) => { try { - const newToken = await oauthProvider.onCallback() + const provider = getAuthProvider(providerType) + const newToken = await provider.onCallback(callbackUrl) token.value = newToken // Fetch user info from backend using the Kiwix token await fetchUserInfo(newToken.access_token) errors.value = [] - oauthProvider.saveToken(newToken) + provider.saveToken(newToken) // Reset refresh failure state on successful login isRefreshFailed.value = false @@ -281,6 +313,7 @@ export const useAuthStore = defineStore('auth', () => { return false } } + return { // State errors, diff --git a/frontend/src/types/auth.ts b/frontend/src/types/auth.ts index c55ce45..21dffd0 100644 --- a/frontend/src/types/auth.ts +++ b/frontend/src/types/auth.ts @@ -1,4 +1,4 @@ -export type AuthProviderType = 'oauth' +export type AuthProviderType = 'local' | 'oauth' export interface StoredToken { access_token: string diff --git a/frontend/src/views/OAuthCallbackView.vue b/frontend/src/views/OAuthCallbackView.vue index 5ed968d..ca94b7d 100644 --- a/frontend/src/views/OAuthCallbackView.vue +++ b/frontend/src/views/OAuthCallbackView.vue @@ -48,7 +48,7 @@ const error = ref(null) onMounted(async () => { try { - const success = await authStore.handleCallBack() + const success = await authStore.handleCallBack('oauth', window.location.href) if (!success) { throw new Error(authStore.errors.join(', ') || 'Authentication failed') } diff --git a/frontend/src/views/SignInView.vue b/frontend/src/views/SignInView.vue index c21b3cc..ff72812 100644 --- a/frontend/src/views/SignInView.vue +++ b/frontend/src/views/SignInView.vue @@ -19,7 +19,81 @@

Please sign in

+ + + + + {{ error }} + + + + + + Signing you in... + + + + + + + + + + Sign in with Username + + + + + OR + +