diff --git a/fastapi_sqlalchemy_toolkit/model_manager.py b/fastapi_sqlalchemy_toolkit/model_manager.py index df7cc12..a50de71 100644 --- a/fastapi_sqlalchemy_toolkit/model_manager.py +++ b/fastapi_sqlalchemy_toolkit/model_manager.py @@ -25,7 +25,9 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.relationships import Relationship from sqlalchemy.sql import Select +from sqlalchemy.sql.base import _entity_namespace from sqlalchemy.sql.elements import UnaryExpression +from sqlalchemy.sql.expression import BinaryExpression, ColumnElement from sqlalchemy.sql.functions import Function from sqlalchemy.sql.schema import ScalarElementColumnDefault from sqlalchemy.sql.selectable import Exists @@ -524,7 +526,9 @@ async def paginated_list( self, session: AsyncSession, order_by: InstrumentedAttribute | UnaryExpression | None = ..., - filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ..., + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ] | None = ..., nullable_filter_expressions: ( dict[InstrumentedAttribute | Callable, Any] | None ) = ..., @@ -540,7 +544,9 @@ async def paginated_list( self, session: AsyncSession, order_by: InstrumentedAttribute | UnaryExpression | None = ..., - filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ..., + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ] | None = ..., nullable_filter_expressions: ( dict[InstrumentedAttribute | Callable, Any] | None ) = ..., @@ -555,7 +561,9 @@ async def paginated_list( self, session: AsyncSession, order_by: InstrumentedAttribute | UnaryExpression | None = None, - filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = None, + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ] | None = None, nullable_filter_expressions: ( dict[InstrumentedAttribute | Callable, Any] | None ) = None, @@ -618,7 +626,10 @@ async def paginated_list( ) for filter_expression, value in filter_expressions.items(): - if isinstance(filter_expression, InstrumentedAttribute | Function): + if isinstance( + filter_expression, + InstrumentedAttribute | Function | BinaryExpression | ColumnElement + ): stmt = stmt.filter(filter_expression == value) else: stmt = stmt.filter(filter_expression(value)) @@ -716,7 +727,9 @@ async def list( self, session: AsyncSession, order_by: InstrumentedAttribute | UnaryExpression | None = ..., - filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ..., + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ] | None = ..., nullable_filter_expressions: ( dict[InstrumentedAttribute | Callable, Any] | None ) = ..., @@ -735,7 +748,9 @@ async def list( self, session: AsyncSession, order_by: InstrumentedAttribute | UnaryExpression | None = ..., - filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = ..., + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ] | None = ..., nullable_filter_expressions: ( dict[InstrumentedAttribute | Callable, Any] | None ) = ..., @@ -753,7 +768,9 @@ async def list( self, session: AsyncSession, order_by: InstrumentedAttribute | UnaryExpression | None = None, - filter_expressions: dict[InstrumentedAttribute | Callable, Any] | None = None, + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ] | None = None, nullable_filter_expressions: ( dict[InstrumentedAttribute | Callable, Any] | None ) = None, @@ -830,7 +847,10 @@ async def list( ) for filter_expression, value in filter_expressions.items(): - if isinstance(filter_expression, InstrumentedAttribute): + if isinstance( + filter_expression, + InstrumentedAttribute | BinaryExpression | ColumnElement + ): stmt = stmt.filter(filter_expression == value) else: stmt = stmt.filter(filter_expression(value)) @@ -1075,7 +1095,9 @@ def get_select(self, base_stmt: Select | None = None, **_kwargs: Any) -> Select: def get_joins( self, base_query: Select, - filter_expressions: dict[InstrumentedAttribute | Callable, Any], + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ], options: List[Any] | None = None, order_by: InstrumentedAttribute | UnaryExpression | None = None, ) -> Select: @@ -1109,7 +1131,7 @@ def get_joins( elif isinstance(filter_expression, Function): model = filter_expression.entity_namespace else: - model = filter_expression.__self__.parent._identity_class + model = _entity_namespace(filter_expression) if model != self.model: models_to_join.add(model) for model in models_to_join: @@ -1157,7 +1179,9 @@ def remove_optional_filter_bys( @staticmethod def handle_filter_expressions( - filter_expressions: dict[InstrumentedAttribute | Callable, Any], + filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ], ) -> None: for filter_expression, value in filter_expressions.copy().items(): if value is None: @@ -1169,7 +1193,9 @@ def handle_filter_expressions( @staticmethod def handle_nullable_filter_expressions( - nullable_filter_expressions: dict[InstrumentedAttribute | Callable, Any], + nullable_filter_expressions: dict[ + InstrumentedAttribute | Callable | ColumnElement, Any + ], ) -> None: for filter_expression, value in nullable_filter_expressions.copy().items(): if value in null_query_values: