Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ RUN apt update \
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
RUN useradd --no-create-home --gid root runner

ENV UV_PYTHON_PREFERENCE=only-system
ENV UV_NO_CACHE=true
ENV UV_PROJECT_ENVIRONMENT=/code/.venv \
UV_NO_MANAGED_PYTHON=1 \
UV_NO_CACHE=true \
UV_LINK_MODE=copy

WORKDIR /code

Expand Down
34 changes: 17 additions & 17 deletions app/api/decks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@
from starlette import status

from app import models, schemas
from app.repositories import CardsService, DecksService
from app.repositories import CardsRepository, DecksRepository


ROUTER: typing.Final = fastapi.APIRouter()


@ROUTER.get("/decks/")
async def list_decks(
decks_service: DecksService = FromDI(DecksService),
decks_repository: DecksRepository = FromDI(DecksRepository),
) -> schemas.Decks:
objects = await decks_service.list()
objects = await decks_repository.list()
return typing.cast("schemas.Decks", {"items": objects})


@ROUTER.get("/decks/{deck_id}/")
async def get_deck(
deck_id: int,
decks_service: DecksService = FromDI(DecksService),
decks_repository: DecksRepository = FromDI(DecksRepository),
) -> schemas.Deck:
instance = await decks_service.get_one_or_none(
instance = await decks_repository.get_one_or_none(
models.Deck.id == deck_id,
load=[orm.selectinload(models.Deck.cards)],
)
Expand All @@ -40,10 +40,10 @@ async def get_deck(
async def update_deck(
deck_id: int,
data: schemas.DeckCreate,
decks_service: DecksService = FromDI(DecksService),
decks_repository: DecksRepository = FromDI(DecksRepository),
) -> schemas.Deck:
try:
instance = await decks_service.update(data=data.model_dump(), item_id=deck_id)
instance = await decks_repository.update(data=data.model_dump(), item_id=deck_id)
except NotFoundError:
raise fastapi.HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deck is not found") from None

Expand All @@ -53,27 +53,27 @@ async def update_deck(
@ROUTER.post("/decks/")
async def create_deck(
data: schemas.DeckCreate,
decks_service: DecksService = FromDI(DecksService),
decks_repository: DecksRepository = FromDI(DecksRepository),
) -> schemas.Deck:
instance = await decks_service.create(data.model_dump())
instance = await decks_repository.create(data.model_dump())
return typing.cast("schemas.Deck", instance)


@ROUTER.get("/decks/{deck_id}/cards/")
async def list_cards(
deck_id: int,
cards_service: CardsService = FromDI(CardsService),
cards_repository: CardsRepository = FromDI(CardsRepository),
) -> schemas.Cards:
objects = await cards_service.list(models.Card.deck_id == deck_id)
objects = await cards_repository.list(models.Card.deck_id == deck_id)
return typing.cast("schemas.Cards", {"items": objects})


@ROUTER.get("/cards/{card_id}/")
async def get_card(
card_id: int,
cards_service: CardsService = FromDI(CardsService),
cards_repository: CardsRepository = FromDI(CardsRepository),
) -> schemas.Card:
instance = await cards_service.get_one_or_none(models.Card.id == card_id)
instance = await cards_repository.get_one_or_none(models.Card.id == card_id)
if not instance:
raise fastapi.HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Card is not found")
return typing.cast("schemas.Card", instance)
Expand All @@ -83,9 +83,9 @@ async def get_card(
async def create_cards(
deck_id: int,
data: list[schemas.CardCreate],
cards_service: CardsService = FromDI(CardsService),
cards_repository: CardsRepository = FromDI(CardsRepository),
) -> schemas.Cards:
objects = await cards_service.create_many(
objects = await cards_repository.create_many(
data=[models.Card(**card.model_dump(), deck_id=deck_id) for card in data],
)
return typing.cast("schemas.Cards", {"items": objects})
Expand All @@ -95,9 +95,9 @@ async def create_cards(
async def update_cards(
deck_id: int,
data: list[schemas.Card],
cards_service: CardsService = FromDI(CardsService),
cards_repository: CardsRepository = FromDI(CardsRepository),
) -> schemas.Cards:
objects = await cards_service.upsert_many(
objects = await cards_repository.upsert_many(
data=[models.Card(**card.model_dump(exclude={"deck_id"}), deck_id=deck_id) for card in data],
)
return typing.cast("schemas.Cards", {"items": objects})
2 changes: 1 addition & 1 deletion app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def include_routers(app: fastapi.FastAPI) -> None:


def build_app() -> fastapi.FastAPI:
di_container = modern_di.AsyncContainer(groups=[ioc.Dependencies])
di_container = modern_di.Container(groups=[ioc.Dependencies])
bootstrap_config = dataclasses.replace(
settings.api_bootstrapper_config,
opentelemetry_instrumentors=[
Expand Down
28 changes: 22 additions & 6 deletions app/ioc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
from modern_di import Group, Scope, providers

from app import repositories
from app.resources.db import create_sa_engine, create_session
from app.repositories import CardsRepository, DecksRepository
from app.resources.db import close_sa_engine, close_session, create_sa_engine, create_session


class Dependencies(Group):
database_engine = providers.Resource(Scope.APP, create_sa_engine)
session = providers.Resource(Scope.REQUEST, create_session, engine=database_engine.cast)
database_engine = providers.Factory(
creator=create_sa_engine, cache_settings=providers.CacheSettings(finalizer=close_sa_engine)
)
session = providers.Factory(
scope=Scope.REQUEST, creator=create_session, cache_settings=providers.CacheSettings(finalizer=close_session)
)

decks_service = providers.Factory(Scope.REQUEST, repositories.DecksService, session=session.cast, auto_commit=True)
cards_service = providers.Factory(Scope.REQUEST, repositories.CardsService, session=session.cast, auto_commit=True)
decks_repository = providers.Factory(
scope=Scope.REQUEST,
creator=DecksRepository,
bound_type=DecksRepository,
kwargs={"session": session, "auto_commit": True},
skip_creator_parsing=True,
)
cards_repository = providers.Factory(
scope=Scope.REQUEST,
creator=CardsRepository,
bound_type=CardsRepository,
kwargs={"session": session, "auto_commit": True},
skip_creator_parsing=True,
)
18 changes: 8 additions & 10 deletions app/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@
from app import models


class DecksRepository(SQLAlchemyAsyncRepository[models.Deck]):
model_type = models.Deck
class DecksRepository(SQLAlchemyAsyncRepositoryService[models.Deck]):
class BaseRepository(SQLAlchemyAsyncRepository[models.Deck]):
model_type = models.Deck

repository_type = BaseRepository

class DecksService(SQLAlchemyAsyncRepositoryService[models.Deck]):
repository_type = DecksRepository

class CardsRepository(SQLAlchemyAsyncRepositoryService[models.Card]):
class BaseRepository(SQLAlchemyAsyncRepository[models.Card]):
model_type = models.Card

class CardsRepository(SQLAlchemyAsyncRepository[models.Card]):
model_type = models.Card


class CardsService(SQLAlchemyAsyncRepositoryService[models.Card]):
repository_type = CardsRepository
repository_type = BaseRepository
29 changes: 14 additions & 15 deletions app/resources/db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import typing

Expand All @@ -9,23 +10,19 @@
logger = logging.getLogger(__name__)


async def create_sa_engine() -> typing.AsyncIterator[sa.AsyncEngine]:
logger.info("Initializing SQLAlchemy engine")
engine = sa.create_async_engine(
def create_sa_engine() -> sa.AsyncEngine:
return sa.create_async_engine(
url=settings.db_dsn_parsed,
echo=settings.service_debug,
echo_pool=settings.service_debug,
pool_size=settings.db_pool_size,
pool_pre_ping=settings.db_pool_pre_ping,
max_overflow=settings.db_max_overflow,
)
engine.pool.status()
logger.info("SQLAlchemy engine has been initialized")
try:
yield engine
finally:
await engine.dispose()
logger.info("SQLAlchemy engine has been cleaned up")


async def close_sa_engine(engine: sa.AsyncEngine) -> None:
await engine.dispose()


class CustomAsyncSession(sa.AsyncSession):
Expand All @@ -36,8 +33,10 @@ async def close(self) -> None:
return await super().close()


async def create_session(engine: sa.AsyncEngine) -> typing.AsyncIterator[sa.AsyncSession]:
async with CustomAsyncSession(engine, expire_on_commit=False, autoflush=False) as session:
logger.info("session created")
yield session
logger.info("session closed")
def create_session(engine: sa.AsyncEngine) -> sa.AsyncSession:
return CustomAsyncSession(engine, expire_on_commit=False, autoflush=False)


async def close_session(session: sa.AsyncSession) -> None:
task: typing.Final = asyncio.create_task(session.close())
await asyncio.shield(task)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "MIT License"
dependencies = [
"fastapi>=0.76",
"lite-bootstrap[fastapi-all]",
"modern-di-fastapi>=1",
"modern-di-fastapi>=2",
"advanced-alchemy",
"pydantic-settings",
"granian[uvloop]",
Expand Down
33 changes: 25 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytest
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession

from app import ioc
from app.application import build_app
from app.resources.db import create_sa_engine


if typing.TYPE_CHECKING:
Expand All @@ -32,17 +33,21 @@ async def client(app: fastapi.FastAPI) -> typing.AsyncIterator[AsyncClient]:


@pytest.fixture
def di_container(app: fastapi.FastAPI) -> modern_di.AsyncContainer:
return modern_di_fastapi.fetch_di_container(app)
async def di_container(app: fastapi.FastAPI) -> typing.AsyncIterator[modern_di.Container]:
container = modern_di_fastapi.fetch_di_container(app)
try:
yield container
finally:
await container.close_async()


@pytest.fixture(autouse=True)
async def db_session(di_container: modern_di.AsyncContainer) -> typing.AsyncIterator[AsyncSession]:
engine = await di_container.resolve_provider(ioc.Dependencies.database_engine)
@pytest.fixture
async def db_session(di_container: modern_di.Container) -> typing.AsyncIterator[AsyncSession]:
engine = create_sa_engine()
connection = await engine.connect()
transaction = await connection.begin()
await connection.begin_nested()
di_container.override(ioc.Dependencies.database_engine, connection)
di_container.override(dependency_type=AsyncEngine, mock=connection)

try:
yield AsyncSession(connection, expire_on_commit=False, autoflush=False)
Expand All @@ -51,3 +56,15 @@ async def db_session(di_container: modern_di.AsyncContainer) -> typing.AsyncIter
await transaction.rollback()
await connection.close()
await engine.dispose()
di_container.reset_override()


@pytest.fixture
async def set_async_session_in_base_sqlalchemy_factory(
db_session: AsyncSession,
) -> typing.AsyncIterator[None]:
try:
SQLAlchemyFactory.__async_session__ = db_session
yield
finally:
SQLAlchemyFactory.__async_session__ = None
23 changes: 9 additions & 14 deletions tests/test_cards.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import TYPE_CHECKING

import pytest
from fastapi import status

from tests import factories


if TYPE_CHECKING:
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession


async def test_get_cards_empty(client: AsyncClient, db_session: AsyncSession) -> None:
factories.DeckModelFactory.__async_session__ = db_session
pytestmark = [pytest.mark.usefixtures("set_async_session_in_base_sqlalchemy_factory")]


async def test_get_cards_empty(client: AsyncClient) -> None:
deck = await factories.DeckModelFactory.create_async()

response = await client.get(f"/api/decks/{deck.id}/cards/")
Expand All @@ -22,9 +24,7 @@ async def test_get_cards_empty(client: AsyncClient, db_session: AsyncSession) ->
assert response.status_code == status.HTTP_404_NOT_FOUND


async def test_get_cards(client: AsyncClient, db_session: AsyncSession) -> None:
factories.DeckModelFactory.__async_session__ = db_session
factories.CardModelFactory.__async_session__ = db_session
async def test_get_cards(client: AsyncClient) -> None:
deck = await factories.DeckModelFactory.create_async()
card = await factories.CardModelFactory.create_async(deck_id=deck.id)

Expand All @@ -36,9 +36,7 @@ async def test_get_cards(client: AsyncClient, db_session: AsyncSession) -> None:
assert v == getattr(card, k)


async def test_get_card(client: AsyncClient, db_session: AsyncSession) -> None:
factories.DeckModelFactory.__async_session__ = db_session
factories.CardModelFactory.__async_session__ = db_session
async def test_get_card(client: AsyncClient) -> None:
deck = await factories.DeckModelFactory.create_async()
card = await factories.CardModelFactory.create_async(deck_id=deck.id)

Expand All @@ -53,8 +51,7 @@ async def test_get_card_not_exist(client: AsyncClient) -> None:
assert response.status_code == status.HTTP_404_NOT_FOUND


async def test_create_cards(client: AsyncClient, db_session: AsyncSession) -> None:
factories.DeckModelFactory.__async_session__ = db_session
async def test_create_cards(client: AsyncClient) -> None:
deck = await factories.DeckModelFactory.create_async()

cards_to_create = [factories.CardCreateSchemaFactory.build(), factories.CardCreateSchemaFactory.build()]
Expand Down Expand Up @@ -86,9 +83,7 @@ async def test_create_cards(client: AsyncClient, db_session: AsyncSession) -> No
assert data["detail"] == "A record matching the supplied data already exists."


async def test_update_cards(client: AsyncClient, db_session: AsyncSession) -> None:
factories.DeckModelFactory.__async_session__ = db_session
factories.CardModelFactory.__async_session__ = db_session
async def test_update_cards(client: AsyncClient) -> None:
deck = await factories.DeckModelFactory.create_async()
card1, card2 = await factories.CardModelFactory.create_batch_async(size=2, deck_id=deck.id)

Expand Down
Loading