diff --git a/README.md b/README.md index 5b489a8..412aaf6 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,14 @@ The goal of this project is to provide a Python-based starter API, which comes p ## Table of Contents 1. [Running the Project Locally](#running-the-project-locally) -2. [Running with Docker](#running-with-docker) -3. [Running Unit Tests](#running-unit-tests) -4. [Running Code Quality Checks](#running-code-quality-checks) -5. [Running Code Formatting](#running-code-formatting) -6. [Publishing Updated Docs](#publishing-updated-docs) -7. [Contributing](#contributing) -8. [Next Steps](#next-steps) +2. [Initializing PostgreSQL Database](#initializing-postgresql-database) +3. [Running with Docker](#running-with-docker) +4. [Running Unit Tests](#running-unit-tests) +5. [Running Code Quality Checks](#running-code-quality-checks) +6. [Running Code Formatting](#running-code-formatting) +7. [Publishing Updated Docs](#publishing-updated-docs) +8. [Contributing](#contributing) +9. [Next Steps](#next-steps) ## Running the Project Locally @@ -49,7 +50,8 @@ pip install -e ".[dev]" ``` API_PREFIX=[SOME_ROUTE] # Ex: '/api' DATABASE_URL=[SOME_URL] # Ex: 'postgresql://username:password@localhost:5432/database_name' -OIDC_CONFIG_URL=[SOME_URL] # Ex: 'https://token.actions.githubusercontent.com/.well-known/openid-configuration' +OIDC_CONFIG_URL=[SOME_URL] # Ex: 'https://keycloak.auth.metrostar.cloud/auth/realms/dev/.well-known/openid-configuration' +LOG_LEVEL=[LOG_LEVEL] # Ex: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' (Default: 'INFO') ``` 5. To start the app, run the following: @@ -60,6 +62,47 @@ uvicorn app.main:app --reload --host=0.0.0.0 --port=5000 6. Access the swagger docs by navigating to: `http://0.0.0.0:5000/docs` +## Initializing PostgreSQL Database + +If you're using PostgreSQL instead of SQLite, you can use the provided initialization script to set up your database: + +1. Ensure your `.env` file contains a PostgreSQL DATABASE_URL: + +``` +DATABASE_URL=postgresql://username:password@localhost:5432/database_name +``` + +2. Run the initialization script using either method: + +**Using the shell script:** + +```sh +./scripts/init_db.sh +``` + +**Or using Python directly:** + +```sh +python scripts/init_postgres.py +``` + +3. To seed initial data along with the schema (optional): + +```sh +./scripts/init_db.sh --seed +``` + +**Script Options:** + +- `--seed`: Seed initial data after running migrations +- `--skip-create`: Skip database creation (only run migrations) + +The script will: + +- Create the database if it doesn't exist +- Run all Alembic migrations to set up the schema +- Optionally seed initial data + ## Running with Docker 1. To build the image, run the following: @@ -136,7 +179,6 @@ The following provides a short list of tasks which are potential next steps for - [ ] Add/Update existing endpoints with more applicable entities and/or columns - [ ] Update applicable endpoints to require JWT -- [ ] Add Admin endpoints to support password reset - [ ] Replace default database with external database (Ex. Postgres) - [ ] Deploy to cloud infrastructure - [ ] Automate doc publishing process diff --git a/app/admin/router.py b/app/admin/router.py index fa939b7..0b408ed 100644 --- a/app/admin/router.py +++ b/app/admin/router.py @@ -1,7 +1,6 @@ from typing import Annotated from fastapi import APIRouter, Depends -from fastapi.security import HTTPBearer from starlette import status from app.auth import validate_jwt @@ -15,8 +14,15 @@ @router.get( "/current-user", - dependencies=[Depends(HTTPBearer())], status_code=status.HTTP_200_OK, ) async def get_current_user(current_user: Annotated[dict, Depends(validate_jwt)]): + """Get the current authenticated user information. + + Args: + current_user: Validated JWT payload containing user information. + + Returns: + dict: User information from the JWT token. + """ return {"user": current_user} diff --git a/app/applicants/models.py b/app/applicants/models.py index 284e495..d3e8f96 100644 --- a/app/applicants/models.py +++ b/app/applicants/models.py @@ -5,6 +5,12 @@ class DBApplicant(Base): + """SQLAlchemy model for applicant data. + + Represents an applicant in the database with personal information, + contact details, and address information. + """ + __tablename__ = "applicants" id = Column(Integer, primary_key=True, index=True) diff --git a/app/applicants/router.py b/app/applicants/router.py index bdd38f3..acde63b 100644 --- a/app/applicants/router.py +++ b/app/applicants/router.py @@ -26,11 +26,30 @@ @router.get("/", status_code=status.HTTP_200_OK, response_model=ApplicantListResponse) async def get_applicants(db: db_session, page_number: int = 0, page_size: int = 100): + """Retrieve a paginated list of all applicants. + + Args: + db: Database session. + page_number: Page number for pagination (default: 0). + page_size: Number of items per page (default: 100). + + Returns: + ApplicantListResponse: Paginated list of applicants. + """ return service.get_items(db, page_number, page_size) @router.post("/", status_code=status.HTTP_201_CREATED, response_model=ApplicantResponse) async def create_applicant(applicant: ApplicantCreate, db: db_session): + """Create a new applicant. + + Args: + applicant: Applicant data to create. + db: Database session. + + Returns: + ApplicantResponse: The created applicant. + """ db_applicant = service.create_item(db, applicant) return db_applicant @@ -39,6 +58,15 @@ async def create_applicant(applicant: ApplicantCreate, db: db_session): "/{applicant_id}", status_code=status.HTTP_200_OK, response_model=ApplicantResponse ) async def get_applicant(applicant_id: int, db: db_session): + """Retrieve a single applicant by ID. + + Args: + applicant_id: ID of the applicant to retrieve. + db: Database session. + + Returns: + ApplicantResponse: The requested applicant. + """ return service.get_item(db, applicant_id) @@ -48,10 +76,26 @@ async def get_applicant(applicant_id: int, db: db_session): async def update_applicant( applicant_id: int, applicant: ApplicantUpdate, db: db_session ): + """Update an existing applicant. + + Args: + applicant_id: ID of the applicant to update. + applicant: Updated applicant data. + db: Database session. + + Returns: + ApplicantResponse: The updated applicant. + """ db_applicant = service.update_item(db, applicant_id, applicant) return db_applicant @router.delete("/{applicant_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_applicant(applicant_id: int, db: db_session): + """Delete an applicant. + + Args: + applicant_id: ID of the applicant to delete. + db: Database session. + """ service.delete_item(db, applicant_id) diff --git a/app/applicants/schemas.py b/app/applicants/schemas.py index b100438..6fb43b3 100644 --- a/app/applicants/schemas.py +++ b/app/applicants/schemas.py @@ -5,6 +5,11 @@ # Pydantic Models class ApplicantBase(BaseModel): + """Base Pydantic model for applicant data. + + Contains common fields shared across create, update, and response models. + """ + first_name: str = Field(..., min_length=1, max_length=50) last_name: str = Field(..., min_length=1, max_length=50) middle_name: str | None = Field(None, min_length=1, max_length=50) @@ -22,10 +27,15 @@ class ApplicantBase(BaseModel): class ApplicantCreate(ApplicantBase): - pass + """Pydantic model for creating a new applicant.""" class ApplicantUpdate(BaseModel): + """Pydantic model for updating an existing applicant. + + All fields are optional to support partial updates. + """ + first_name: str | None = None last_name: str | None = None middle_name: str | None = None @@ -43,6 +53,11 @@ class ApplicantUpdate(BaseModel): class ApplicantResponse(ApplicantBase): + """Pydantic model for applicant API responses. + + Includes database-generated fields like id, created_at, and updated_at. + """ + model_config = ConfigDict(from_attributes=True) id: int created_at: datetime @@ -50,6 +65,11 @@ class ApplicantResponse(ApplicantBase): class ApplicantListResponse(BaseModel): + """Pydantic model for paginated list of applicants. + + Contains pagination metadata along with the list of applicants. + """ + items: list[ApplicantResponse] item_count: int = 0 page_count: int = 0 diff --git a/app/applicants/services.py b/app/applicants/services.py index 752bf48..060089a 100644 --- a/app/applicants/services.py +++ b/app/applicants/services.py @@ -1,3 +1,5 @@ +import logging + from fastapi import HTTPException from sqlalchemy.orm import Session @@ -5,10 +7,24 @@ from app.applicants.schemas import ApplicantCreate, ApplicantUpdate from app.utils import get_next_page, get_page_count, get_prev_page +logger = logging.getLogger(__name__) + def get_items(db: Session, page_number: int, page_size: int): + """Retrieve a paginated list of applicants. + + Args: + db: Database session. + page_number: Current page number (0-indexed). + page_size: Number of items per page. + + Returns: + dict: Paginated response containing applicants and pagination metadata. + """ + logger.debug("Fetching applicants - page: %s, size: %s", page_number, page_size) item_count = db.query(DBApplicant).count() items = db.query(DBApplicant).limit(page_size).offset(page_number * page_size).all() + logger.info("Retrieved %s applicants (total: %s)", len(items), item_count) return { "items": items, @@ -20,42 +36,98 @@ def get_items(db: Session, page_number: int, page_size: int): def create_item(db: Session, applicant: ApplicantCreate): + """Create a new applicant in the database. + + Args: + db: Database session. + applicant: Applicant data to create. + + Returns: + DBApplicant: The created applicant record. + """ + logger.debug("Creating new applicant with email: %s", applicant.email) db_applicant = DBApplicant(**applicant.model_dump()) db.add(db_applicant) db.commit() db.refresh(db_applicant) + logger.info("Created applicant with id: %s", db_applicant.id) return db_applicant def get_item(db: Session, applicant_id: int): - return db.query(DBApplicant).where(DBApplicant.id == applicant_id).first() + """Retrieve a single applicant by ID. + + Args: + db: Database session. + applicant_id: ID of the applicant to retrieve. + + Returns: + DBApplicant | None: The applicant record if found, None otherwise. + """ + logger.debug("Fetching applicant with id: %s", applicant_id) + applicant = db.query(DBApplicant).where(DBApplicant.id == applicant_id).first() + if applicant: + logger.info("Retrieved applicant with id: %s", applicant_id) + else: + logger.warning("Applicant not found with id: %s", applicant_id) + return applicant def update_item(db: Session, id: int, applicant: ApplicantUpdate): + """Update an existing applicant. + + Args: + db: Database session. + id: ID of the applicant to update. + applicant: Updated applicant data. + + Returns: + DBApplicant: The updated applicant record. + + Raises: + HTTPException: If applicant is not found (404). + """ + logger.debug("Updating applicant with id: %s", id) db_applicant = db.query(DBApplicant).filter(DBApplicant.id == id).first() if db_applicant is None: + logger.warning("Applicant not found for update with id: %s", id) raise HTTPException(status_code=404, detail="Applicant not found") - # Only update fields that are provided (not None) + # Update only the fields that are explicitly set in the request update_data = applicant.model_dump(exclude_unset=True) for field, value in update_data.items(): - if value is not None: - setattr(db_applicant, field, value) + setattr(db_applicant, field, value) db.add(db_applicant) db.commit() db.refresh(db_applicant) + logger.info("Updated applicant with id: %s", id) return db_applicant def delete_item(db: Session, id: int): + """Delete an applicant from the database. + + Args: + db: Database session. + id: ID of the applicant to delete. + + Returns: + None + + Raises: + HTTPException: If applicant is not found (404). + """ + logger.debug("Deleting applicant with id: %s", id) db_applicant = db.query(DBApplicant).filter(DBApplicant.id == id).first() if db_applicant is None: + logger.warning("Applicant not found for deletion with id: %s", id) raise HTTPException(status_code=404, detail="Applicant not found") db.query(DBApplicant).filter(DBApplicant.id == id).delete() db.commit() + logger.info("Deleted applicant with id: %s", id) return None diff --git a/app/auth.py b/app/auth.py index 55669af..f7ecea7 100644 --- a/app/auth.py +++ b/app/auth.py @@ -1,40 +1,151 @@ +import time + import requests from fastapi import Depends, HTTPException -from jose.backends import RSAKey +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose.jwt import decode, get_unverified_header from app.config import settings +security = HTTPBearer() + +# JWKS cache with TTL (30 minutes = 1800 seconds) +_jwks_cache: dict | None = None +_jwks_cache_timestamp: float | None = None +JWKS_CACHE_TTL = 1800 + def get_keycloak_jwks(): + """Retrieve JSON Web Key Set (JWKS) from Keycloak OIDC provider. + + Fetches the well-known OIDC configuration from the configured URL + and retrieves the JWKS containing public keys for JWT validation. + JWKS is cached for 30 minutes to reduce redundant HTTP requests. + + Returns: + list: List of JSON Web Keys from the OIDC provider. + + Raises: + HTTPException: If OIDC_CONFIG_URL is not configured or request fails. + """ + global _jwks_cache, _jwks_cache_timestamp + + # Check if cache is valid + current_time = time.time() + if ( + _jwks_cache is not None + and _jwks_cache_timestamp is not None + and current_time - _jwks_cache_timestamp < JWKS_CACHE_TTL + ): + return _jwks_cache + keycloak_well_known_url = settings.OIDC_CONFIG_URL - response = requests.get(keycloak_well_known_url) - well_known_config = response.json() - jwks_url = well_known_config["jwks_uri"] - jwks_response = requests.get(jwks_url) - jwks = jwks_response.json() - return jwks["keys"] + if not keycloak_well_known_url: + raise HTTPException( + status_code=500, + detail="OIDC_CONFIG_URL is not configured", + ) + + try: + response = requests.get(keycloak_well_known_url, timeout=5) + response.raise_for_status() + well_known_config = response.json() + except requests.exceptions.RequestException as e: + raise HTTPException( + status_code=500, + detail=f"Failed to fetch OIDC configuration: {e!s}", + ) from e + except ValueError as e: + raise HTTPException( + status_code=500, + detail=f"Invalid JSON response from OIDC provider: {e!s}", + ) from e + + try: + jwks_url = well_known_config["jwks_uri"] + jwks_response = requests.get(jwks_url, timeout=5) + jwks_response.raise_for_status() + jwks = jwks_response.json() + jwks_keys = jwks["keys"] + + # Update cache + _jwks_cache = jwks_keys + _jwks_cache_timestamp = current_time + + return jwks_keys + except requests.exceptions.RequestException as e: + raise HTTPException( + status_code=500, + detail=f"Failed to fetch JWKS: {e!s}", + ) from e + except (ValueError, KeyError) as e: + raise HTTPException( + status_code=500, + detail=f"Invalid JWKS response: {e!s}", + ) from e + + +def validate_jwt( + credentials: HTTPAuthorizationCredentials = Depends(security), +) -> dict: + """Validate a JSON Web Token (JWT) using JWKS from the OIDC provider. + + Extracts the key ID from the JWT header, finds the matching RSA key + in the JWKS, and decodes/validates the token. + + Args: + credentials: HTTP Authorization credentials containing the JWT token. + + Returns: + dict: Decoded JWT payload containing user claims. + + Raises: + HTTPException: If RSA key is not found (401) or JWT is invalid (401). + """ + token = credentials.credentials + + try: + header = get_unverified_header(token) + except Exception as e: + raise HTTPException( + status_code=401, + detail=f"Invalid JWT token format: {e!s}", + ) from e -def validate_jwt(token: str = Depends(get_keycloak_jwks)): - header = get_unverified_header(token) jwks = get_keycloak_jwks() # Find the RSA key with the matching kid in the JWKS rsa_key = None for key in jwks: - if key["kid"] == header["kid"]: - rsa_key = RSAKey( - key=key["n"], alg=key["alg"], use=key["use"], kid=key["kid"] - ) + if key.get("kid") == header.get("kid"): + rsa_key = { + "kty": key.get("kty"), + "kid": key.get("kid"), + "use": key.get("use"), + "n": key.get("n"), + "e": key.get("e"), + } break if rsa_key is None: - raise HTTPException(status_code=401, detail="RSA Key not found in JWKS") + # Provide helpful debugging information + available_kids = [key.get("kid") for key in jwks] + token_kid = header.get("kid") + raise HTTPException( + status_code=401, + detail=( + f"RSA Key not found in JWKS. Token kid: {token_kid}, " + f"Available kids: {available_kids}" + ), + ) try: payload = decode(token, rsa_key, algorithms=[header["alg"]]) - except Exception: - raise HTTPException(status_code=401, detail="Invalid JWT token") from Exception + except Exception as e: + raise HTTPException( + status_code=401, + detail=f"Invalid JWT token: {e!s}", + ) from e return payload diff --git a/app/cases/models.py b/app/cases/models.py index 59e9147..12128f7 100644 --- a/app/cases/models.py +++ b/app/cases/models.py @@ -6,6 +6,12 @@ # SQLAlchemy Model class DBCase(Base): + """SQLAlchemy model for case data. + + Represents a case in the database with status tracking, + assignment information, and a relationship to an applicant. + """ + __tablename__ = "cases" id = Column(Integer, primary_key=True, index=True) diff --git a/app/cases/router.py b/app/cases/router.py index 44f9606..162b82f 100644 --- a/app/cases/router.py +++ b/app/cases/router.py @@ -27,11 +27,30 @@ @router.get("/", status_code=status.HTTP_200_OK, response_model=CaseListResponse) async def get_cases(db: db_session, page_number: int = 0, page_size: int = 100): + """Retrieve a paginated list of all cases. + + Args: + db: Database session. + page_number: Page number for pagination (default: 0). + page_size: Number of items per page (default: 100). + + Returns: + CaseListResponse: Paginated list of cases. + """ return service.get_items(db, page_number, page_size) @router.post("/", status_code=status.HTTP_201_CREATED, response_model=CaseResponse) async def create_case(case: CaseCreate, db: db_session): + """Create a new case. + + Args: + case: Case data to create. + db: Database session. + + Returns: + CaseResponse: The created case. + """ db_case = service.create_item(db, case) return db_case @@ -40,15 +59,40 @@ async def create_case(case: CaseCreate, db: db_session): "/{case_id}", status_code=status.HTTP_200_OK, response_model=CaseWithApplicant ) async def get_case(case_id: int, db: db_session): + """Retrieve a single case by ID with applicant details. + + Args: + case_id: ID of the case to retrieve. + db: Database session. + + Returns: + CaseWithApplicant: The requested case with applicant information. + """ return service.get_item(db, case_id) @router.put("/{case_id}", status_code=status.HTTP_200_OK, response_model=CaseResponse) async def update_case(case_id: int, case: CaseUpdate, db: db_session): + """Update an existing case. + + Args: + case_id: ID of the case to update. + case: Updated case data. + db: Database session. + + Returns: + CaseResponse: The updated case. + """ db_case = service.update_item(db, case_id, case) return db_case @router.delete("/{case_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_case(case_id: int, db: db_session): + """Delete a case. + + Args: + case_id: ID of the case to delete. + db: Database session. + """ service.delete_item(db, case_id) diff --git a/app/cases/schemas.py b/app/cases/schemas.py index 9f4b838..60c4af5 100644 --- a/app/cases/schemas.py +++ b/app/cases/schemas.py @@ -11,20 +11,41 @@ # Pydantic Models class CaseBase(BaseModel): + """Base Pydantic model for case data. + + Contains common fields shared across create, update, and response models. + """ + status: CASE_STATUS assigned_to: str | None = Field(None, min_length=1, max_length=255) class CaseCreate(CaseBase): + """Pydantic model for creating a new case. + + Requires an applicant_id to associate the case with an applicant. + """ + applicant_id: int class CaseUpdate(BaseModel): + """Pydantic model for updating an existing case. + + All fields are optional to support partial updates. + """ + status: CASE_STATUS | None = None assigned_to: str | None = None class CaseResponse(CaseBase): + """Pydantic model for case API responses. + + Includes database-generated fields like id, applicant_id, + created_at, and updated_at. + """ + model_config = ConfigDict(from_attributes=True) id: int applicant_id: int @@ -33,10 +54,20 @@ class CaseResponse(CaseBase): class CaseWithApplicant(CaseResponse): + """Pydantic model for case responses with applicant details. + + Extends CaseResponse to include the full applicant information. + """ + applicant: ApplicantResponse | None = None class CaseListResponse(BaseModel): + """Pydantic model for paginated list of cases. + + Contains pagination metadata along with the list of cases. + """ + items: list[CaseWithApplicant] item_count: int page_count: int diff --git a/app/cases/services.py b/app/cases/services.py index 41bb63b..56f34c3 100644 --- a/app/cases/services.py +++ b/app/cases/services.py @@ -1,3 +1,5 @@ +import logging + from fastapi import HTTPException from sqlalchemy.orm import Session, joinedload @@ -5,10 +7,24 @@ from app.cases.schemas import CaseCreate, CaseUpdate from app.utils import get_next_page, get_page_count, get_prev_page +logger = logging.getLogger(__name__) + def get_items(db: Session, page_number: int, page_size: int): + """Retrieve a paginated list of cases. + + Args: + db: Database session. + page_number: Current page number (0-indexed). + page_size: Number of items per page. + + Returns: + dict: Paginated response containing cases and pagination metadata. + """ + logger.debug("Fetching cases - page: %s, size: %s", page_number, page_size) item_count = db.query(DBCase).count() items = db.query(DBCase).limit(page_size).offset(page_number * page_size).all() + logger.info("Retrieved %s cases (total: %s)", len(items), item_count) return { "items": items, @@ -20,15 +36,39 @@ def get_items(db: Session, page_number: int, page_size: int): def create_item(db: Session, case: CaseCreate): + """Create a new case in the database. + + Args: + db: Database session. + case: Case data to create. + + Returns: + DBCase: The created case record. + """ + logger.debug("Creating new case with applicant_id: %s", case.applicant_id) db_case = DBCase(**case.model_dump()) db.add(db_case) db.commit() db.refresh(db_case) + logger.info("Created case with id: %s", db_case.id) return db_case def get_item(db: Session, case_id: int): + """Retrieve a single case by ID with its associated applicant. + + Args: + db: Database session. + case_id: ID of the case to retrieve. + + Returns: + dict: Case data with applicant information. + + Raises: + HTTPException: If case is not found (404). + """ + logger.debug("Fetching case with id: %s", case_id) case = ( db.query(DBCase) .options(joinedload(DBCase.applicant)) @@ -37,8 +77,10 @@ def get_item(db: Session, case_id: int): ) if case is None: + logger.warning("Case not found with id: %s", case_id) raise HTTPException(status_code=404, detail="Case not found") + logger.info("Retrieved case with id: %s", case_id) # Handle case where applicant might be None applicant_data = None if case.applicant: @@ -74,28 +116,59 @@ def get_item(db: Session, case_id: int): def update_item(db: Session, id: int, case: CaseUpdate): + """Update an existing case. + + Args: + db: Database session. + id: ID of the case to update. + case: Updated case data. + + Returns: + DBCase: The updated case record. + + Raises: + HTTPException: If case is not found (404). + """ + logger.debug("Updating case with id: %s", id) db_case = db.query(DBCase).filter(DBCase.id == id).first() if db_case is None: + logger.warning("Case not found for update with id: %s", id) raise HTTPException(status_code=404, detail="Case not found") - if case.status is not None: - db_case.status = case.status - if case.assigned_to is not None: - db_case.assigned_to = case.assigned_to + # Update only the fields that are explicitly set in the request + update_data = case.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(db_case, field, value) db.add(db_case) db.commit() db.refresh(db_case) + logger.info("Updated case with id: %s", id) return db_case def delete_item(db: Session, id: int): + """Delete a case from the database. + + Args: + db: Database session. + id: ID of the case to delete. + + Returns: + None + + Raises: + HTTPException: If case is not found (404). + """ + logger.debug("Deleting case with id: %s", id) db_case = db.query(DBCase).filter(DBCase.id == id).first() if db_case is None: + logger.warning("Case not found for deletion with id: %s", id) raise HTTPException(status_code=404, detail="Case not found") db.query(DBCase).filter(DBCase.id == id).delete() db.commit() + logger.info("Deleted case with id: %s", id) return None diff --git a/app/config.py b/app/config.py index 65cf799..cd7c784 100644 --- a/app/config.py +++ b/app/config.py @@ -2,11 +2,18 @@ class Settings(BaseSettings): + """Application configuration settings. + + Loads configuration from environment variables or .env file. + Provides settings for API prefix, database URL, and OIDC configuration. + """ + model_config = SettingsConfigDict(env_file=".env") API_PREFIX: str = "" DATABASE_URL: str = "sqlite:///./db.sqlite3" OIDC_CONFIG_URL: str | None = None + LOG_LEVEL: str = "INFO" settings = Settings() diff --git a/app/db.py b/app/db.py index 3c877e4..bc792a5 100644 --- a/app/db.py +++ b/app/db.py @@ -3,8 +3,28 @@ from app.config import settings +# SqlAlchemy Setup +SQLALCHEMY_DATABASE_URL = settings.DATABASE_URL +# Use in-memory SQLite database for testing +if "sqlite" in SQLALCHEMY_DATABASE_URL: + engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} + ) +else: # For other databases, use the default connection settings + engine = create_engine( + SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, pool_recycle=300 + ) + def get_db(): + """Provide a database session for dependency injection. + + Creates a database session, yields it for use in the request, + and ensures it is properly closed after the request completes. + + Yields: + Session: SQLAlchemy database session. + """ db = SessionLocal() try: yield db @@ -13,6 +33,5 @@ def get_db(): # SqlAlchemy Setup -engine = create_engine(settings.DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() diff --git a/app/health/router.py b/app/health/router.py index 7650511..27a40bc 100644 --- a/app/health/router.py +++ b/app/health/router.py @@ -11,4 +11,9 @@ @router.get("/", status_code=status.HTTP_200_OK) def get_health(): + """Health check endpoint. + + Returns: + dict: Health status indicator. + """ return {"health": "healthy"} diff --git a/app/main.py b/app/main.py index e8e1ef6..d7eefed 100644 --- a/app/main.py +++ b/app/main.py @@ -1,19 +1,41 @@ +import logging + from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from app.admin.router import router as admin_router from app.applicants.router import router as applicants_router from app.cases.router import router as cases_router from app.db import Base, engine from app.health.router import router as health_router -from app.users.router import router as users_router +from app.utils import setup_logging + +# Configure logging +setup_logging() +logger = logging.getLogger(__name__) # Create the app app = FastAPI() +logger.info("FastAPI application initialized") + +# Set up CORS middleware +# TODO: Restrict origins for production use +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +logger.info("CORS middleware configured") + # Create database Base.metadata.create_all(bind=engine) +logger.info("Database tables created") + # Add routes app.include_router(cases_router) app.include_router(applicants_router) -app.include_router(users_router) app.include_router(admin_router) app.include_router(health_router) +logger.info("API routes registered") diff --git a/app/users/__init__.py b/app/users/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/users/models.py b/app/users/models.py deleted file mode 100644 index f3ed166..0000000 --- a/app/users/models.py +++ /dev/null @@ -1,26 +0,0 @@ -from sqlalchemy import Boolean, Column, DateTime, Integer, String, func - -from app.db import Base - - -# SQLAlchemy Model -class DBUser(Base): - __tablename__ = "users" - - id = Column(Integer, primary_key=True, index=True) - user_id = Column(String(100), nullable=False) - first_name = Column(String(100), nullable=False) - last_name = Column(String(100), nullable=False) - display_name = Column(String(200), nullable=False) - email = Column(String(254), unique=True, index=True) - hashed_password = Column(String) - is_active = Column(Boolean, default=True) - created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - created_by = Column(String(100), nullable=False) - modified = Column( - DateTime(timezone=True), - server_default=func.now(), - onupdate=func.now(), - nullable=False, - ) - modified_by = Column(String(100), nullable=False) diff --git a/app/users/router.py b/app/users/router.py deleted file mode 100644 index d649d3e..0000000 --- a/app/users/router.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session -from starlette import status - -import app.users.services as service -from app.config import settings -from app.db import get_db -from app.users.schemas import UserCreate, UserListResponse, UserResponse, UserUpdate - -router = APIRouter( - prefix=f"{settings.API_PREFIX}/users", - tags=["Users"], - responses={404: {"description": "Endpoint not found"}}, -) - -# Database dependency injection session -db_session = Annotated[Session, Depends(get_db)] - - -@router.get( - "/", - status_code=status.HTTP_200_OK, - response_model=UserListResponse, -) -async def get_items(db: db_session, page_number: int = 0, page_size: int = 100): - return service.get_items(db, page_number, page_size) - - -@router.post("/", status_code=status.HTTP_201_CREATED, response_model=UserResponse) -async def create_item(item: UserCreate, db: db_session): - db_item = service.create_item(db, item) - return db_item - - -@router.get("/{id}", status_code=status.HTTP_200_OK, response_model=UserResponse) -async def get_item(id: int, db: db_session): - return service.get_item(db, id) - - -@router.put("/{id}", status_code=status.HTTP_200_OK, response_model=UserResponse) -async def update_item(id: int, item: UserUpdate, db: db_session): - db_item = service.update_item(db, id, item) - return db_item - - -@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT) -async def delete_item(id: int, db: db_session): - service.delete_item(db, id) diff --git a/app/users/schemas.py b/app/users/schemas.py deleted file mode 100644 index caf42e8..0000000 --- a/app/users/schemas.py +++ /dev/null @@ -1,44 +0,0 @@ -from datetime import datetime - -from pydantic import BaseModel, ConfigDict, Field - - -# Pydantic Models -class UserBase(BaseModel): - user_id: str = Field(..., min_length=1, max_length=100) - first_name: str = Field(..., min_length=1, max_length=100) - last_name: str = Field(..., min_length=1, max_length=100) - display_name: str = Field(..., min_length=1, max_length=200) - email: str = Field(..., min_length=1, max_length=254) - is_active: bool = True - created_by: str = Field(..., min_length=1, max_length=100) - modified_by: str = Field(..., min_length=1, max_length=100) - - -class UserCreate(UserBase): - hashed_password: str | None = None - - -class UserUpdate(BaseModel): - user_id: str | None = Field(None, min_length=1, max_length=100) - first_name: str | None = Field(None, min_length=1, max_length=100) - last_name: str | None = Field(None, min_length=1, max_length=100) - display_name: str | None = Field(None, min_length=1, max_length=200) - email: str | None = Field(None, min_length=1, max_length=254) - is_active: bool | None = None - modified_by: str | None = Field(None, min_length=1, max_length=100) - - -class UserResponse(UserBase): - model_config = ConfigDict(from_attributes=True) - id: int - created: datetime - modified: datetime - - -class UserListResponse(BaseModel): - items: list[UserResponse] - item_count: int = 0 - page_count: int = 0 - prev_page: int | None = None - next_page: int | None = None diff --git a/app/users/services.py b/app/users/services.py deleted file mode 100644 index 98173cc..0000000 --- a/app/users/services.py +++ /dev/null @@ -1,63 +0,0 @@ -from fastapi import HTTPException -from sqlalchemy.orm import Session - -from app.users.models import DBUser -from app.users.schemas import UserCreate, UserUpdate -from app.utils import get_next_page, get_page_count, get_prev_page - - -def get_items(db: Session, page_number: int, page_size: int): - item_count = db.query(DBUser).count() - items = db.query(DBUser).limit(page_size).offset(page_number * page_size).all() - - return { - "items": items, - "item_count": item_count, - "page_count": get_page_count(item_count, page_size), - "prev_page": get_prev_page(page_number), - "next_page": get_next_page(item_count, page_number, page_size), - } - - -def create_item(db: Session, item: UserCreate): - db_item = DBUser(**item.model_dump(exclude={"hashed_password"})) - if item.hashed_password: - db_item.hashed_password = item.hashed_password - db.add(db_item) - db.commit() - db.refresh(db_item) - - return db_item - - -def get_item(db: Session, item_id: int): - return db.query(DBUser).where(DBUser.id == item_id).first() - - -def update_item(db: Session, id: int, item: UserUpdate): - db_item = db.query(DBUser).filter(DBUser.id == id).first() - if db_item is None: - raise HTTPException(status_code=404, detail="Item not found") - - # Only update fields that are provided (not None) - update_data = item.model_dump(exclude_unset=True) - for field, value in update_data.items(): - if value is not None: - setattr(db_item, field, value) - - db.add(db_item) - db.commit() - db.refresh(db_item) - - return db_item - - -def delete_item(db: Session, id: int): - db_item = db.query(DBUser).filter(DBUser.id == id).first() - if db_item is None: - raise HTTPException(status_code=404, detail="Item not found") - - db.query(DBUser).filter(DBUser.id == id).delete() - db.commit() - - return None diff --git a/app/utils.py b/app/utils.py index 67ff80b..3267ea7 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,12 +1,61 @@ +import logging + +from app.config import settings + + +def setup_logging(): + """Configure application-wide logging. + + Sets up logging with the configured log level and format. + This should be called once at application startup. + """ + logging.basicConfig( + level=settings.LOG_LEVEL, + format="%(levelname)s: %(message)s - %(name)s:%(lineno)d", + handlers=[logging.StreamHandler()], + ) + logger = logging.getLogger(__name__) + logger.info("Logging configured with level: %s", settings.LOG_LEVEL) + + def get_page_count(item_count: int, page_size: int): - return round(item_count / page_size) + """Calculate the total number of pages based on item count and page size. + + Args: + item_count: Total number of items to paginate. + page_size: Number of items per page. + + Returns: + int: The total number of pages. + """ + if item_count == 0: + return 0 + return (item_count + page_size - 1) // page_size def get_prev_page(page_number: int): # noqa: E501 + """Get the previous page number if it exists. + + Args: + page_number: Current page number (0-indexed). + + Returns: + int | None: The previous page number, or None if on the first page. + """ return page_number - 1 if page_number > 0 else None def get_next_page(item_count: int, page_number: int, page_size: int): + """Get the next page number if it exists. + + Args: + item_count: Total number of items to paginate. + page_number: Current page number (0-indexed). + page_size: Number of items per page. + + Returns: + int | None: The next page number, or None if on the last page. + """ page_count = get_page_count(item_count, page_size) if page_number >= page_count: return None diff --git a/pyproject.toml b/pyproject.toml index 3921d6d..437da45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "alembic", "fastapi", "httpx", + "psycopg2-binary", "pydantic-settings", "python-dotenv", "python-jose[cryptography]", @@ -37,6 +38,7 @@ exclude = [ ".venv", "venv", "docs", + "scripts", "__pycache", "**/migrations/*", ] diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..1907f80 --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1 @@ +"""Scripts for database initialization and maintenance.""" diff --git a/scripts/init_db.sh b/scripts/init_db.sh new file mode 100755 index 0000000..0cbea42 --- /dev/null +++ b/scripts/init_db.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Initialize PostgreSQL database for Comet API +# Usage: ./scripts/init_db.sh [--seed] + +set -e + +# Change to the project root directory +cd "$(dirname "$0")/.." + +echo "Initializing PostgreSQL database..." + +# Activate virtual environment if it exists +if [ -d "venv" ]; then + source venv/bin/activate +fi + +# Run the Python initialization script +python scripts/init_postgres.py "$@" diff --git a/scripts/init_postgres.py b/scripts/init_postgres.py new file mode 100755 index 0000000..a27a1f4 --- /dev/null +++ b/scripts/init_postgres.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +"""Initialize PostgreSQL database for Comet API. + +This script: +1. Creates the database if it doesn't exist +2. Runs Alembic migrations to set up the schema +3. Optionally seeds initial data + +Usage: + python scripts/init_postgres.py [--seed] +""" + +import argparse +import sys +from pathlib import Path + +import psycopg2 +from psycopg2 import sql +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +from sqlalchemy.engine import make_url + +# Add the project root to the path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.config import settings + + +def parse_database_url(url: str) -> dict: + """Parse PostgreSQL database URL using SQLAlchemy's make_url. + + Args: + url: Database URL in format postgresql://user:password@host:port/dbname + + Returns: + Dictionary with connection parameters. + """ + parsed = make_url(url) + + return { + "user": parsed.username, + "password": parsed.password, + "host": parsed.host, + "port": parsed.port or 5432, + "dbname": parsed.database, + } + + +def create_database(db_params: dict) -> None: + """Create database if it doesn't exist. + + Args: + db_params: Database connection parameters. + """ + dbname = db_params["dbname"] + conn_params = db_params.copy() + conn_params["dbname"] = "postgres" # Connect to default database + + print(f"Checking if database '{dbname}' exists...") + + try: + # Connect to PostgreSQL server + conn = psycopg2.connect(**conn_params) + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + cursor = conn.cursor() + + # Check if database exists + cursor.execute( + "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", (dbname,) + ) + exists = cursor.fetchone() + + if exists: + print(f"Database '{dbname}' already exists.") + else: + # Create database + cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(dbname))) + print(f"Database '{dbname}' created successfully.") + + cursor.close() + conn.close() + + except psycopg2.Error as e: + print(f"Error creating database: {e}") + sys.exit(1) + + +def run_migrations() -> None: + """Run Alembic migrations to set up database schema.""" + import subprocess + + print("\nRunning Alembic migrations...") + + try: + result = subprocess.run( + ["alembic", "upgrade", "head"], + capture_output=True, + text=True, + check=True, + ) + print(result.stdout) + print("Migrations completed successfully.") + + except subprocess.CalledProcessError as e: + print(f"Error running migrations: {e}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + sys.exit(1) + except FileNotFoundError: + print("Error: Alembic not found. Please install it with: pip install alembic") + sys.exit(1) + + +def seed_initial_data() -> None: + """Seed initial data into the database.""" + from sqlalchemy.orm import Session + + from app.db import SessionLocal + + print("\nSeeding initial data...") + + db: Session = SessionLocal() + + try: + # Add your initial data seeding logic here + # Example: + # from app.users.models import DBUser + # admin_user = DBUser( + # username="admin", + # email="admin@example.com", + # is_active=True, + # ) + # db.add(admin_user) + # db.commit() + + print("Initial data seeded successfully.") + + except Exception as e: + print(f"Error seeding data: {e}") + db.rollback() + sys.exit(1) + finally: + db.close() + + +def main(): + """Main function to initialize PostgreSQL database.""" + parser = argparse.ArgumentParser( + description="Initialize PostgreSQL database for Comet API" + ) + parser.add_argument( + "--seed", + action="store_true", + help="Seed initial data after running migrations", + ) + parser.add_argument( + "--skip-create", + action="store_true", + help="Skip database creation (only run migrations)", + ) + + args = parser.parse_args() + + # Check if using PostgreSQL + if not settings.DATABASE_URL.startswith("postgresql"): + print("Error: This script is only for PostgreSQL databases.") + print(f"Current DATABASE_URL: {settings.DATABASE_URL}") + sys.exit(1) + + print("=" * 60) + print("Comet API - PostgreSQL Database Initialization") + print("=" * 60) + print(f"\nDatabase URL: {settings.DATABASE_URL}") + + # Parse database URL + try: + db_params = parse_database_url(settings.DATABASE_URL) + except ValueError as e: + print(f"Error parsing database URL: {e}") + sys.exit(1) + + # Create database + if not args.skip_create: + create_database(db_params) + + # Run migrations + run_migrations() + + # Seed data if requested + if args.seed: + seed_initial_data() + + print("\n" + "=" * 60) + print("Database initialization completed successfully!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index 9cec1ce..cb13022 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,14 +17,12 @@ from app.cases.router import router as cases_router from app.db import Base, get_db from app.health.router import router as health_router -from app.users.router import router as users_router def start_application(): app = FastAPI() app.include_router(cases_router) app.include_router(applicants_router) - app.include_router(users_router) app.include_router(admin_router) app.include_router(health_router) return app @@ -62,7 +60,8 @@ def db_session(app: FastAPI) -> Generator[SessionTesting, Any, None]: # type: i @pytest.fixture(scope="function") def client( - app: FastAPI, db_session: SessionTesting # type: ignore + app: FastAPI, + db_session: SessionTesting, # type: ignore ) -> Generator[TestClient, Any, None]: """ Create a new FastAPI TestClient that uses the `db_session` fixture to override diff --git a/tests/test_users.py b/tests/test_users.py deleted file mode 100644 index 5d054b3..0000000 --- a/tests/test_users.py +++ /dev/null @@ -1,95 +0,0 @@ -import pytest - -base_date = "2021-01-01T00:00:00.000000" -base_user = { - "id": 1, - "user_id": "testuser", - "first_name": "Test", - "last_name": "User", - "display_name": "Test User", - "email": "testuser1@test.com", - "is_active": True, - "created": base_date, - "created_by": "System Account", - "modified": base_date, - "modified_by": "System Account", -} - - -async def seed_data(client): - client.post("/users/", json=base_user) - - -@pytest.mark.asyncio -async def test_create_user(client): - response = client.post("/users/", json=base_user) - response_json = response.json() - response_json["created"] = base_date - response_json["modified"] = base_date - - assert response.status_code == 201 - assert response_json == base_user - - -@pytest.mark.asyncio -async def test_get_all_users(client): - await seed_data(client) - response = client.get("/users") - assert response.status_code == 200 - assert len(response.json()) > 0 - - -@pytest.mark.asyncio -async def test_get_users_paged(client): - await seed_data(client) - response = client.get("/users?page_number=0&page_size=10") - assert response.status_code == 200 - assert len(response.json()) > 0 - - -@pytest.mark.asyncio -async def test_get_user(client): - await seed_data(client) - response = client.get("/users/1") - response_json = response.json() - response_json["created"] = base_date - response_json["modified"] = base_date - - assert response.status_code == 200 - assert response_json == base_user - - -@pytest.mark.asyncio -async def test_update_user(client): - await seed_data(client) - updated_user = base_user.copy() - updated_user["is_active"] = False - - response = client.put("/users/1", json=updated_user) - response_json = response.json() - response_json["created"] = base_date - response_json["modified"] = base_date - - assert response.status_code == 200 - assert response_json == updated_user - - -@pytest.mark.asyncio -async def test_update_user_invalid_id(client): - await seed_data(client) - response = client.put("/users/-1", json=base_user) - assert response.status_code == 404 - - -@pytest.mark.asyncio -async def test_delete_user(client): - await seed_data(client) - response = client.delete("/users/1") - assert response.status_code == 204 - - -@pytest.mark.asyncio -async def test_delete_user_invalid_id(client): - await seed_data(client) - response = client.delete("/users/-1") - assert response.status_code == 404 diff --git a/tests/test_utils.py b/tests/test_utils.py index b9fa54a..726dba9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,9 +4,9 @@ def test_get_page_count(): assert get_page_count(0, 10) == 0 assert get_page_count(10, 10) == 1 - assert get_page_count(11, 10) == 1 + assert get_page_count(11, 10) == 2 assert get_page_count(20, 10) == 2 - assert get_page_count(21, 10) == 2 + assert get_page_count(21, 10) == 3 assert get_page_count(30, 10) == 3