diff --git a/iceaxe/__init__.py b/iceaxe/__init__.py index 72fb630..e4ad168 100644 --- a/iceaxe/__init__.py +++ b/iceaxe/__init__.py @@ -4,7 +4,10 @@ TableBase as TableBase, UniqueConstraint as UniqueConstraint, ) -from .exceptions import IceaxeQueryError as IceaxeQueryError +from .exceptions import ( + IceaxeQueryError as IceaxeQueryError, + NoObjectFound as NoObjectFound, +) from .field import Field as Field from .functions import func as func from .postgres import PostgresDateTime as PostgresDateTime, PostgresTime as PostgresTime diff --git a/iceaxe/__tests__/test_queries.py b/iceaxe/__tests__/test_queries.py index e14202c..8ff9b94 100644 --- a/iceaxe/__tests__/test_queries.py +++ b/iceaxe/__tests__/test_queries.py @@ -1,5 +1,5 @@ from enum import IntEnum, StrEnum -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import pytest @@ -10,6 +10,7 @@ FunctionDemoModel, UserDemo, ) +from iceaxe.__tests__.helpers import pyright_raises from iceaxe.functions import func from iceaxe.queries import QueryBuilder, and_, or_, select @@ -29,6 +30,27 @@ def test_select(): ) +def test_select_one(): + new_query = select(UserDemo).one() + assert new_query.build() == ( + 'SELECT "userdemo"."id" AS "userdemo_id", "userdemo"."name" AS ' + '"userdemo_name", "userdemo"."email" AS "userdemo_email" FROM "userdemo" LIMIT 1', + [], + ) + + +def test_one_requires_single_full_model(): + with pytest.raises( + ValueError, match="one\\(\\) only supports selecting a single full table model" + ): + cast(Any, select(UserDemo.email)).one() + + with pytest.raises( + ValueError, match="one\\(\\) only supports selecting a single full table model" + ): + cast(Any, select((UserDemo,))).one() + + def test_select_single_field(): new_query = QueryBuilder().select(UserDemo.email) assert new_query.build() == ( @@ -583,6 +605,18 @@ def test_select_single_typehint(): _: QueryBuilder[UserDemo, Literal["SELECT"]] = query +def test_select_one_typehint(): + query = select(UserDemo).one() + if TYPE_CHECKING: + _: QueryBuilder[UserDemo, Literal["SELECT_ONE"]] = query + + +def test_select_one_partial_typehint_error(): + with pyright_raises("reportAttributeAccessIssue"): + if TYPE_CHECKING: + select(UserDemo.id).one() # type: ignore + + def test_select_multiple_typehints(): query = select((UserDemo, UserDemo.id, UserDemo.name)) if TYPE_CHECKING: diff --git a/iceaxe/__tests__/test_session.py b/iceaxe/__tests__/test_session.py index 9e639df..6bda421 100644 --- a/iceaxe/__tests__/test_session.py +++ b/iceaxe/__tests__/test_session.py @@ -18,10 +18,10 @@ UserDemo, ) from iceaxe.base import INTERNAL_TABLE_FIELDS, TableBase -from iceaxe.exceptions import IceaxeQueryError +from iceaxe.exceptions import IceaxeQueryError, NoObjectFound from iceaxe.field import Field from iceaxe.functions import func -from iceaxe.queries import QueryBuilder +from iceaxe.queries import QueryBuilder, select from iceaxe.schemas.cli import create_all from iceaxe.session import ( PG_MAX_PARAMETERS, @@ -220,6 +220,33 @@ async def test_select(db_connection: DBConnection): ] +@pytest.mark.asyncio +async def test_select_one(db_connection: DBConnection): + user = UserDemo(name="John Doe", email="john@example.com") + await db_connection.insert([user]) + + result = await db_connection.exec( + select(UserDemo).where(UserDemo.name == "John Doe").one() + ) + assert result == UserDemo(id=user.id, name="John Doe", email="john@example.com") + + +@pytest.mark.asyncio +async def test_select_one_missing_raises(db_connection: DBConnection): + query = select(UserDemo).where(UserDemo.id == 999).one() + + with pytest.raises(NoObjectFound, match="No UserDemo object found") as exc_info: + await db_connection.exec(query) + + assert exc_info.value.object_type is UserDemo + assert exc_info.value.sql_text == ( + 'SELECT "userdemo"."id" AS "userdemo_id", "userdemo"."name" AS ' + '"userdemo_name", "userdemo"."email" AS "userdemo_email" FROM "userdemo" ' + 'WHERE "userdemo"."id" = $1 LIMIT 1' + ) + assert exc_info.value.variables == (999,) + + @pytest.mark.asyncio async def test_is_null(db_connection: DBConnection): user = UserDemo(name="John Doe", email="john@example.com") diff --git a/iceaxe/exceptions.py b/iceaxe/exceptions.py index cfd417f..2427c31 100644 --- a/iceaxe/exceptions.py +++ b/iceaxe/exceptions.py @@ -1,8 +1,25 @@ from functools import lru_cache +from typing import Any import asyncpg +class NoObjectFound(LookupError): + """ + Query completed successfully but returned no rows for a `.one()` lookup. + """ + + def __init__( + self, object_type: type[Any], sql_text: str, variables: tuple[Any, ...] + ): + self.object_type = object_type + self.sql_text = sql_text + self.variables = variables + + context = f"\nQuery: {sql_text}\nVariables: {variables}" + super().__init__(f"No {object_type.__name__} object found.{context}") + + class IceaxeQueryError(asyncpg.PostgresError): """ Query error that subclasses the original asyncpg exception type, diff --git a/iceaxe/queries.py b/iceaxe/queries.py index 021aff2..2a3fc93 100644 --- a/iceaxe/queries.py +++ b/iceaxe/queries.py @@ -65,7 +65,11 @@ Ts = TypeVarTuple("Ts") -QueryType = TypeVar("QueryType", bound=Literal["SELECT", "INSERT", "UPDATE", "DELETE"]) +QueryType = TypeVar( + "QueryType", + bound=Literal["SELECT", "SELECT_ONE", "INSERT", "UPDATE", "DELETE"], +) +TableSelectType = TypeVar("TableSelectType", bound=TableBase) JoinType = Literal["INNER", "LEFT", "RIGHT", "FULL"] @@ -145,6 +149,7 @@ def __init__(self): self._main_model: Type[TableBase] | None = None self._return_typehint: P + self._is_single_select_expression = False self._where_conditions: list[FieldComparison | FieldComparisonGroup] = [] self._order_by_clauses: list[str] = [] @@ -429,8 +434,10 @@ def select( ] if not isinstance(fields, tuple): all_fields = (fields,) # type: ignore + self._is_single_select_expression = True else: all_fields = fields # type: ignore + self._is_single_select_expression = False # Verify the field type for field in all_fields: @@ -448,6 +455,33 @@ def select( return self # type: ignore + @allow_branching + def one( + self: QueryBuilder[TableSelectType, Literal["SELECT"]], + ) -> QueryBuilder[TableSelectType, Literal["SELECT_ONE"]]: + """ + Converts a full-model SELECT query into a single-object SELECT. + + This is currently limited to selects of a single full model expression and + always applies `LIMIT 1` at build time. + + :return: A QueryBuilder instance configured for a single-object SELECT + + """ + if self._query_type != "SELECT": + raise ValueError("one() is only valid on SELECT queries") + + if ( + not self._is_single_select_expression + or len(self._select_raw) != 1 + or not is_base_table(self._select_raw[0]) + ): + raise ValueError("one() only supports selecting a single full table model.") + + self._query_type = "SELECT_ONE" # type: ignore + self._limit_value = 1 + return self # type: ignore + def _select_inner( self, fields: tuple[DBFieldClassDefinition | Type[TableBase] | FunctionMetadata, ...], @@ -986,7 +1020,7 @@ def build(self) -> tuple[str, list[Any]]: query = "" variables: list[Any] = [] - if self._query_type == "SELECT": + if self._query_type in ("SELECT", "SELECT_ONE"): if not self._main_model: raise ValueError("No model selected for query") @@ -1055,8 +1089,9 @@ def build(self) -> tuple[str, list[Any]]: if self._order_by_clauses: query += " ORDER BY " + ", ".join(self._order_by_clauses) - if self._limit_value is not None: - query += f" LIMIT {self._limit_value}" + effective_limit = 1 if self._query_type == "SELECT_ONE" else self._limit_value + if effective_limit is not None: + query += f" LIMIT {effective_limit}" if self._offset_value is not None: query += f" OFFSET {self._offset_value}" diff --git a/iceaxe/session.py b/iceaxe/session.py index 0184bc1..b3f5d6f 100644 --- a/iceaxe/session.py +++ b/iceaxe/session.py @@ -18,7 +18,7 @@ from typing_extensions import TypeVarTuple from iceaxe.base import DBFieldClassDefinition, DBModelMetaclass, TableBase -from iceaxe.exceptions import wrap_query_error +from iceaxe.exceptions import NoObjectFound, wrap_query_error from iceaxe.logging import LOGGER from iceaxe.modifications import ModificationTracker from iceaxe.queries import ( @@ -261,6 +261,11 @@ async def transaction(self, *, ensure: bool = False): @overload async def exec(self, query: QueryBuilder[T, Literal["SELECT"]]) -> list[T]: ... + @overload + async def exec( + self, query: QueryBuilder[TableType, Literal["SELECT_ONE"]] + ) -> TableType: ... + @overload async def exec(self, query: QueryBuilder[T, Literal["INSERT"]]) -> None: ... @@ -273,10 +278,11 @@ async def exec(self, query: QueryBuilder[T, Literal["DELETE"]]) -> None: ... async def exec( self, query: QueryBuilder[T, Literal["SELECT"]] + | QueryBuilder[TableType, Literal["SELECT_ONE"]] | QueryBuilder[T, Literal["INSERT"]] | QueryBuilder[T, Literal["UPDATE"]] | QueryBuilder[T, Literal["DELETE"]], - ) -> list[T] | None: + ) -> list[T] | TableType | None: """ Execute a query built with QueryBuilder and return the results. @@ -304,7 +310,7 @@ async def exec( ``` :param query: A QueryBuilder instance representing the query to execute - :return: For SELECT queries, returns a list of results. For other queries, returns None + :return: For SELECT queries, returns results. For other queries, returns None """ sql_text, variables = query.build() @@ -314,7 +320,7 @@ async def exec( except asyncpg.PostgresError as e: raise wrap_query_error(e, sql_text, tuple(variables)) from e - if query._query_type == "SELECT": + if query._query_type in ("SELECT", "SELECT_ONE"): # Pre-cache the select types for better performance select_types = [ ( @@ -338,6 +344,15 @@ async def exec( self.modification_tracker.track_modification ) + if query._query_type == "SELECT_ONE": + if not result_all: + raise NoObjectFound( + cast(type[TableBase], query._select_raw[0]), + sql_text, + tuple(variables), + ) + return cast(TableType, result_all[0]) + return cast(list[T], result_all) return None