Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions fastapi_sqlalchemy_toolkit/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.relationships import Relationship
from sqlalchemy.sql import Select
from sqlalchemy.sql.base import _entity_namespace
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.expression import BinaryExpression, ColumnElement
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import ScalarElementColumnDefault
from sqlalchemy.sql.selectable import Exists
Expand Down Expand Up @@ -524,7 +526,9 @@ async def paginated_list(
self,
session: AsyncSession,
order_by: InstrumentedAttribute | UnaryExpression | None = ...,
filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ...,
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
] | None = ...,
nullable_filter_expressions: (
dict[InstrumentedAttribute | Callable, Any] | None
) = ...,
Expand All @@ -540,7 +544,9 @@ async def paginated_list(
self,
session: AsyncSession,
order_by: InstrumentedAttribute | UnaryExpression | None = ...,
filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ...,
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
] | None = ...,
nullable_filter_expressions: (
dict[InstrumentedAttribute | Callable, Any] | None
) = ...,
Expand All @@ -555,7 +561,9 @@ async def paginated_list(
self,
session: AsyncSession,
order_by: InstrumentedAttribute | UnaryExpression | None = None,
filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = None,
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
] | None = None,
nullable_filter_expressions: (
dict[InstrumentedAttribute | Callable, Any] | None
) = None,
Expand Down Expand Up @@ -618,7 +626,10 @@ async def paginated_list(
)

for filter_expression, value in filter_expressions.items():
if isinstance(filter_expression, InstrumentedAttribute | Function):
if isinstance(
filter_expression,
InstrumentedAttribute | Function | BinaryExpression | ColumnElement
):
stmt = stmt.filter(filter_expression == value)
else:
stmt = stmt.filter(filter_expression(value))
Expand Down Expand Up @@ -716,7 +727,9 @@ async def list(
self,
session: AsyncSession,
order_by: InstrumentedAttribute | UnaryExpression | None = ...,
filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ...,
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
] | None = ...,
nullable_filter_expressions: (
dict[InstrumentedAttribute | Callable, Any] | None
) = ...,
Expand All @@ -735,7 +748,9 @@ async def list(
self,
session: AsyncSession,
order_by: InstrumentedAttribute | UnaryExpression | None = ...,
filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ...,
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
] | None = ...,
nullable_filter_expressions: (
dict[InstrumentedAttribute | Callable, Any] | None
) = ...,
Expand All @@ -753,7 +768,9 @@ async def list(
self,
session: AsyncSession,
order_by: InstrumentedAttribute | UnaryExpression | None = None,
filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = None,
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
] | None = None,
nullable_filter_expressions: (
dict[InstrumentedAttribute | Callable, Any] | None
) = None,
Expand Down Expand Up @@ -830,7 +847,10 @@ async def list(
)

for filter_expression, value in filter_expressions.items():
if isinstance(filter_expression, InstrumentedAttribute):
if isinstance(
filter_expression,
InstrumentedAttribute | BinaryExpression | ColumnElement
):
stmt = stmt.filter(filter_expression == value)
else:
stmt = stmt.filter(filter_expression(value))
Expand Down Expand Up @@ -1075,7 +1095,9 @@ def get_select(self, base_stmt: Select | None = None, **_kwargs: Any) -> Select:
def get_joins(
self,
base_query: Select,
filter_expressions: dict[InstrumentedAttribute | Callable, Any],
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
],
options: List[Any] | None = None,
order_by: InstrumentedAttribute | UnaryExpression | None = None,
) -> Select:
Expand Down Expand Up @@ -1109,7 +1131,7 @@ def get_joins(
elif isinstance(filter_expression, Function):
model = filter_expression.entity_namespace
else:
model = filter_expression.__self__.parent._identity_class
model = _entity_namespace(filter_expression)
if model != self.model:
models_to_join.add(model)
for model in models_to_join:
Expand Down Expand Up @@ -1157,7 +1179,9 @@ def remove_optional_filter_bys(

@staticmethod
def handle_filter_expressions(
filter_expressions: dict[InstrumentedAttribute | Callable, Any],
filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
],
) -> None:
for filter_expression, value in filter_expressions.copy().items():
if value is None:
Expand All @@ -1169,7 +1193,9 @@ def handle_filter_expressions(

@staticmethod
def handle_nullable_filter_expressions(
nullable_filter_expressions: dict[InstrumentedAttribute | Callable, Any],
nullable_filter_expressions: dict[
InstrumentedAttribute | Callable | ColumnElement, Any
],
) -> None:
for filter_expression, value in nullable_filter_expressions.copy().items():
if value in null_query_values:
Expand Down
Loading