diff --git a/server/mergin/sync/public_api_v2_controller.py b/server/mergin/sync/public_api_v2_controller.py index 3e28aa40..2ad68b44 100644 --- a/server/mergin/sync/public_api_v2_controller.py +++ b/server/mergin/sync/public_api_v2_controller.py @@ -49,7 +49,7 @@ from .storages.disk import move_to_tmp, save_to_file from .utils import get_device_id, get_ip, get_user_agent, get_chunk_location from .workspace import WorkspaceRole -from ..utils import parse_order_params +from ..utils import parse_order_params, get_schema_fields_map @auth_required @@ -437,7 +437,9 @@ def list_workspace_projects(workspace_id, page, per_page, order_params=None, q=N projects = projects.filter(Project.name.ilike(f"%{q}%")) if order_params: - order_by_params = parse_order_params(Project, order_params) + order_by_params = parse_order_params( + Project, order_params, field_map=ProjectSchemaV2.field_map + ) projects = projects.order_by(*order_by_params) result = projects.paginate(page, per_page).items diff --git a/server/mergin/sync/schemas_v2.py b/server/mergin/sync/schemas_v2.py index d6b781ee..3d2ce9af 100644 --- a/server/mergin/sync/schemas_v2.py +++ b/server/mergin/sync/schemas_v2.py @@ -11,6 +11,7 @@ Project, ProjectVersion, ) +from ..utils import get_schema_fields_map class ProjectSchema(ma.SQLAlchemyAutoSchema): @@ -46,3 +47,6 @@ class Meta: "workspace", "role", ) + + +ProjectSchema.field_map = get_schema_fields_map(ProjectSchema) diff --git a/server/mergin/tests/test_public_api_v2.py b/server/mergin/tests/test_public_api_v2.py index e058c589..f3131cae 100644 --- a/server/mergin/tests/test_public_api_v2.py +++ b/server/mergin/tests/test_public_api_v2.py @@ -643,6 +643,13 @@ def test_list_workspace_projects(client): url + f"?page={page}&per_page={per_page}&q=1&order_params=created DESC" ) assert response.json["projects"][0]["name"] == "project_10" + # using field name instead column names for sorting + p4 = Project.query.filter(Project.name == project_name).first() + p4.disk_usage = 1234567 + db.session.commit() + response = client.get(url + f"?page=1&per_page=10&order_params=size DESC") + resp_data = json.loads(response.data) + assert resp_data["projects"][0]["name"] == project_name # no permissions to workspace user2 = add_user("user", "password") diff --git a/server/mergin/utils.py b/server/mergin/utils.py index 9acc6124..39609abd 100644 --- a/server/mergin/utils.py +++ b/server/mergin/utils.py @@ -1,13 +1,16 @@ # Copyright (C) Lutra Consulting Limited # # SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-MerginMaps-Commercial +import logging + import math from collections import namedtuple from datetime import datetime, timedelta, timezone from enum import Enum import os -from flask import current_app +from flask import current_app, abort from flask_sqlalchemy import Model +from marshmallow import Schema from pathvalidate import sanitize_filename from sqlalchemy import Column, JSON from sqlalchemy.sql.elements import UnaryExpression @@ -33,7 +36,7 @@ def split_order_param(order_param: str) -> Optional[OrderParam]: def get_order_param( - cls: Model, order_param: OrderParam, json_sort: dict = None + cls: Model, order_param: OrderParam, json_sort: dict = None, field_map: dict = None ) -> Optional[UnaryExpression]: """Return order by clause parameter for SQL query @@ -43,15 +46,22 @@ def get_order_param( :type order_param: OrderParam :param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None :type json_sort: dict + :param field_map: mapping for translating public field names to internal DB columns, e.g. '{"size": "disk_usage"}' + :type field_map: dict """ + # translate field name to column name + db_column_name = order_param.name + if field_map and order_param.name in field_map: + db_column_name = field_map[order_param.name] # find candidate for nested json sort - if "." in order_param.name: - col, attr = order_param.name.split(".") + if "." in db_column_name: + col, attr = db_column_name.split(".") else: - col = order_param.name + col = db_column_name attr = None order_attr = cls.__table__.c.get(col, None) if not isinstance(order_attr, Column): + logging.warning("Ignoring invalid order parameter.") return # sort by key in JSON field if attr: @@ -80,7 +90,9 @@ def get_order_param( return order_attr.desc() -def parse_order_params(cls: Model, order_params: str, json_sort: dict = None): +def parse_order_params( + cls: Model, order_params: str, json_sort: dict = None, field_map: dict = None +) -> list[UnaryExpression]: """Convert order parameters in query string to list of order by clauses. :param cls: Db model class @@ -89,6 +101,8 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None): :type order_params: str :param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None :type json_sort: dict + :param field_map: mapping response fields to database column names, e.g. '{"size": "disk_usage"}' + :type field_map: dict :rtype: List[Column] """ @@ -97,7 +111,7 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None): order_param = split_order_param(p) if not order_param: continue - order_attr = get_order_param(cls, order_param, json_sort) + order_attr = get_order_param(cls, order_param, json_sort, field_map) if order_attr is not None: order_by_params.append(order_attr) return order_by_params @@ -135,3 +149,16 @@ def save_diagnostic_log_file(app: str, username: str, body: bytes) -> str: f.write(content) return file_name + + +def get_schema_fields_map(schema: Schema) -> dict: + """ + Creates a mapping of schema field names to corresponding DB columns. + This allows sorting by the API field name (e.g. 'size') while + actually sorting by the database column (e.g. 'disk_usage'). + """ + mapping = {} + for name, field in schema._declared_fields.items(): + if field and field.attribute: + mapping[name] = field.attribute + return mapping