Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion iceaxe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
TableBase as TableBase,
UniqueConstraint as UniqueConstraint,
)
from .exceptions import IceaxeQueryError as IceaxeQueryError
from .exceptions import (
IceaxeQueryError as IceaxeQueryError,
NoObjectFound as NoObjectFound,
)
from .field import Field as Field
from .functions import func as func
from .postgres import PostgresDateTime as PostgresDateTime, PostgresTime as PostgresTime
Expand Down
36 changes: 35 additions & 1 deletion iceaxe/__tests__/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum, StrEnum
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Any, Literal, cast

import pytest

Expand All @@ -10,6 +10,7 @@
FunctionDemoModel,
UserDemo,
)
from iceaxe.__tests__.helpers import pyright_raises
from iceaxe.functions import func
from iceaxe.queries import QueryBuilder, and_, or_, select

Expand All @@ -29,6 +30,27 @@ def test_select():
)


def test_select_one():
new_query = select(UserDemo).one()
assert new_query.build() == (
'SELECT "userdemo"."id" AS "userdemo_id", "userdemo"."name" AS '
'"userdemo_name", "userdemo"."email" AS "userdemo_email" FROM "userdemo" LIMIT 1',
[],
)


def test_one_requires_single_full_model():
with pytest.raises(
ValueError, match="one\\(\\) only supports selecting a single full table model"
):
cast(Any, select(UserDemo.email)).one()

with pytest.raises(
ValueError, match="one\\(\\) only supports selecting a single full table model"
):
cast(Any, select((UserDemo,))).one()


def test_select_single_field():
new_query = QueryBuilder().select(UserDemo.email)
assert new_query.build() == (
Expand Down Expand Up @@ -583,6 +605,18 @@ def test_select_single_typehint():
_: QueryBuilder[UserDemo, Literal["SELECT"]] = query


def test_select_one_typehint():
query = select(UserDemo).one()
if TYPE_CHECKING:
_: QueryBuilder[UserDemo, Literal["SELECT_ONE"]] = query


def test_select_one_partial_typehint_error():
with pyright_raises("reportAttributeAccessIssue"):
if TYPE_CHECKING:
select(UserDemo.id).one() # type: ignore


def test_select_multiple_typehints():
query = select((UserDemo, UserDemo.id, UserDemo.name))
if TYPE_CHECKING:
Expand Down
31 changes: 29 additions & 2 deletions iceaxe/__tests__/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
UserDemo,
)
from iceaxe.base import INTERNAL_TABLE_FIELDS, TableBase
from iceaxe.exceptions import IceaxeQueryError
from iceaxe.exceptions import IceaxeQueryError, NoObjectFound
from iceaxe.field import Field
from iceaxe.functions import func
from iceaxe.queries import QueryBuilder
from iceaxe.queries import QueryBuilder, select
from iceaxe.schemas.cli import create_all
from iceaxe.session import (
PG_MAX_PARAMETERS,
Expand Down Expand Up @@ -220,6 +220,33 @@ async def test_select(db_connection: DBConnection):
]


@pytest.mark.asyncio
async def test_select_one(db_connection: DBConnection):
user = UserDemo(name="John Doe", email="john@example.com")
await db_connection.insert([user])

result = await db_connection.exec(
select(UserDemo).where(UserDemo.name == "John Doe").one()
)
assert result == UserDemo(id=user.id, name="John Doe", email="john@example.com")


@pytest.mark.asyncio
async def test_select_one_missing_raises(db_connection: DBConnection):
query = select(UserDemo).where(UserDemo.id == 999).one()

with pytest.raises(NoObjectFound, match="No UserDemo object found") as exc_info:
await db_connection.exec(query)

assert exc_info.value.object_type is UserDemo
assert exc_info.value.sql_text == (
'SELECT "userdemo"."id" AS "userdemo_id", "userdemo"."name" AS '
'"userdemo_name", "userdemo"."email" AS "userdemo_email" FROM "userdemo" '
'WHERE "userdemo"."id" = $1 LIMIT 1'
)
assert exc_info.value.variables == (999,)


@pytest.mark.asyncio
async def test_is_null(db_connection: DBConnection):
user = UserDemo(name="John Doe", email="john@example.com")
Expand Down
17 changes: 17 additions & 0 deletions iceaxe/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
from functools import lru_cache
from typing import Any

import asyncpg


class NoObjectFound(LookupError):
"""
Query completed successfully but returned no rows for a `.one()` lookup.
"""

def __init__(
self, object_type: type[Any], sql_text: str, variables: tuple[Any, ...]
):
self.object_type = object_type
self.sql_text = sql_text
self.variables = variables

context = f"\nQuery: {sql_text}\nVariables: {variables}"
super().__init__(f"No {object_type.__name__} object found.{context}")


class IceaxeQueryError(asyncpg.PostgresError):
"""
Query error that subclasses the original asyncpg exception type,
Expand Down
43 changes: 39 additions & 4 deletions iceaxe/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@
Ts = TypeVarTuple("Ts")


QueryType = TypeVar("QueryType", bound=Literal["SELECT", "INSERT", "UPDATE", "DELETE"])
QueryType = TypeVar(
"QueryType",
bound=Literal["SELECT", "SELECT_ONE", "INSERT", "UPDATE", "DELETE"],
)
TableSelectType = TypeVar("TableSelectType", bound=TableBase)


JoinType = Literal["INNER", "LEFT", "RIGHT", "FULL"]
Expand Down Expand Up @@ -145,6 +149,7 @@ def __init__(self):
self._main_model: Type[TableBase] | None = None

self._return_typehint: P
self._is_single_select_expression = False

self._where_conditions: list[FieldComparison | FieldComparisonGroup] = []
self._order_by_clauses: list[str] = []
Expand Down Expand Up @@ -429,8 +434,10 @@ def select(
]
if not isinstance(fields, tuple):
all_fields = (fields,) # type: ignore
self._is_single_select_expression = True
else:
all_fields = fields # type: ignore
self._is_single_select_expression = False

# Verify the field type
for field in all_fields:
Expand All @@ -448,6 +455,33 @@ def select(

return self # type: ignore

@allow_branching
def one(
self: QueryBuilder[TableSelectType, Literal["SELECT"]],
) -> QueryBuilder[TableSelectType, Literal["SELECT_ONE"]]:
"""
Converts a full-model SELECT query into a single-object SELECT.

This is currently limited to selects of a single full model expression and
always applies `LIMIT 1` at build time.

:return: A QueryBuilder instance configured for a single-object SELECT

"""
if self._query_type != "SELECT":
raise ValueError("one() is only valid on SELECT queries")

if (
not self._is_single_select_expression
or len(self._select_raw) != 1
or not is_base_table(self._select_raw[0])
):
raise ValueError("one() only supports selecting a single full table model.")

self._query_type = "SELECT_ONE" # type: ignore
self._limit_value = 1
return self # type: ignore

def _select_inner(
self,
fields: tuple[DBFieldClassDefinition | Type[TableBase] | FunctionMetadata, ...],
Expand Down Expand Up @@ -986,7 +1020,7 @@ def build(self) -> tuple[str, list[Any]]:
query = ""
variables: list[Any] = []

if self._query_type == "SELECT":
if self._query_type in ("SELECT", "SELECT_ONE"):
if not self._main_model:
raise ValueError("No model selected for query")

Expand Down Expand Up @@ -1055,8 +1089,9 @@ def build(self) -> tuple[str, list[Any]]:
if self._order_by_clauses:
query += " ORDER BY " + ", ".join(self._order_by_clauses)

if self._limit_value is not None:
query += f" LIMIT {self._limit_value}"
effective_limit = 1 if self._query_type == "SELECT_ONE" else self._limit_value
if effective_limit is not None:
query += f" LIMIT {effective_limit}"

if self._offset_value is not None:
query += f" OFFSET {self._offset_value}"
Expand Down
23 changes: 19 additions & 4 deletions iceaxe/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing_extensions import TypeVarTuple

from iceaxe.base import DBFieldClassDefinition, DBModelMetaclass, TableBase
from iceaxe.exceptions import wrap_query_error
from iceaxe.exceptions import NoObjectFound, wrap_query_error
from iceaxe.logging import LOGGER
from iceaxe.modifications import ModificationTracker
from iceaxe.queries import (
Expand Down Expand Up @@ -261,6 +261,11 @@ async def transaction(self, *, ensure: bool = False):
@overload
async def exec(self, query: QueryBuilder[T, Literal["SELECT"]]) -> list[T]: ...

@overload
async def exec(
self, query: QueryBuilder[TableType, Literal["SELECT_ONE"]]
) -> TableType: ...

@overload
async def exec(self, query: QueryBuilder[T, Literal["INSERT"]]) -> None: ...

Expand All @@ -273,10 +278,11 @@ async def exec(self, query: QueryBuilder[T, Literal["DELETE"]]) -> None: ...
async def exec(
self,
query: QueryBuilder[T, Literal["SELECT"]]
| QueryBuilder[TableType, Literal["SELECT_ONE"]]
| QueryBuilder[T, Literal["INSERT"]]
| QueryBuilder[T, Literal["UPDATE"]]
| QueryBuilder[T, Literal["DELETE"]],
) -> list[T] | None:
) -> list[T] | TableType | None:
"""
Execute a query built with QueryBuilder and return the results.

Expand Down Expand Up @@ -304,7 +310,7 @@ async def exec(
```

:param query: A QueryBuilder instance representing the query to execute
:return: For SELECT queries, returns a list of results. For other queries, returns None
:return: For SELECT queries, returns results. For other queries, returns None

"""
sql_text, variables = query.build()
Expand All @@ -314,7 +320,7 @@ async def exec(
except asyncpg.PostgresError as e:
raise wrap_query_error(e, sql_text, tuple(variables)) from e

if query._query_type == "SELECT":
if query._query_type in ("SELECT", "SELECT_ONE"):
# Pre-cache the select types for better performance
select_types = [
(
Expand All @@ -338,6 +344,15 @@ async def exec(
self.modification_tracker.track_modification
)

if query._query_type == "SELECT_ONE":
if not result_all:
raise NoObjectFound(
cast(type[TableBase], query._select_raw[0]),
sql_text,
tuple(variables),
)
return cast(TableType, result_all[0])

return cast(list[T], result_all)

return None
Expand Down
Loading