diff --git a/backend/api/decorators/__init__.py b/backend/api/decorators/__init__.py index 5c8500d2..ef535125 100644 --- a/backend/api/decorators/__init__.py +++ b/backend/api/decorators/__init__.py @@ -1,3 +1,20 @@ +from api.decorators.filter_sort import ( + apply_filtering, + apply_sorting, + filtered_and_sorted_query, +) +from api.decorators.full_text_search import ( + apply_full_text_search, + full_text_search_query, +) from api.decorators.pagination import apply_pagination, paginated_query -__all__ = ["apply_pagination", "paginated_query"] +__all__ = [ + "apply_pagination", + "paginated_query", + "apply_sorting", + "apply_filtering", + "filtered_and_sorted_query", + "apply_full_text_search", + "full_text_search_query", +] diff --git a/backend/api/decorators/filter_sort.py b/backend/api/decorators/filter_sort.py new file mode 100644 index 00000000..c38cacb0 --- /dev/null +++ b/backend/api/decorators/filter_sort.py @@ -0,0 +1,694 @@ +from datetime import date as date_type +from functools import wraps +from typing import Any, Callable, TypeVar + +import strawberry +from api.decorators.pagination import apply_pagination +from api.inputs import ( + ColumnType, + FilterInput, + FilterOperator, + PaginationInput, + SortDirection, + SortInput, +) +from database import models +from database.models.base import Base +from sqlalchemy import Select, and_, func, or_, select +from sqlalchemy.orm import aliased + +T = TypeVar("T") + + +def detect_entity_type(model_class: type[Base]) -> str | None: + if model_class == models.Patient: + return "patient" + if model_class == models.Task: + return "task" + return None + + +def get_property_value_column(field_type: str) -> str: + field_type_mapping = { + "FIELD_TYPE_TEXT": "text_value", + "FIELD_TYPE_NUMBER": "number_value", + "FIELD_TYPE_CHECKBOX": "boolean_value", + "FIELD_TYPE_DATE": "date_value", + "FIELD_TYPE_DATE_TIME": "date_time_value", + "FIELD_TYPE_SELECT": "select_value", + "FIELD_TYPE_MULTI_SELECT": "multi_select_values", + } + return field_type_mapping.get(field_type, "text_value") + + +def get_property_join_alias( + query: Select[Any], + model_class: type[Base], + property_definition_id: str, + field_type: str, +) -> Any: + entity_type = detect_entity_type(model_class) + if not entity_type: + raise ValueError( + f"Unsupported entity type for property filtering: {model_class}" + ) + + property_alias = aliased(models.PropertyValue) + value_column = get_property_value_column(field_type) + + if entity_type == "patient": + join_condition = and_( + property_alias.patient_id == model_class.id, + property_alias.definition_id == property_definition_id, + ) + else: + join_condition = and_( + property_alias.task_id == model_class.id, + property_alias.definition_id == property_definition_id, + ) + + query = query.outerjoin(property_alias, join_condition) + return query, property_alias, getattr(property_alias, value_column) + + +def apply_sorting( + query: Select[Any], + sorting: list[SortInput] | None, + model_class: type[Base], + property_field_types: dict[str, str] | None = None, +) -> Select[Any]: + if not sorting: + return query + + order_by_clauses = [] + property_field_types = property_field_types or {} + + for sort_input in sorting: + if sort_input.column_type == ColumnType.DIRECT_ATTRIBUTE: + try: + column = getattr(model_class, sort_input.column) + if sort_input.direction == SortDirection.DESC: + order_by_clauses.append(column.desc()) + else: + order_by_clauses.append(column.asc()) + except AttributeError: + continue + + elif sort_input.column_type == ColumnType.PROPERTY: + if not sort_input.property_definition_id: + continue + + field_type = property_field_types.get( + sort_input.property_definition_id, + "FIELD_TYPE_TEXT" + ) + query, property_alias, value_column = ( + get_property_join_alias( + query, + model_class, + sort_input.property_definition_id, + field_type, + ) + ) + + if sort_input.direction == SortDirection.DESC: + order_by_clauses.append(value_column.desc().nulls_last()) + else: + order_by_clauses.append(value_column.asc().nulls_first()) + + if order_by_clauses: + query = query.order_by(*order_by_clauses) + + return query + + +def apply_text_filter( + column: Any, operator: FilterOperator, parameter: Any +) -> Any: + search_text = parameter.search_text + if search_text is None: + return None + + is_case_sensitive = parameter.is_case_sensitive + + if is_case_sensitive: + if operator == FilterOperator.TEXT_EQUALS: + return column.like(search_text) + if operator == FilterOperator.TEXT_NOT_EQUALS: + return ~column.like(search_text) + if operator == FilterOperator.TEXT_NOT_WHITESPACE: + return func.trim(column) != "" + if operator == FilterOperator.TEXT_CONTAINS: + return column.like(f"%{search_text}%") + if operator == FilterOperator.TEXT_NOT_CONTAINS: + return ~column.like(f"%{search_text}%") + if operator == FilterOperator.TEXT_STARTS_WITH: + return column.like(f"{search_text}%") + if operator == FilterOperator.TEXT_ENDS_WITH: + return column.like(f"%{search_text}") + else: + if operator == FilterOperator.TEXT_EQUALS: + return column.ilike(search_text) + if operator == FilterOperator.TEXT_NOT_EQUALS: + return ~column.ilike(search_text) + if operator == FilterOperator.TEXT_NOT_WHITESPACE: + return func.trim(column) != "" + if operator == FilterOperator.TEXT_CONTAINS: + return column.ilike(f"%{search_text}%") + if operator == FilterOperator.TEXT_NOT_CONTAINS: + return ~column.ilike(f"%{search_text}%") + if operator == FilterOperator.TEXT_STARTS_WITH: + return column.ilike(f"{search_text}%") + if operator == FilterOperator.TEXT_ENDS_WITH: + return column.ilike(f"%{search_text}") + + return None + + +def apply_number_filter( + column: Any, operator: FilterOperator, parameter: Any +) -> Any: + compare_value = parameter.compare_value + min_value = parameter.min + max_value = parameter.max + + if operator == FilterOperator.NUMBER_EQUALS: + if compare_value is not None: + return column == compare_value + elif operator == FilterOperator.NUMBER_NOT_EQUALS: + if compare_value is not None: + return column != compare_value + elif operator == FilterOperator.NUMBER_GREATER_THAN: + if compare_value is not None: + return column > compare_value + elif operator == FilterOperator.NUMBER_GREATER_THAN_OR_EQUAL: + if compare_value is not None: + return column >= compare_value + elif operator == FilterOperator.NUMBER_LESS_THAN: + if compare_value is not None: + return column < compare_value + elif operator == FilterOperator.NUMBER_LESS_THAN_OR_EQUAL: + if compare_value is not None: + return column <= compare_value + elif operator == FilterOperator.NUMBER_BETWEEN: + if min_value is not None and max_value is not None: + return column.between(min_value, max_value) + elif operator == FilterOperator.NUMBER_NOT_BETWEEN: + if min_value is not None and max_value is not None: + return ~column.between(min_value, max_value) + + return None + + +def normalize_date_for_comparison(date_value: Any) -> Any: + return date_value + + +def apply_date_filter( + column: Any, operator: FilterOperator, parameter: Any +) -> Any: + compare_date = parameter.compare_date + min_date = parameter.min_date + max_date = parameter.max_date + + if operator == FilterOperator.DATE_EQUALS: + if compare_date is not None: + if isinstance(compare_date, date_type): + return func.date(column) == compare_date + return column == compare_date + elif operator == FilterOperator.DATE_NOT_EQUALS: + if compare_date is not None: + if isinstance(compare_date, date_type): + return func.date(column) != compare_date + return column != compare_date + elif operator == FilterOperator.DATE_GREATER_THAN: + if compare_date is not None: + if isinstance(compare_date, date_type): + return func.date(column) > compare_date + return column > compare_date + elif operator == FilterOperator.DATE_GREATER_THAN_OR_EQUAL: + if compare_date is not None: + if isinstance(compare_date, date_type): + return func.date(column) >= compare_date + return column >= compare_date + elif operator == FilterOperator.DATE_LESS_THAN: + if compare_date is not None: + if isinstance(compare_date, date_type): + return func.date(column) < compare_date + return column < compare_date + elif operator == FilterOperator.DATE_LESS_THAN_OR_EQUAL: + if compare_date is not None: + if isinstance(compare_date, date_type): + return func.date(column) <= compare_date + return column <= compare_date + elif operator == FilterOperator.DATE_BETWEEN: + if min_date is not None and max_date is not None: + if isinstance(min_date, date_type) and isinstance(max_date, date_type): + return func.date(column).between(min_date, max_date) + return column.between(min_date, max_date) + elif operator == FilterOperator.DATE_NOT_BETWEEN: + if min_date is not None and max_date is not None: + if isinstance(min_date, date_type) and isinstance(max_date, date_type): + return ~func.date(column).between(min_date, max_date) + return ~column.between(min_date, max_date) + + return None + + +def apply_datetime_filter( + column: Any, operator: FilterOperator, parameter: Any +) -> Any: + compare_date_time = parameter.compare_date_time + min_date_time = parameter.min_date_time + max_date_time = parameter.max_date_time + + if operator == FilterOperator.DATETIME_EQUALS: + if compare_date_time is not None: + return column == compare_date_time + elif operator == FilterOperator.DATETIME_NOT_EQUALS: + if compare_date_time is not None: + return column != compare_date_time + elif operator == FilterOperator.DATETIME_GREATER_THAN: + if compare_date_time is not None: + return column > compare_date_time + elif operator == FilterOperator.DATETIME_GREATER_THAN_OR_EQUAL: + if compare_date_time is not None: + return column >= compare_date_time + elif operator == FilterOperator.DATETIME_LESS_THAN: + if compare_date_time is not None: + return column < compare_date_time + elif operator == FilterOperator.DATETIME_LESS_THAN_OR_EQUAL: + if compare_date_time is not None: + return column <= compare_date_time + elif operator == FilterOperator.DATETIME_BETWEEN: + if min_date_time is not None and max_date_time is not None: + return column.between(min_date_time, max_date_time) + elif operator == FilterOperator.DATETIME_NOT_BETWEEN: + if min_date_time is not None and max_date_time is not None: + return ~column.between(min_date_time, max_date_time) + + return None + + +def apply_boolean_filter( + column: Any, operator: FilterOperator, parameter: Any +) -> Any: + if operator == FilterOperator.BOOLEAN_IS_TRUE: + return column.is_(True) + if operator == FilterOperator.BOOLEAN_IS_FALSE: + return column.is_(False) + return None + + +def apply_tags_filter(column: Any, operator: FilterOperator, parameter: Any) -> Any: + search_tags = parameter.search_tags + if not search_tags: + return None + + if operator == FilterOperator.TAGS_EQUALS: + tags_str = ",".join(sorted(search_tags)) + return column == tags_str + if operator == FilterOperator.TAGS_NOT_EQUALS: + tags_str = ",".join(sorted(search_tags)) + return column != tags_str + if operator == FilterOperator.TAGS_CONTAINS: + conditions = [] + for tag in search_tags: + conditions.append(column.contains(tag)) + return and_(*conditions) + if operator == FilterOperator.TAGS_NOT_CONTAINS: + conditions = [] + for tag in search_tags: + conditions.append(~column.contains(tag)) + return or_(*conditions) + + return None + + +def apply_tags_single_filter( + column: Any, operator: FilterOperator, parameter: Any +) -> Any: + search_tags = parameter.search_tags + if not search_tags: + return None + + if operator == FilterOperator.TAGS_SINGLE_EQUALS: + if len(search_tags) == 1: + return column == search_tags[0] + if operator == FilterOperator.TAGS_SINGLE_NOT_EQUALS: + if len(search_tags) == 1: + return column != search_tags[0] + if operator == FilterOperator.TAGS_SINGLE_CONTAINS: + conditions = [] + for tag in search_tags: + conditions.append(column == tag) + return or_(*conditions) + if operator == FilterOperator.TAGS_SINGLE_NOT_CONTAINS: + conditions = [] + for tag in search_tags: + conditions.append(column != tag) + return and_(*conditions) + + return None + + +def apply_null_filter( + column: Any, operator: FilterOperator, parameter: Any +) -> Any: + if operator == FilterOperator.IS_NULL: + return column.is_(None) + if operator == FilterOperator.IS_NOT_NULL: + return column.isnot(None) + return None + + +def apply_filtering( + query: Select[Any], + filtering: list[FilterInput] | None, + model_class: type[Base], + property_field_types: dict[str, str] | None = None, +) -> Select[Any]: + if not filtering: + return query + + filter_conditions = [] + property_field_types = property_field_types or {} + + for filter_input in filtering: + condition = None + + if filter_input.column_type == ColumnType.DIRECT_ATTRIBUTE: + try: + column = getattr(model_class, filter_input.column) + except AttributeError: + continue + + operator = filter_input.operator + parameter = filter_input.parameter + + if operator in [ + FilterOperator.TEXT_EQUALS, + FilterOperator.TEXT_NOT_EQUALS, + FilterOperator.TEXT_NOT_WHITESPACE, + FilterOperator.TEXT_CONTAINS, + FilterOperator.TEXT_NOT_CONTAINS, + FilterOperator.TEXT_STARTS_WITH, + FilterOperator.TEXT_ENDS_WITH, + ]: + condition = apply_text_filter(column, operator, parameter) + + elif operator in [ + FilterOperator.NUMBER_EQUALS, + FilterOperator.NUMBER_NOT_EQUALS, + FilterOperator.NUMBER_GREATER_THAN, + FilterOperator.NUMBER_GREATER_THAN_OR_EQUAL, + FilterOperator.NUMBER_LESS_THAN, + FilterOperator.NUMBER_LESS_THAN_OR_EQUAL, + FilterOperator.NUMBER_BETWEEN, + FilterOperator.NUMBER_NOT_BETWEEN, + ]: + condition = apply_number_filter(column, operator, parameter) + + elif operator in [ + FilterOperator.DATE_EQUALS, + FilterOperator.DATE_NOT_EQUALS, + FilterOperator.DATE_GREATER_THAN, + FilterOperator.DATE_GREATER_THAN_OR_EQUAL, + FilterOperator.DATE_LESS_THAN, + FilterOperator.DATE_LESS_THAN_OR_EQUAL, + FilterOperator.DATE_BETWEEN, + FilterOperator.DATE_NOT_BETWEEN, + ]: + condition = apply_date_filter(column, operator, parameter) + + elif operator in [ + FilterOperator.DATETIME_EQUALS, + FilterOperator.DATETIME_NOT_EQUALS, + FilterOperator.DATETIME_GREATER_THAN, + FilterOperator.DATETIME_GREATER_THAN_OR_EQUAL, + FilterOperator.DATETIME_LESS_THAN, + FilterOperator.DATETIME_LESS_THAN_OR_EQUAL, + FilterOperator.DATETIME_BETWEEN, + FilterOperator.DATETIME_NOT_BETWEEN, + ]: + condition = apply_datetime_filter(column, operator, parameter) + + elif operator in [ + FilterOperator.BOOLEAN_IS_TRUE, + FilterOperator.BOOLEAN_IS_FALSE, + ]: + condition = apply_boolean_filter(column, operator, parameter) + + elif operator in [ + FilterOperator.IS_NULL, + FilterOperator.IS_NOT_NULL, + ]: + condition = apply_null_filter(column, operator, parameter) + + elif filter_input.column_type == ColumnType.PROPERTY: + if not filter_input.property_definition_id: + continue + + field_type = property_field_types.get( + filter_input.property_definition_id, + "FIELD_TYPE_TEXT" + ) + query, property_alias, value_column = get_property_join_alias( + query, model_class, filter_input.property_definition_id, field_type + ) + + operator = filter_input.operator + parameter = filter_input.parameter + + if operator in [ + FilterOperator.TEXT_EQUALS, + FilterOperator.TEXT_NOT_EQUALS, + FilterOperator.TEXT_NOT_WHITESPACE, + FilterOperator.TEXT_CONTAINS, + FilterOperator.TEXT_NOT_CONTAINS, + FilterOperator.TEXT_STARTS_WITH, + FilterOperator.TEXT_ENDS_WITH, + ]: + condition = apply_text_filter( + value_column, operator, parameter + ) + + elif operator in [ + FilterOperator.NUMBER_EQUALS, + FilterOperator.NUMBER_NOT_EQUALS, + FilterOperator.NUMBER_GREATER_THAN, + FilterOperator.NUMBER_GREATER_THAN_OR_EQUAL, + FilterOperator.NUMBER_LESS_THAN, + FilterOperator.NUMBER_LESS_THAN_OR_EQUAL, + FilterOperator.NUMBER_BETWEEN, + FilterOperator.NUMBER_NOT_BETWEEN, + ]: + condition = apply_number_filter(value_column, operator, parameter) + + elif operator in [ + FilterOperator.DATE_EQUALS, + FilterOperator.DATE_NOT_EQUALS, + FilterOperator.DATE_GREATER_THAN, + FilterOperator.DATE_GREATER_THAN_OR_EQUAL, + FilterOperator.DATE_LESS_THAN, + FilterOperator.DATE_LESS_THAN_OR_EQUAL, + FilterOperator.DATE_BETWEEN, + FilterOperator.DATE_NOT_BETWEEN, + ]: + condition = apply_date_filter( + value_column, operator, parameter + ) + + elif operator in [ + FilterOperator.DATETIME_EQUALS, + FilterOperator.DATETIME_NOT_EQUALS, + FilterOperator.DATETIME_GREATER_THAN, + FilterOperator.DATETIME_GREATER_THAN_OR_EQUAL, + FilterOperator.DATETIME_LESS_THAN, + FilterOperator.DATETIME_LESS_THAN_OR_EQUAL, + FilterOperator.DATETIME_BETWEEN, + FilterOperator.DATETIME_NOT_BETWEEN, + ]: + condition = apply_datetime_filter(value_column, operator, parameter) + + elif operator in [ + FilterOperator.BOOLEAN_IS_TRUE, + FilterOperator.BOOLEAN_IS_FALSE, + ]: + condition = apply_boolean_filter( + value_column, operator, parameter + ) + + elif operator in [ + FilterOperator.TAGS_EQUALS, + FilterOperator.TAGS_NOT_EQUALS, + FilterOperator.TAGS_CONTAINS, + FilterOperator.TAGS_NOT_CONTAINS, + ]: + condition = apply_tags_filter( + value_column, operator, parameter + ) + + elif operator in [ + FilterOperator.TAGS_SINGLE_EQUALS, + FilterOperator.TAGS_SINGLE_NOT_EQUALS, + FilterOperator.TAGS_SINGLE_CONTAINS, + FilterOperator.TAGS_SINGLE_NOT_CONTAINS, + ]: + condition = apply_tags_single_filter(value_column, operator, parameter) + + elif operator in [ + FilterOperator.IS_NULL, + FilterOperator.IS_NOT_NULL, + ]: + condition = apply_null_filter( + value_column, operator, parameter + ) + + if condition is not None: + filter_conditions.append(condition) + + if filter_conditions: + query = query.where(and_(*filter_conditions)) + + return query + + +def filtered_and_sorted_query( + filtering_param: str = "filtering", + sorting_param: str = "sorting", + pagination_param: str = "pagination", +): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + filtering: list[FilterInput] | None = kwargs.get(filtering_param) + sorting: list[SortInput] | None = kwargs.get(sorting_param) + pagination: PaginationInput | None = kwargs.get(pagination_param) + + result = await func(*args, **kwargs) + + if not isinstance(result, Select): + return result + + model_class = result.column_descriptions[0]["entity"] + if not model_class: + if isinstance(result, Select): + for arg in args: + if ( + hasattr(arg, "context") + and hasattr(arg.context, "db") + ): + db = arg.context.db + query_result = await db.execute(result) + return query_result.scalars().all() + else: + info = kwargs.get("info") + if ( + info + and hasattr(info, "context") + and hasattr(info.context, "db") + ): + db = info.context.db + query_result = await db.execute(result) + return query_result.scalars().all() + return result + + property_field_types: dict[str, str] = {} + + if filtering or sorting: + property_def_ids = set() + if filtering: + for f in filtering: + if ( + f.column_type == ColumnType.PROPERTY + and f.property_definition_id + ): + property_def_ids.add(f.property_definition_id) + if sorting: + for s in sorting: + if ( + s.column_type == ColumnType.PROPERTY + and s.property_definition_id + ): + property_def_ids.add(s.property_definition_id) + + if property_def_ids: + for arg in args: + if ( + hasattr(arg, "context") + and hasattr(arg.context, "db") + ): + db = arg.context.db + prop_defs_result = await db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type + for prop_def in prop_defs + } + break + else: + info = kwargs.get("info") + if ( + info + and hasattr(info, "context") + and hasattr(info.context, "db") + ): + db = info.context.db + prop_defs_result = await db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type + for prop_def in prop_defs + } + + if filtering: + result = apply_filtering( + result, filtering, model_class, property_field_types + ) + + if sorting: + result = apply_sorting( + result, sorting, model_class, property_field_types + ) + + if pagination and pagination is not strawberry.UNSET: + page_index = pagination.page_index + page_size = pagination.page_size + if page_size: + offset = page_index * page_size + result = apply_pagination(result, limit=page_size, offset=offset) + + if isinstance(result, Select): + for arg in args: + if ( + hasattr(arg, "context") + and hasattr(arg.context, "db") + ): + db = arg.context.db + query_result = await db.execute(result) + return query_result.scalars().all() + else: + info = kwargs.get("info") + if ( + info + and hasattr(info, "context") + and hasattr(info.context, "db") + ): + db = info.context.db + query_result = await db.execute(result) + return query_result.scalars().all() + + return result + + return wrapper + + return decorator diff --git a/backend/api/decorators/full_text_search.py b/backend/api/decorators/full_text_search.py new file mode 100644 index 00000000..b251996d --- /dev/null +++ b/backend/api/decorators/full_text_search.py @@ -0,0 +1,116 @@ +from functools import wraps +from typing import Any, Callable, TypeVar + +import strawberry +from api.inputs import FullTextSearchInput +from database import models +from database.models.base import Base +from sqlalchemy import Select, String, and_, inspect, or_ +from sqlalchemy.orm import aliased + +T = TypeVar("T") + + +def detect_entity_type(model_class: type[Base]) -> str | None: + if model_class == models.Patient: + return "patient" + if model_class == models.Task: + return "task" + return None + + +def get_text_columns_from_model(model_class: type[Base]) -> list[str]: + mapper = inspect(model_class) + text_columns = [] + for column in mapper.columns: + if isinstance(column.type, String): + text_columns.append(column.key) + return text_columns + + +def apply_full_text_search( + query: Select[Any], + search_input: FullTextSearchInput, + model_class: type[Base], +) -> Select[Any]: + if not search_input.search_text or not search_input.search_text.strip(): + return query + + search_text = search_input.search_text.strip() + search_pattern = f"%{search_text}%" + + search_conditions = [] + + columns_to_search = search_input.search_columns + if columns_to_search is None: + columns_to_search = get_text_columns_from_model(model_class) + + for column_name in columns_to_search: + try: + column = getattr(model_class, column_name) + search_conditions.append(column.ilike(search_pattern)) + except AttributeError: + continue + + if search_input.include_properties: + entity_type = detect_entity_type(model_class) + if entity_type: + property_alias = aliased(models.PropertyValue) + + if entity_type == "patient": + join_condition = property_alias.patient_id == model_class.id + else: + join_condition = property_alias.task_id == model_class.id + + if search_input.property_definition_ids: + property_filter = and_( + property_alias.text_value.ilike(search_pattern), + property_alias.definition_id.in_( + search_input.property_definition_ids + ), + ) + else: + property_filter = ( + property_alias.text_value.ilike(search_pattern) + ) + + query = query.outerjoin(property_alias, join_condition) + search_conditions.append(property_filter) + + if not search_conditions: + return query + + combined_condition = or_(*search_conditions) + query = query.where(combined_condition) + + if search_input.include_properties: + query = query.distinct() + + return query + + +def full_text_search_query(search_param: str = "search"): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + search_input: FullTextSearchInput | None = kwargs.get(search_param) + + result = await func(*args, **kwargs) + + if not isinstance(result, Select): + return result + + if not search_input or search_input is strawberry.UNSET: + return result + + model_class = result.column_descriptions[0]["entity"] + if not model_class: + return result + + result = apply_full_text_search(result, search_input, model_class) + + return result + + return wrapper + + return decorator diff --git a/backend/api/inputs.py b/backend/api/inputs.py index d5f326a3..1962ae09 100644 --- a/backend/api/inputs.py +++ b/backend/api/inputs.py @@ -170,3 +170,117 @@ class UpdatePropertyDefinitionInput: @strawberry.input class UpdateProfilePictureInput: avatar_url: str + + +@strawberry.enum +class SortDirection(Enum): + ASC = "ASC" + DESC = "DESC" + + +@strawberry.enum +class FilterOperator(Enum): + TEXT_EQUALS = "TEXT_EQUALS" + TEXT_NOT_EQUALS = "TEXT_NOT_EQUALS" + TEXT_NOT_WHITESPACE = "TEXT_NOT_WHITESPACE" + TEXT_CONTAINS = "TEXT_CONTAINS" + TEXT_NOT_CONTAINS = "TEXT_NOT_CONTAINS" + TEXT_STARTS_WITH = "TEXT_STARTS_WITH" + TEXT_ENDS_WITH = "TEXT_ENDS_WITH" + NUMBER_EQUALS = "NUMBER_EQUALS" + NUMBER_NOT_EQUALS = "NUMBER_NOT_EQUALS" + NUMBER_GREATER_THAN = "NUMBER_GREATER_THAN" + NUMBER_GREATER_THAN_OR_EQUAL = "NUMBER_GREATER_THAN_OR_EQUAL" + NUMBER_LESS_THAN = "NUMBER_LESS_THAN" + NUMBER_LESS_THAN_OR_EQUAL = "NUMBER_LESS_THAN_OR_EQUAL" + NUMBER_BETWEEN = "NUMBER_BETWEEN" + NUMBER_NOT_BETWEEN = "NUMBER_NOT_BETWEEN" + DATE_EQUALS = "DATE_EQUALS" + DATE_NOT_EQUALS = "DATE_NOT_EQUALS" + DATE_GREATER_THAN = "DATE_GREATER_THAN" + DATE_GREATER_THAN_OR_EQUAL = "DATE_GREATER_THAN_OR_EQUAL" + DATE_LESS_THAN = "DATE_LESS_THAN" + DATE_LESS_THAN_OR_EQUAL = "DATE_LESS_THAN_OR_EQUAL" + DATE_BETWEEN = "DATE_BETWEEN" + DATE_NOT_BETWEEN = "DATE_NOT_BETWEEN" + DATETIME_EQUALS = "DATETIME_EQUALS" + DATETIME_NOT_EQUALS = "DATETIME_NOT_EQUALS" + DATETIME_GREATER_THAN = "DATETIME_GREATER_THAN" + DATETIME_GREATER_THAN_OR_EQUAL = "DATETIME_GREATER_THAN_OR_EQUAL" + DATETIME_LESS_THAN = "DATETIME_LESS_THAN" + DATETIME_LESS_THAN_OR_EQUAL = "DATETIME_LESS_THAN_OR_EQUAL" + DATETIME_BETWEEN = "DATETIME_BETWEEN" + DATETIME_NOT_BETWEEN = "DATETIME_NOT_BETWEEN" + BOOLEAN_IS_TRUE = "BOOLEAN_IS_TRUE" + BOOLEAN_IS_FALSE = "BOOLEAN_IS_FALSE" + TAGS_EQUALS = "TAGS_EQUALS" + TAGS_NOT_EQUALS = "TAGS_NOT_EQUALS" + TAGS_CONTAINS = "TAGS_CONTAINS" + TAGS_NOT_CONTAINS = "TAGS_NOT_CONTAINS" + TAGS_SINGLE_EQUALS = "TAGS_SINGLE_EQUALS" + TAGS_SINGLE_NOT_EQUALS = "TAGS_SINGLE_NOT_EQUALS" + TAGS_SINGLE_CONTAINS = "TAGS_SINGLE_CONTAINS" + TAGS_SINGLE_NOT_CONTAINS = "TAGS_SINGLE_NOT_CONTAINS" + IS_NULL = "IS_NULL" + IS_NOT_NULL = "IS_NOT_NULL" + + +@strawberry.enum +class ColumnType(Enum): + DIRECT_ATTRIBUTE = "DIRECT_ATTRIBUTE" + PROPERTY = "PROPERTY" + + +@strawberry.input +class FilterParameter: + search_text: str | None = None + is_case_sensitive: bool = False + compare_value: float | None = None + min: float | None = None + max: float | None = None + compare_date: date | None = None + min_date: date | None = None + max_date: date | None = None + compare_date_time: datetime | None = None + min_date_time: datetime | None = None + max_date_time: datetime | None = None + search_tags: list[str] | None = None + property_definition_id: str | None = None + + +@strawberry.input +class SortInput: + column: str + direction: SortDirection + column_type: ColumnType = ColumnType.DIRECT_ATTRIBUTE + property_definition_id: str | None = None + + +@strawberry.input +class FilterInput: + column: str + operator: FilterOperator + parameter: FilterParameter + column_type: ColumnType = ColumnType.DIRECT_ATTRIBUTE + property_definition_id: str | None = None + + +@strawberry.input +class PaginationInput: + page_index: int = 0 + page_size: int | None = None + + +@strawberry.input +class QueryOptionsInput: + sorting: list[SortInput] | None = None + filtering: list[FilterInput] | None = None + pagination: PaginationInput | None = None + + +@strawberry.input +class FullTextSearchInput: + search_text: str + search_columns: list[str] | None = None + include_properties: bool = False + property_definition_ids: list[str] | None = None diff --git a/backend/api/resolvers/patient.py b/backend/api/resolvers/patient.py index c8e9d404..cb018e20 100644 --- a/backend/api/resolvers/patient.py +++ b/backend/api/resolvers/patient.py @@ -3,7 +3,22 @@ import strawberry from api.audit import audit_log from api.context import Info -from api.decorators.pagination import apply_pagination +from api.decorators.filter_sort import ( + apply_filtering, + apply_sorting, + filtered_and_sorted_query, +) +from api.decorators.full_text_search import ( + apply_full_text_search, + full_text_search_query, +) +from api.inputs import ( + ColumnType, + FilterInput, + FullTextSearchInput, + PaginationInput, + SortInput, +) from api.inputs import CreatePatientInput, PatientState, UpdatePatientInput from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver from api.services.authorization import AuthorizationService @@ -14,48 +29,20 @@ from api.types.patient import PatientType from database import models from graphql import GraphQLError -from sqlalchemy import desc, func, select +from sqlalchemy import func, select from sqlalchemy.orm import aliased, selectinload +from sqlalchemy.sql import Select @strawberry.type class PatientQuery: - @strawberry.field - async def patient( - self, - info: Info, - id: strawberry.ID, - ) -> PatientType | None: - result = await info.context.db.execute( - select(models.Patient) - .where(models.Patient.id == id) - .where(models.Patient.deleted.is_(False)) - .options( - selectinload(models.Patient.assigned_locations), - selectinload(models.Patient.tasks), - selectinload(models.Patient.teams), - ), - ) - patient = result.scalars().first() - if patient: - auth_service = AuthorizationService(info.context.db) - if not await auth_service.can_access_patient(info.context.user, patient, info.context): - raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", - extensions={"code": "FORBIDDEN"}, - ) - return patient - - @strawberry.field - async def patients( - self, + @staticmethod + async def _build_patients_base_query( info: Info, location_node_id: strawberry.ID | None = None, root_location_ids: list[strawberry.ID] | None = None, states: list[PatientState] | None = None, - limit: int | None = None, - offset: int | None = None, - ) -> list[PatientType]: + ) -> tuple[Select, list[strawberry.ID]]: query = select(models.Patient).options( selectinload(models.Patient.assigned_locations), selectinload(models.Patient.tasks), @@ -70,12 +57,14 @@ async def patients( models.Patient.state == PatientState.ADMITTED.value ) auth_service = AuthorizationService(info.context.db) - accessible_location_ids = await auth_service.get_user_accessible_location_ids( - info.context.user, info.context + accessible_location_ids = ( + await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) ) if not accessible_location_ids: - return [] + return query.where(False), [] query = auth_service.filter_patients_by_access( info.context.user, query, accessible_location_ids @@ -85,7 +74,10 @@ async def patients( if location_node_id: if location_node_id not in accessible_location_ids: raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", + ( + "Insufficient permission. Please contact an administrator " + "if you believe this is an error." + ), extensions={"code": "FORBIDDEN"}, ) filter_cte = ( @@ -99,9 +91,11 @@ async def patients( ) filter_cte = filter_cte.union_all(children) elif root_location_ids: - valid_root_location_ids = [lid for lid in root_location_ids if lid in accessible_location_ids] + valid_root_location_ids = [ + lid for lid in root_location_ids if lid in accessible_location_ids + ] if not valid_root_location_ids: - return [] + return query.where(False), [] root_location_ids = valid_root_location_ids filter_cte = ( select(models.LocationNode.id) @@ -130,32 +124,170 @@ async def patients( (models.Patient.clinic_id.in_(select(filter_cte.c.id))) | ( models.Patient.position_id.isnot(None) - & models.Patient.position_id.in_(select(filter_cte.c.id)) + & models.Patient.position_id.in_( + select(filter_cte.c.id) + ) ) | ( models.Patient.assigned_location_id.isnot(None) - & models.Patient.assigned_location_id.in_(select(filter_cte.c.id)) + & models.Patient.assigned_location_id.in_( + select(filter_cte.c.id) + ) + ) + | ( + patient_locations_filter.c.location_id.in_( + select(filter_cte.c.id) + ) + ) + | ( + patient_teams_filter.c.location_id.in_( + select(filter_cte.c.id) + ) ) - | (patient_locations_filter.c.location_id.in_(select(filter_cte.c.id))) - | (patient_teams_filter.c.location_id.in_(select(filter_cte.c.id))) ) .distinct() ) - query = apply_pagination(query, limit=limit, offset=offset) + return query, accessible_location_ids + + @strawberry.field + async def patient( + self, + info: Info, + id: strawberry.ID, + ) -> PatientType | None: + result = await info.context.db.execute( + select(models.Patient) + .where(models.Patient.id == id) + .where(models.Patient.deleted.is_(False)) + .options( + selectinload(models.Patient.assigned_locations), + selectinload(models.Patient.tasks), + selectinload(models.Patient.teams), + ), + ) + patient = result.scalars().first() + if patient: + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient( + info.context.user, patient, info.context + ): + raise GraphQLError( + ( + "Insufficient permission. Please contact an administrator " + "if you believe this is an error." + ), + extensions={"code": "FORBIDDEN"}, + ) + return patient + + @strawberry.field + @filtered_and_sorted_query() + @full_text_search_query() + async def patients( + self, + info: Info, + location_node_id: strawberry.ID | None = None, + root_location_ids: list[strawberry.ID] | None = None, + states: list[PatientState] | None = None, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + pagination: PaginationInput | None = None, + search: FullTextSearchInput | None = None, + ) -> list[PatientType]: + query, _ = await PatientQuery._build_patients_base_query( + info, location_node_id, root_location_ids, states + ) + return query + + @strawberry.field + async def patientsTotal( + self, + info: Info, + location_node_id: strawberry.ID | None = None, + root_location_ids: list[strawberry.ID] | None = None, + states: list[PatientState] | None = None, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + search: FullTextSearchInput | None = None, + ) -> int: + query, _ = await PatientQuery._build_patients_base_query( + info, location_node_id, root_location_ids, states + ) + + if search and search is not strawberry.UNSET: + query = apply_full_text_search(query, search, models.Patient) + + if filtering: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for f in filtering: + if ( + f.column_type == ColumnType.PROPERTY + and f.property_definition_id + ): + property_def_ids.add(f.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type + for prop_def in prop_defs + } + + query = apply_filtering( + query, filtering, models.Patient, property_field_types + ) + + if sorting: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for s in sorting: + if ( + s.column_type == ColumnType.PROPERTY + and s.property_definition_id + ): + property_def_ids.add(s.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type for prop_def in prop_defs + } - result = await info.context.db.execute(query) - return result.scalars().all() + query = apply_sorting(query, sorting, models.Patient, property_field_types) + + subquery = query.subquery() + count_query = select(func.count(func.distinct(subquery.c.id))) + result = await info.context.db.execute(count_query) + return result.scalar() or 0 @strawberry.field + @filtered_and_sorted_query() + @full_text_search_query() async def recent_patients( self, info: Info, - limit: int = 5, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + pagination: PaginationInput | None = None, + search: FullTextSearchInput | None = None, ) -> list[PatientType]: auth_service = AuthorizationService(info.context.db) - accessible_location_ids = await auth_service.get_user_accessible_location_ids( - info.context.user, info.context + accessible_location_ids = ( + await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) ) if not accessible_location_ids: @@ -171,7 +303,7 @@ async def recent_patients( ) query = ( - select(models.Patient, max_task_update_date.c.max_update_date) + select(models.Patient) .options( selectinload(models.Patient.assigned_locations), selectinload(models.Patient.tasks), @@ -182,14 +314,108 @@ async def recent_patients( models.Patient.id == max_task_update_date.c.patient_id, ) .where(models.Patient.deleted.is_(False)) - .order_by(desc(max_task_update_date.c.max_update_date), desc(models.Patient.id)) - .limit(limit) ) query = auth_service.filter_patients_by_access( info.context.user, query, accessible_location_ids ) - result = await info.context.db.execute(query) - return [row[0] for row in result.all()] + + return query + + @strawberry.field + async def recentPatientsTotal( + self, + info: Info, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + search: FullTextSearchInput | None = None, + ) -> int: + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = ( + await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + ) + + if not accessible_location_ids: + return 0 + + max_task_update_date = ( + select( + func.max(models.Task.update_date).label("max_update_date"), + models.Task.patient_id.label("patient_id"), + ) + .group_by(models.Task.patient_id) + .subquery() + ) + + query = ( + select(models.Patient) + .outerjoin( + max_task_update_date, + models.Patient.id == max_task_update_date.c.patient_id, + ) + .where(models.Patient.deleted.is_(False)) + ) + query = auth_service.filter_patients_by_access( + info.context.user, query, accessible_location_ids + ) + + if search and search is not strawberry.UNSET: + query = apply_full_text_search(query, search, models.Patient) + + if filtering: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for f in filtering: + if ( + f.column_type == ColumnType.PROPERTY + and f.property_definition_id + ): + property_def_ids.add(f.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type + for prop_def in prop_defs + } + + query = apply_filtering( + query, filtering, models.Patient, property_field_types + ) + + if sorting: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for s in sorting: + if ( + s.column_type == ColumnType.PROPERTY + and s.property_definition_id + ): + property_def_ids.add(s.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type for prop_def in prop_defs + } + + query = apply_sorting(query, sorting, models.Patient, property_field_types) + + subquery = query.subquery() + count_query = select(func.count(func.distinct(subquery.c.id))) + result = await info.context.db.execute(count_query) + return result.scalar() or 0 @strawberry.type @@ -216,19 +442,27 @@ async def create_patient( ) auth_service = AuthorizationService(db) - accessible_location_ids = await auth_service.get_user_accessible_location_ids( - info.context.user, info.context + accessible_location_ids = ( + await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) ) if not accessible_location_ids: raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", + ( + "Insufficient permission. Please contact an " + "administrator if you believe this is an error." + ), extensions={"code": "FORBIDDEN"}, ) if data.clinic_id not in accessible_location_ids: raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", + ( + "Insufficient permission. Please contact an administrator " + "if you believe this is an error." + ), extensions={"code": "FORBIDDEN"}, ) @@ -247,7 +481,10 @@ async def create_patient( for team_id in data.team_ids: if team_id not in accessible_location_ids: raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", + ( + "Insufficient permission. Please contact an " + "administrator if you believe this is an error." + ), extensions={"code": "FORBIDDEN"}, ) teams = await location_service.validate_and_get_teams( @@ -281,9 +518,15 @@ async def create_patient( ) new_patient.assigned_locations = locations elif data.assigned_location_id: - if data.assigned_location_id not in accessible_location_ids: + if ( + data.assigned_location_id + not in accessible_location_ids + ): raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", + ( + "Insufficient permission. Please contact an " + "administrator if you believe this is an error." + ), extensions={"code": "FORBIDDEN"}, ) location = await location_service.get_location_by_id( @@ -327,7 +570,9 @@ async def update_patient( raise Exception("Patient not found") auth_service = AuthorizationService(db) - if not await auth_service.can_access_patient(info.context.user, patient, info.context): + if not await auth_service.can_access_patient( + info.context.user, patient, info.context + ): raise GraphQLError( "Forbidden: You do not have access to this patient", extensions={"code": "FORBIDDEN"}, @@ -348,8 +593,10 @@ async def update_patient( patient.description = data.description location_service = PatientMutation._get_location_service(db) - accessible_location_ids = await auth_service.get_user_accessible_location_ids( - info.context.user + accessible_location_ids = ( + await auth_service.get_user_accessible_location_ids( + info.context.user + ) ) if data.clinic_id is not None: @@ -367,7 +614,10 @@ async def update_patient( else: if data.position_id not in accessible_location_ids: raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", + ( + "Insufficient permission. Please contact an " + "administrator if you believe this is an error." + ), extensions={"code": "FORBIDDEN"}, ) await location_service.validate_and_get_position( @@ -401,9 +651,15 @@ async def update_patient( ) patient.assigned_locations = locations elif data.assigned_location_id is not None: - if data.assigned_location_id not in accessible_location_ids: + if ( + data.assigned_location_id + not in accessible_location_ids + ): raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", + ( + "Insufficient permission. Please contact an " + "administrator if you believe this is an error." + ), extensions={"code": "FORBIDDEN"}, ) location = await location_service.get_location_by_id( @@ -441,7 +697,9 @@ async def delete_patient(self, info: Info, id: strawberry.ID) -> bool: return False auth_service = AuthorizationService(db) - if not await auth_service.can_access_patient(info.context.user, patient, info.context): + if not await auth_service.can_access_patient( + info.context.user, patient, info.context + ): raise GraphQLError( "Forbidden: You do not have access to this patient", extensions={"code": "FORBIDDEN"}, @@ -475,7 +733,9 @@ async def _update_patient_state( raise Exception("Patient not found") auth_service = AuthorizationService(db) - if not await auth_service.can_access_patient(info.context.user, patient, info.context): + if not await auth_service.can_access_patient( + info.context.user, patient, info.context + ): raise GraphQLError( "Forbidden: You do not have access to this patient", extensions={"code": "FORBIDDEN"}, diff --git a/backend/api/resolvers/task.py b/backend/api/resolvers/task.py index 03189a48..b2bb58f1 100644 --- a/backend/api/resolvers/task.py +++ b/backend/api/resolvers/task.py @@ -3,7 +3,22 @@ import strawberry from api.audit import audit_log from api.context import Info -from api.decorators.pagination import apply_pagination +from api.decorators.filter_sort import ( + apply_filtering, + apply_sorting, + filtered_and_sorted_query, +) +from api.decorators.full_text_search import ( + apply_full_text_search, + full_text_search_query, +) +from api.inputs import ( + ColumnType, + FilterInput, + FullTextSearchInput, + PaginationInput, + SortInput, +) from api.inputs import CreateTaskInput, UpdateTaskInput from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver from api.services.authorization import AuthorizationService @@ -13,7 +28,7 @@ from api.types.task import TaskType from database import models from graphql import GraphQLError -from sqlalchemy import desc, select +from sqlalchemy import desc, func, select from sqlalchemy.orm import aliased, selectinload @@ -37,6 +52,8 @@ async def task(self, info: Info, id: strawberry.ID) -> TaskType | None: return task @strawberry.field + @filtered_and_sorted_query() + @full_text_search_query() async def tasks( self, info: Info, @@ -44,8 +61,10 @@ async def tasks( assignee_id: strawberry.ID | None = None, assignee_team_id: strawberry.ID | None = None, root_location_ids: list[strawberry.ID] | None = None, - limit: int | None = None, - offset: int | None = None, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + pagination: PaginationInput | None = None, + search: FullTextSearchInput | None = None, ) -> list[TaskType]: auth_service = AuthorizationService(info.context.db) @@ -65,10 +84,7 @@ async def tasks( if assignee_team_id: query = query.where(models.Task.assignee_team_id == assignee_team_id) - query = apply_pagination(query, limit=limit, offset=offset) - - result = await info.context.db.execute(query) - return result.scalars().all() + return query accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context @@ -164,16 +180,185 @@ async def tasks( models.Task.assignee_team_id.in_(select(team_location_cte.c.id)) ) - query = apply_pagination(query, limit=limit, offset=offset) + return query + + @strawberry.field + async def tasksTotal( + self, + info: Info, + patient_id: strawberry.ID | None = None, + assignee_id: strawberry.ID | None = None, + assignee_team_id: strawberry.ID | None = None, + root_location_ids: list[strawberry.ID] | None = None, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + search: FullTextSearchInput | None = None, + ) -> int: + auth_service = AuthorizationService(info.context.db) - result = await info.context.db.execute(query) - return result.scalars().all() + if patient_id: + if not await auth_service.can_access_patient_id(info.context.user, patient_id, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + query = select(models.Task).where(models.Task.patient_id == patient_id) + + if assignee_id: + query = query.where(models.Task.assignee_id == assignee_id) + if assignee_team_id: + query = query.where(models.Task.assignee_team_id == assignee_team_id) + else: + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return 0 + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + if root_location_ids: + invalid_ids = [lid for lid in root_location_ids if lid not in accessible_location_ids] + if invalid_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + root_cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(root_location_ids)) + .cte(name="root_location_descendants", recursive=True) + ) + root_children = select(models.LocationNode.id).join( + root_cte, models.LocationNode.parent_id == root_cte.c.id + ) + root_cte = root_cte.union_all(root_children) + else: + root_cte = cte + + team_location_cte = None + if assignee_team_id: + if assignee_team_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + team_location_cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id == assignee_team_id) + .cte(name="team_location_descendants", recursive=True) + ) + team_children = select(models.LocationNode.id).join( + team_location_cte, models.LocationNode.parent_id == team_location_cte.c.id + ) + team_location_cte = team_location_cte.union_all(team_children) + + query = ( + select(models.Task) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(root_cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(root_cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(root_cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(root_cte.c.id))) + | (patient_teams.c.location_id.in_(select(root_cte.c.id))) + ) + .distinct() + ) + + if assignee_id: + query = query.where(models.Task.assignee_id == assignee_id) + if assignee_team_id: + query = query.where( + models.Task.assignee_team_id.in_(select(team_location_cte.c.id)) + ) + + if search and search is not strawberry.UNSET: + query = apply_full_text_search(query, search, models.Task) + + if filtering: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for f in filtering: + if f.column_type == ColumnType.PROPERTY and f.property_definition_id: + property_def_ids.add(f.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type for prop_def in prop_defs + } + + query = apply_filtering(query, filtering, models.Task, property_field_types) + + if sorting: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for s in sorting: + if s.column_type == ColumnType.PROPERTY and s.property_definition_id: + property_def_ids.add(s.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type for prop_def in prop_defs + } + + query = apply_sorting(query, sorting, models.Task, property_field_types) + + subquery = query.subquery() + count_query = select(func.count(func.distinct(subquery.c.id))) + result = await info.context.db.execute(count_query) + return result.scalar() or 0 @strawberry.field + @filtered_and_sorted_query() + @full_text_search_query() async def recent_tasks( self, info: Info, - limit: int = 10, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + pagination: PaginationInput | None = None, + search: FullTextSearchInput | None = None, ) -> list[TaskType]: auth_service = AuthorizationService(info.context.db) accessible_location_ids = await auth_service.get_user_accessible_location_ids( @@ -224,13 +409,119 @@ async def recent_tasks( | (patient_locations.c.location_id.in_(select(cte.c.id))) | (patient_teams.c.location_id.in_(select(cte.c.id))) ) - .order_by(desc(models.Task.update_date)) - .limit(limit) .distinct() ) - result = await info.context.db.execute(query) - return result.scalars().all() + default_sorting = sorting is None or len(sorting) == 0 + if default_sorting: + query = query.order_by(desc(models.Task.update_date)) + + return query + + @strawberry.field + async def recentTasksTotal( + self, + info: Info, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + search: FullTextSearchInput | None = None, + ) -> int: + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return 0 + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + query = ( + select(models.Task) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) + .distinct() + ) + + if search and search is not strawberry.UNSET: + query = apply_full_text_search(query, search, models.Task) + + if filtering: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for f in filtering: + if f.column_type == ColumnType.PROPERTY and f.property_definition_id: + property_def_ids.add(f.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type for prop_def in prop_defs + } + + query = apply_filtering(query, filtering, models.Task, property_field_types) + + if sorting: + property_field_types: dict[str, str] = {} + property_def_ids = set() + for s in sorting: + if s.column_type == ColumnType.PROPERTY and s.property_definition_id: + property_def_ids.add(s.property_definition_id) + + if property_def_ids: + prop_defs_result = await info.context.db.execute( + select(models.PropertyDefinition).where( + models.PropertyDefinition.id.in_(property_def_ids) + ) + ) + prop_defs = prop_defs_result.scalars().all() + property_field_types = { + str(prop_def.id): prop_def.field_type for prop_def in prop_defs + } + + query = apply_sorting(query, sorting, models.Task, property_field_types) + + subquery = query.subquery() + count_query = select(func.count(func.distinct(subquery.c.id))) + result = await info.context.db.execute(count_query) + return result.scalar() or 0 @strawberry.type @@ -339,7 +630,11 @@ async def update_task( if data.estimated_time is not strawberry.UNSET: task.estimated_time = data.estimated_time - if data.assignee_id is not None and data.assignee_team_id is not strawberry.UNSET and data.assignee_team_id is not None: + if ( + data.assignee_id is not None + and data.assignee_team_id is not strawberry.UNSET + and data.assignee_team_id is not None + ): raise GraphQLError( "Cannot assign both a user and a team. Please assign either a user or a team.", extensions={"code": "BAD_REQUEST"}, diff --git a/backend/api/resolvers/user.py b/backend/api/resolvers/user.py index e9879676..e4d97525 100644 --- a/backend/api/resolvers/user.py +++ b/backend/api/resolvers/user.py @@ -1,6 +1,14 @@ import strawberry from api.context import Info -from api.inputs import UpdateProfilePictureInput +from api.decorators.filter_sort import filtered_and_sorted_query +from api.decorators.full_text_search import full_text_search_query +from api.inputs import ( + FilterInput, + FullTextSearchInput, + PaginationInput, + SortInput, + UpdateProfilePictureInput, +) from api.resolvers.base import BaseMutationResolver from api.types.user import UserType from database import models @@ -18,9 +26,18 @@ async def user(self, info: Info, id: strawberry.ID) -> UserType | None: return result.scalars().first() @strawberry.field - async def users(self, info: Info) -> list[UserType]: - result = await info.context.db.execute(select(models.User)) - return result.scalars().all() + @filtered_and_sorted_query() + @full_text_search_query() + async def users( + self, + info: Info, + filtering: list[FilterInput] | None = None, + sorting: list[SortInput] | None = None, + pagination: PaginationInput | None = None, + search: FullTextSearchInput | None = None, + ) -> list[UserType]: + query = select(models.User) + return query @strawberry.field def me(self, info: Info) -> UserType | None: diff --git a/backend/api/types/pagination.py b/backend/api/types/pagination.py new file mode 100644 index 00000000..5acc30c4 --- /dev/null +++ b/backend/api/types/pagination.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING, Annotated + +import strawberry + +if TYPE_CHECKING: + from api.types.patient import PatientType + from api.types.task import TaskType + + +@strawberry.type +class PaginatedPatientResult: + items: list[Annotated["PatientType", strawberry.lazy("api.types.patient")]] + total_count: int + + +@strawberry.type +class PaginatedTaskResult: + items: list[Annotated["TaskType", strawberry.lazy("api.types.task")]] + total_count: int diff --git a/backend/main.py b/backend/main.py index d8890d92..6b755093 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,10 +7,12 @@ from api.router import AuthedGraphQLRouter from auth import UnauthenticatedRedirect, unauthenticated_redirect_handler from config import ALLOWED_ORIGINS, IS_DEV, LOGGER -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from routers import auth from scaffold import load_scaffold_data +from starlette.requests import ClientDisconnect from strawberry import Schema logger = logging.getLogger(LOGGER) @@ -50,6 +52,14 @@ async def lifespan(app: FastAPI): unauthenticated_redirect_handler, ) + +@app.exception_handler(ClientDisconnect) +async def client_disconnect_handler(request: Request, exc: ClientDisconnect): + logger.debug(f"Client disconnected: {request.url}") + return JSONResponse( + status_code=499, content={"detail": "Client disconnected"} + ) + app.include_router(auth.router) app.include_router(graphql_app, prefix="/graphql") diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 00000000..b08734ba --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "tasks", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/web/api/gql/generated.ts b/web/api/gql/generated.ts index ce8c5933..5ae37a3e 100644 --- a/web/api/gql/generated.ts +++ b/web/api/gql/generated.ts @@ -27,6 +27,11 @@ export type AuditLogType = { userId?: Maybe; }; +export enum ColumnType { + DirectAttribute = 'DIRECT_ATTRIBUTE', + Property = 'PROPERTY' +} + export type CreateLocationNodeInput = { kind: LocationType; parentId?: InputMaybe; @@ -81,6 +86,83 @@ export enum FieldType { FieldTypeUnspecified = 'FIELD_TYPE_UNSPECIFIED' } +export type FilterInput = { + column: Scalars['String']['input']; + columnType?: ColumnType; + operator: FilterOperator; + parameter: FilterParameter; + propertyDefinitionId?: InputMaybe; +}; + +export enum FilterOperator { + BooleanIsFalse = 'BOOLEAN_IS_FALSE', + BooleanIsTrue = 'BOOLEAN_IS_TRUE', + DatetimeBetween = 'DATETIME_BETWEEN', + DatetimeEquals = 'DATETIME_EQUALS', + DatetimeGreaterThan = 'DATETIME_GREATER_THAN', + DatetimeGreaterThanOrEqual = 'DATETIME_GREATER_THAN_OR_EQUAL', + DatetimeLessThan = 'DATETIME_LESS_THAN', + DatetimeLessThanOrEqual = 'DATETIME_LESS_THAN_OR_EQUAL', + DatetimeNotBetween = 'DATETIME_NOT_BETWEEN', + DatetimeNotEquals = 'DATETIME_NOT_EQUALS', + DateBetween = 'DATE_BETWEEN', + DateEquals = 'DATE_EQUALS', + DateGreaterThan = 'DATE_GREATER_THAN', + DateGreaterThanOrEqual = 'DATE_GREATER_THAN_OR_EQUAL', + DateLessThan = 'DATE_LESS_THAN', + DateLessThanOrEqual = 'DATE_LESS_THAN_OR_EQUAL', + DateNotBetween = 'DATE_NOT_BETWEEN', + DateNotEquals = 'DATE_NOT_EQUALS', + IsNotNull = 'IS_NOT_NULL', + IsNull = 'IS_NULL', + NumberBetween = 'NUMBER_BETWEEN', + NumberEquals = 'NUMBER_EQUALS', + NumberGreaterThan = 'NUMBER_GREATER_THAN', + NumberGreaterThanOrEqual = 'NUMBER_GREATER_THAN_OR_EQUAL', + NumberLessThan = 'NUMBER_LESS_THAN', + NumberLessThanOrEqual = 'NUMBER_LESS_THAN_OR_EQUAL', + NumberNotBetween = 'NUMBER_NOT_BETWEEN', + NumberNotEquals = 'NUMBER_NOT_EQUALS', + TagsContains = 'TAGS_CONTAINS', + TagsEquals = 'TAGS_EQUALS', + TagsNotContains = 'TAGS_NOT_CONTAINS', + TagsNotEquals = 'TAGS_NOT_EQUALS', + TagsSingleContains = 'TAGS_SINGLE_CONTAINS', + TagsSingleEquals = 'TAGS_SINGLE_EQUALS', + TagsSingleNotContains = 'TAGS_SINGLE_NOT_CONTAINS', + TagsSingleNotEquals = 'TAGS_SINGLE_NOT_EQUALS', + TextContains = 'TEXT_CONTAINS', + TextEndsWith = 'TEXT_ENDS_WITH', + TextEquals = 'TEXT_EQUALS', + TextNotContains = 'TEXT_NOT_CONTAINS', + TextNotEquals = 'TEXT_NOT_EQUALS', + TextNotWhitespace = 'TEXT_NOT_WHITESPACE', + TextStartsWith = 'TEXT_STARTS_WITH' +} + +export type FilterParameter = { + compareDate?: InputMaybe; + compareDateTime?: InputMaybe; + compareValue?: InputMaybe; + isCaseSensitive?: Scalars['Boolean']['input']; + max?: InputMaybe; + maxDate?: InputMaybe; + maxDateTime?: InputMaybe; + min?: InputMaybe; + minDate?: InputMaybe; + minDateTime?: InputMaybe; + propertyDefinitionId?: InputMaybe; + searchTags?: InputMaybe>; + searchText?: InputMaybe; +}; + +export type FullTextSearchInput = { + includeProperties?: Scalars['Boolean']['input']; + propertyDefinitionIds?: InputMaybe>; + searchColumns?: InputMaybe>; + searchText: Scalars['String']['input']; +}; + export type LocationNodeType = { __typename?: 'LocationNodeType'; children: Array; @@ -252,6 +334,11 @@ export type MutationWaitPatientArgs = { id: Scalars['ID']['input']; }; +export type PaginationInput = { + pageIndex?: Scalars['Int']['input']; + pageSize?: InputMaybe; +}; + export enum PatientState { Admitted = 'ADMITTED', Dead = 'DEAD', @@ -336,11 +423,15 @@ export type Query = { me?: Maybe; patient?: Maybe; patients: Array; + patientsTotal: Scalars['Int']['output']; propertyDefinitions: Array; recentPatients: Array; + recentPatientsTotal: Scalars['Int']['output']; recentTasks: Array; + recentTasksTotal: Scalars['Int']['output']; task?: Maybe; tasks: Array; + tasksTotal: Scalars['Int']['output']; user?: Maybe; users: Array; }; @@ -375,21 +466,53 @@ export type QueryPatientArgs = { export type QueryPatientsArgs = { - limit?: InputMaybe; + filtering?: InputMaybe>; + locationNodeId?: InputMaybe; + pagination?: InputMaybe; + rootLocationIds?: InputMaybe>; + search?: InputMaybe; + sorting?: InputMaybe>; + states?: InputMaybe>; +}; + + +export type QueryPatientsTotalArgs = { + filtering?: InputMaybe>; locationNodeId?: InputMaybe; - offset?: InputMaybe; rootLocationIds?: InputMaybe>; + search?: InputMaybe; + sorting?: InputMaybe>; states?: InputMaybe>; }; export type QueryRecentPatientsArgs = { - limit?: Scalars['Int']['input']; + filtering?: InputMaybe>; + pagination?: InputMaybe; + search?: InputMaybe; + sorting?: InputMaybe>; +}; + + +export type QueryRecentPatientsTotalArgs = { + filtering?: InputMaybe>; + search?: InputMaybe; + sorting?: InputMaybe>; }; export type QueryRecentTasksArgs = { - limit?: Scalars['Int']['input']; + filtering?: InputMaybe>; + pagination?: InputMaybe; + search?: InputMaybe; + sorting?: InputMaybe>; +}; + + +export type QueryRecentTasksTotalArgs = { + filtering?: InputMaybe>; + search?: InputMaybe; + sorting?: InputMaybe>; }; @@ -401,10 +524,23 @@ export type QueryTaskArgs = { export type QueryTasksArgs = { assigneeId?: InputMaybe; assigneeTeamId?: InputMaybe; - limit?: InputMaybe; - offset?: InputMaybe; + filtering?: InputMaybe>; + pagination?: InputMaybe; patientId?: InputMaybe; rootLocationIds?: InputMaybe>; + search?: InputMaybe; + sorting?: InputMaybe>; +}; + + +export type QueryTasksTotalArgs = { + assigneeId?: InputMaybe; + assigneeTeamId?: InputMaybe; + filtering?: InputMaybe>; + patientId?: InputMaybe; + rootLocationIds?: InputMaybe>; + search?: InputMaybe; + sorting?: InputMaybe>; }; @@ -412,12 +548,32 @@ export type QueryUserArgs = { id: Scalars['ID']['input']; }; + +export type QueryUsersArgs = { + filtering?: InputMaybe>; + pagination?: InputMaybe; + search?: InputMaybe; + sorting?: InputMaybe>; +}; + export enum Sex { Female = 'FEMALE', Male = 'MALE', Unknown = 'UNKNOWN' } +export enum SortDirection { + Asc = 'ASC', + Desc = 'DESC' +} + +export type SortInput = { + column: Scalars['String']['input']; + columnType?: ColumnType; + direction: SortDirection; + propertyDefinitionId?: InputMaybe; +}; + export type Subscription = { __typename?: 'Subscription'; locationNodeCreated: Scalars['ID']['output']; @@ -601,10 +757,19 @@ export type GetMyTasksQueryVariables = Exact<{ [key: string]: never; }>; export type GetMyTasksQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, tasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, priority?: string | null, estimatedTime?: number | null, creationDate: any, updateDate?: any | null, patient: { __typename?: 'PatientType', id: string, name: string, assignedLocation?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, assignedLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null }> }, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null, lastOnline?: any | null, isOnline: boolean } | null }> } | null }; -export type GetOverviewDataQueryVariables = Exact<{ [key: string]: never; }>; +export type GetOverviewDataQueryVariables = Exact<{ + recentPatientsFiltering?: InputMaybe | FilterInput>; + recentPatientsSorting?: InputMaybe | SortInput>; + recentPatientsPagination?: InputMaybe; + recentPatientsSearch?: InputMaybe; + recentTasksFiltering?: InputMaybe | FilterInput>; + recentTasksSorting?: InputMaybe | SortInput>; + recentTasksPagination?: InputMaybe; + recentTasksSearch?: InputMaybe; +}>; -export type GetOverviewDataQuery = { __typename?: 'Query', recentPatients: Array<{ __typename?: 'PatientType', id: string, name: string, sex: Sex, birthdate: any, position?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, tasks: Array<{ __typename?: 'TaskType', updateDate?: any | null }> }>, recentTasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, updateDate?: any | null, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null, lastOnline?: any | null, isOnline: boolean } | null, patient: { __typename?: 'PatientType', id: string, name: string, position?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } }> }; +export type GetOverviewDataQuery = { __typename?: 'Query', recentPatientsTotal: number, recentTasksTotal: number, recentPatients: Array<{ __typename?: 'PatientType', id: string, name: string, sex: Sex, birthdate: any, position?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, tasks: Array<{ __typename?: 'TaskType', updateDate?: any | null }>, properties: Array<{ __typename?: 'PropertyValueType', textValue?: string | null, numberValue?: number | null, booleanValue?: boolean | null, dateValue?: any | null, dateTimeValue?: any | null, selectValue?: string | null, multiSelectValues?: Array | null, definition: { __typename?: 'PropertyDefinitionType', id: string, name: string, description?: string | null, fieldType: FieldType, isActive: boolean, allowedEntities: Array, options: Array } }> }>, recentTasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, updateDate?: any | null, priority?: string | null, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null, lastOnline?: any | null, isOnline: boolean } | null, patient: { __typename?: 'PatientType', id: string, name: string, position?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null }, properties: Array<{ __typename?: 'PropertyValueType', textValue?: string | null, numberValue?: number | null, booleanValue?: boolean | null, dateValue?: any | null, dateTimeValue?: any | null, selectValue?: string | null, multiSelectValues?: Array | null, definition: { __typename?: 'PropertyDefinitionType', id: string, name: string, description?: string | null, fieldType: FieldType, isActive: boolean, allowedEntities: Array, options: Array } }> }> }; export type GetPatientQueryVariables = Exact<{ id: Scalars['ID']['input']; @@ -617,12 +782,14 @@ export type GetPatientsQueryVariables = Exact<{ locationId?: InputMaybe; rootLocationIds?: InputMaybe | Scalars['ID']['input']>; states?: InputMaybe | PatientState>; - limit?: InputMaybe; - offset?: InputMaybe; + filtering?: InputMaybe | FilterInput>; + sorting?: InputMaybe | SortInput>; + pagination?: InputMaybe; + search?: InputMaybe; }>; -export type GetPatientsQuery = { __typename?: 'Query', patients: Array<{ __typename?: 'PatientType', id: string, name: string, firstname: string, lastname: string, birthdate: any, sex: Sex, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, assignedLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null }>, clinic: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null } | null }, position?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null } | null } | null, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null } | null }>, tasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, priority?: string | null, estimatedTime?: number | null, creationDate: any, updateDate?: any | null, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null, lastOnline?: any | null, isOnline: boolean } | null, assigneeTeam?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType } | null }>, properties: Array<{ __typename?: 'PropertyValueType', textValue?: string | null, definition: { __typename?: 'PropertyDefinitionType', name: string } }> }> }; +export type GetPatientsQuery = { __typename?: 'Query', patientsTotal: number, patients: Array<{ __typename?: 'PatientType', id: string, name: string, firstname: string, lastname: string, birthdate: any, sex: Sex, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, assignedLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null }>, clinic: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null } | null }, position?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null } | null } | null, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null } | null } | null }>, tasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, priority?: string | null, estimatedTime?: number | null, creationDate: any, updateDate?: any | null, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null, lastOnline?: any | null, isOnline: boolean } | null, assigneeTeam?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType } | null }>, properties: Array<{ __typename?: 'PropertyValueType', textValue?: string | null, numberValue?: number | null, booleanValue?: boolean | null, dateValue?: any | null, dateTimeValue?: any | null, selectValue?: string | null, multiSelectValues?: Array | null, definition: { __typename?: 'PropertyDefinitionType', id: string, name: string, description?: string | null, fieldType: FieldType, isActive: boolean, allowedEntities: Array, options: Array } }> }> }; export type GetTaskQueryVariables = Exact<{ id: Scalars['ID']['input']; @@ -635,12 +802,14 @@ export type GetTasksQueryVariables = Exact<{ rootLocationIds?: InputMaybe | Scalars['ID']['input']>; assigneeId?: InputMaybe; assigneeTeamId?: InputMaybe; - limit?: InputMaybe; - offset?: InputMaybe; + filtering?: InputMaybe | FilterInput>; + sorting?: InputMaybe | SortInput>; + pagination?: InputMaybe; + search?: InputMaybe; }>; -export type GetTasksQuery = { __typename?: 'Query', tasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, priority?: string | null, estimatedTime?: number | null, creationDate: any, updateDate?: any | null, patient: { __typename?: 'PatientType', id: string, name: string, assignedLocation?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, assignedLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null }> }, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null, lastOnline?: any | null, isOnline: boolean } | null, assigneeTeam?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType } | null }> }; +export type GetTasksQuery = { __typename?: 'Query', tasksTotal: number, tasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, priority?: string | null, estimatedTime?: number | null, creationDate: any, updateDate?: any | null, patient: { __typename?: 'PatientType', id: string, name: string, assignedLocation?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, assignedLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null }> }, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null, lastOnline?: any | null, isOnline: boolean } | null, assigneeTeam?: { __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType } | null, properties: Array<{ __typename?: 'PropertyValueType', textValue?: string | null, numberValue?: number | null, booleanValue?: boolean | null, dateValue?: any | null, dateTimeValue?: any | null, selectValue?: string | null, multiSelectValues?: Array | null, definition: { __typename?: 'PropertyDefinitionType', id: string, name: string, description?: string | null, fieldType: FieldType, isActive: boolean, allowedEntities: Array, options: Array } }> }> }; export type GetUserQueryVariables = Exact<{ id: Scalars['ID']['input']; @@ -1057,8 +1226,13 @@ export const useGetMyTasksQuery = < )}; export const GetOverviewDataDocument = ` - query GetOverviewData { - recentPatients(limit: 5) { + query GetOverviewData($recentPatientsFiltering: [FilterInput!], $recentPatientsSorting: [SortInput!], $recentPatientsPagination: PaginationInput, $recentPatientsSearch: FullTextSearchInput, $recentTasksFiltering: [FilterInput!], $recentTasksSorting: [SortInput!], $recentTasksPagination: PaginationInput, $recentTasksSearch: FullTextSearchInput) { + recentPatients( + filtering: $recentPatientsFiltering + sorting: $recentPatientsSorting + pagination: $recentPatientsPagination + search: $recentPatientsSearch + ) { id name sex @@ -1075,14 +1249,43 @@ export const GetOverviewDataDocument = ` tasks { updateDate } + properties { + definition { + id + name + description + fieldType + isActive + allowedEntities + options + } + textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues + } } - recentTasks(limit: 10) { + recentPatientsTotal( + filtering: $recentPatientsFiltering + sorting: $recentPatientsSorting + search: $recentPatientsSearch + ) + recentTasks( + filtering: $recentTasksFiltering + sorting: $recentTasksSorting + pagination: $recentTasksPagination + search: $recentTasksSearch + ) { id title description done dueDate updateDate + priority assignee { id name @@ -1103,7 +1306,30 @@ export const GetOverviewDataDocument = ` } } } + properties { + definition { + id + name + description + fieldType + isActive + allowedEntities + options + } + textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues + } } + recentTasksTotal( + filtering: $recentTasksFiltering + sorting: $recentTasksSorting + search: $recentTasksSearch + ) } `; @@ -1266,13 +1492,15 @@ export const useGetPatientQuery = < )}; export const GetPatientsDocument = ` - query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!], $limit: Int, $offset: Int) { + query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!], $filtering: [FilterInput!], $sorting: [SortInput!], $pagination: PaginationInput, $search: FullTextSearchInput) { patients( locationNodeId: $locationId rootLocationIds: $rootLocationIds states: $states - limit: $limit - offset: $offset + filtering: $filtering + sorting: $sorting + pagination: $pagination + search: $search ) { id name @@ -1394,11 +1622,31 @@ export const GetPatientsDocument = ` } properties { definition { + id name + description + fieldType + isActive + allowedEntities + options } textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues } } + patientsTotal( + locationNodeId: $locationId + rootLocationIds: $rootLocationIds + states: $states + filtering: $filtering + sorting: $sorting + search: $search + ) } `; @@ -1484,13 +1732,15 @@ export const useGetTaskQuery = < )}; export const GetTasksDocument = ` - query GetTasks($rootLocationIds: [ID!], $assigneeId: ID, $assigneeTeamId: ID, $limit: Int, $offset: Int) { + query GetTasks($rootLocationIds: [ID!], $assigneeId: ID, $assigneeTeamId: ID, $filtering: [FilterInput!], $sorting: [SortInput!], $pagination: PaginationInput, $search: FullTextSearchInput) { tasks( rootLocationIds: $rootLocationIds assigneeId: $assigneeId assigneeTeamId: $assigneeTeamId - limit: $limit - offset: $offset + filtering: $filtering + sorting: $sorting + pagination: $pagination + search: $search ) { id title @@ -1538,7 +1788,33 @@ export const GetTasksDocument = ` title kind } + properties { + definition { + id + name + description + fieldType + isActive + allowedEntities + options + } + textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues + } } + tasksTotal( + rootLocationIds: $rootLocationIds + assigneeId: $assigneeId + assigneeTeamId: $assigneeTeamId + filtering: $filtering + sorting: $sorting + search: $search + ) } `; diff --git a/web/api/graphql/GetOverviewData.graphql b/web/api/graphql/GetOverviewData.graphql index 77b8c5e0..1680b377 100644 --- a/web/api/graphql/GetOverviewData.graphql +++ b/web/api/graphql/GetOverviewData.graphql @@ -1,5 +1,5 @@ -query GetOverviewData { - recentPatients(limit: 5) { +query GetOverviewData($recentPatientsFiltering: [FilterInput!], $recentPatientsSorting: [SortInput!], $recentPatientsPagination: PaginationInput, $recentPatientsSearch: FullTextSearchInput, $recentTasksFiltering: [FilterInput!], $recentTasksSorting: [SortInput!], $recentTasksPagination: PaginationInput, $recentTasksSearch: FullTextSearchInput) { + recentPatients(filtering: $recentPatientsFiltering, sorting: $recentPatientsSorting, pagination: $recentPatientsPagination, search: $recentPatientsSearch) { id name sex @@ -16,14 +16,34 @@ query GetOverviewData { tasks { updateDate } + properties { + definition { + id + name + description + fieldType + isActive + allowedEntities + options + } + textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues + } } - recentTasks(limit: 10) { + recentPatientsTotal(filtering: $recentPatientsFiltering, sorting: $recentPatientsSorting, search: $recentPatientsSearch) + recentTasks(filtering: $recentTasksFiltering, sorting: $recentTasksSorting, pagination: $recentTasksPagination, search: $recentTasksSearch) { id title description done dueDate updateDate + priority assignee { id name @@ -44,5 +64,24 @@ query GetOverviewData { } } } + properties { + definition { + id + name + description + fieldType + isActive + allowedEntities + options + } + textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues + } } + recentTasksTotal(filtering: $recentTasksFiltering, sorting: $recentTasksSorting, search: $recentTasksSearch) } diff --git a/web/api/graphql/GetPatients.graphql b/web/api/graphql/GetPatients.graphql index 68435959..7bae64ec 100644 --- a/web/api/graphql/GetPatients.graphql +++ b/web/api/graphql/GetPatients.graphql @@ -1,5 +1,5 @@ -query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!], $limit: Int, $offset: Int) { - patients(locationNodeId: $locationId, rootLocationIds: $rootLocationIds, states: $states, limit: $limit, offset: $offset) { +query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!], $filtering: [FilterInput!], $sorting: [SortInput!], $pagination: PaginationInput, $search: FullTextSearchInput) { + patients(locationNodeId: $locationId, rootLocationIds: $rootLocationIds, states: $states, filtering: $filtering, sorting: $sorting, pagination: $pagination, search: $search) { id name firstname @@ -120,9 +120,22 @@ query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientSta } properties { definition { + id name + description + fieldType + isActive + allowedEntities + options } textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues } } + patientsTotal(locationNodeId: $locationId, rootLocationIds: $rootLocationIds, states: $states, filtering: $filtering, sorting: $sorting, search: $search) } diff --git a/web/api/graphql/GetTasks.graphql b/web/api/graphql/GetTasks.graphql index 50c6a0ab..b431d06e 100644 --- a/web/api/graphql/GetTasks.graphql +++ b/web/api/graphql/GetTasks.graphql @@ -1,5 +1,5 @@ -query GetTasks($rootLocationIds: [ID!], $assigneeId: ID, $assigneeTeamId: ID, $limit: Int, $offset: Int) { - tasks(rootLocationIds: $rootLocationIds, assigneeId: $assigneeId, assigneeTeamId: $assigneeTeamId, limit: $limit, offset: $offset) { +query GetTasks($rootLocationIds: [ID!], $assigneeId: ID, $assigneeTeamId: ID, $filtering: [FilterInput!], $sorting: [SortInput!], $pagination: PaginationInput, $search: FullTextSearchInput) { + tasks(rootLocationIds: $rootLocationIds, assigneeId: $assigneeId, assigneeTeamId: $assigneeTeamId, filtering: $filtering, sorting: $sorting, pagination: $pagination, search: $search) { id title description @@ -46,6 +46,25 @@ query GetTasks($rootLocationIds: [ID!], $assigneeId: ID, $assigneeTeamId: ID, $l title kind } + properties { + definition { + id + name + description + fieldType + isActive + allowedEntities + options + } + textValue + numberValue + booleanValue + dateValue + dateTimeValue + selectValue + multiSelectValues + } } + tasksTotal(rootLocationIds: $rootLocationIds, assigneeId: $assigneeId, assigneeTeamId: $assigneeTeamId, filtering: $filtering, sorting: $sorting, search: $search) } diff --git a/web/api/optimistic-updates/GetPatient.ts b/web/api/optimistic-updates/GetPatient.ts new file mode 100644 index 00000000..b4549d11 --- /dev/null +++ b/web/api/optimistic-updates/GetPatient.ts @@ -0,0 +1,784 @@ +import { useSafeMutation } from '@/hooks/useSafeMutation' +import { fetcher } from '@/api/gql/fetcher' +import { CompleteTaskDocument, ReopenTaskDocument, CreatePatientDocument, AdmitPatientDocument, DischargePatientDocument, DeletePatientDocument, WaitPatientDocument, MarkPatientDeadDocument, UpdatePatientDocument, type CompleteTaskMutation, type CompleteTaskMutationVariables, type ReopenTaskMutation, type ReopenTaskMutationVariables, type CreatePatientMutation, type CreatePatientMutationVariables, type AdmitPatientMutation, type DischargePatientMutation, type DeletePatientMutation, type DeletePatientMutationVariables, type WaitPatientMutation, type MarkPatientDeadMutation, type UpdatePatientMutation, type UpdatePatientMutationVariables, type UpdatePatientInput, PatientState, type FieldType } from '@/api/gql/generated' +import type { GetPatientQuery, GetPatientsQuery, GetGlobalDataQuery } from '@/api/gql/generated' +import { useTasksContext } from '@/hooks/useTasksContext' +import { useQueryClient } from '@tanstack/react-query' + +interface UseOptimisticCompleteTaskMutationParams { + id: string, + onSuccess?: (data: CompleteTaskMutation, variables: CompleteTaskMutationVariables) => void, + onError?: (error: Error, variables: CompleteTaskMutationVariables) => void, +} + +export function useOptimisticCompleteTaskMutation({ + id, + onSuccess, + onError, +}: UseOptimisticCompleteTaskMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(CompleteTaskDocument, variables)() + }, + optimisticUpdate: (variables) => [ + { + queryKey: ['GetPatient', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientQuery | undefined + if (!data?.patient) return oldData + return { + ...data, + patient: { + ...data.patient, + tasks: data.patient.tasks?.map(task => ( + task.id === variables.id ? { ...task, done: true } : task + )) || [] + } + } + } + }, + { + queryKey: ['GetPatients'], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + return { + ...data, + patients: data.patients.map(patient => { + if (patient.id === id && patient.tasks) { + return { + ...patient, + tasks: patient.tasks.map(task => task.id === variables.id ? { ...task, done: true } : task) + } + } + return patient + }) + } + } + }, + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data?.me?.tasks) return oldData + return { + ...data, + me: data.me ? { + ...data.me, + tasks: data.me.tasks.map(task => task.id === variables.id ? { ...task, done: true } : task) + } : null + } + } + }, + ], + affectedQueryKeys: [['GetPatient', { id }], ['GetTasks'], ['GetPatients'], ['GetOverviewData'], ['GetGlobalData']], + onSuccess, + onError, + }) +} + +interface UseOptimisticReopenTaskMutationParams { + id: string, + onSuccess?: (data: ReopenTaskMutation, variables: ReopenTaskMutationVariables) => void, + onError?: (error: Error, variables: ReopenTaskMutationVariables) => void, +} + +export function useOptimisticReopenTaskMutation({ + id, + onSuccess, + onError, +}: UseOptimisticReopenTaskMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(ReopenTaskDocument, variables)() + }, + optimisticUpdate: (variables) => [ + { + queryKey: ['GetPatient', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientQuery | undefined + if (!data?.patient) return oldData + return { + ...data, + patient: { + ...data.patient, + tasks: data.patient.tasks?.map(task => ( + task.id === variables.id ? { ...task, done: false } : task + )) || [] + } + } + } + }, + { + queryKey: ['GetPatients'], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + return { + ...data, + patients: data.patients.map(patient => { + if (patient.id === id && patient.tasks) { + return { + ...patient, + tasks: patient.tasks.map(task => task.id === variables.id ? { ...task, done: false } : task) + } + } + return patient + }) + } + } + }, + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data?.me?.tasks) return oldData + return { + ...data, + me: data.me ? { + ...data.me, + tasks: data.me.tasks.map(task => task.id === variables.id ? { ...task, done: false } : task) + } : null + } + } + }, + ], + affectedQueryKeys: [['GetPatient', { id }], ['GetTasks'], ['GetPatients'], ['GetOverviewData'], ['GetGlobalData']], + onSuccess, + onError, + }) +} + +interface UseOptimisticCreatePatientMutationParams { + onSuccess?: (data: CreatePatientMutation, variables: CreatePatientMutationVariables) => void, + onError?: (error: Error, variables: CreatePatientMutationVariables) => void, + onMutate?: () => void, + onSettled?: () => void, +} + +export function useOptimisticCreatePatientMutation({ + onMutate, + onSettled, + onSuccess, + onError, +}: UseOptimisticCreatePatientMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(CreatePatientDocument, variables)() + }, + optimisticUpdate: (variables) => [ + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + const newPatient = { + __typename: 'PatientType' as const, + id: `temp-${Date.now()}`, + name: `${variables.data.firstname} ${variables.data.lastname}`.trim(), + firstname: variables.data.firstname, + lastname: variables.data.lastname, + birthdate: variables.data.birthdate, + sex: variables.data.sex, + state: variables.data.state || PatientState.Admitted, + assignedLocation: null, + assignedLocations: [], + clinic: null, + position: null, + teams: [], + properties: [], + tasks: [], + } + return { + ...data, + patients: [...(data.patients || []), newPatient], + waitingPatients: variables.data.state === PatientState.Wait + ? [...(data.waitingPatients || []), newPatient] + : data.waitingPatients || [], + } + } + } + ], + affectedQueryKeys: [['GetGlobalData'], ['GetPatients'], ['GetOverviewData']], + onSuccess, + onError, + onMutate, + onSettled, + }) +} + +interface UseOptimisticAdmitPatientMutationParams { + id: string, + onSuccess?: (data: AdmitPatientMutation, variables: { id: string }) => void, + onError?: (error: Error, variables: { id: string }) => void, +} + +export function useOptimisticAdmitPatientMutation({ + id, + onSuccess, + onError, +}: UseOptimisticAdmitPatientMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(AdmitPatientDocument, variables)() + }, + optimisticUpdate: () => [ + { + queryKey: ['GetPatient', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientQuery | undefined + if (!data?.patient) return oldData + return { + ...data, + patient: { + ...data.patient, + state: PatientState.Admitted + } + } + } + }, + { + queryKey: ['GetPatients'], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + return { + ...data, + patients: data.patients.map(p => + p.id === id ? { ...p, state: PatientState.Admitted } : p) + } + } + }, + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + const existingPatient = data.patients.find(p => p.id === id) + const updatedPatient = existingPatient + ? { ...existingPatient, state: PatientState.Admitted } + : { __typename: 'PatientType' as const, id, state: PatientState.Admitted, assignedLocation: null } + return { + ...data, + patients: existingPatient + ? data.patients.map(p => p.id === id ? updatedPatient : p) + : [...data.patients, updatedPatient], + waitingPatients: data.waitingPatients.filter(p => p.id !== id) + } + } + } + ], + affectedQueryKeys: [['GetPatients'], ['GetGlobalData']], + onSuccess, + onError, + }) +} + +interface UseOptimisticDischargePatientMutationParams { + id: string, + onSuccess?: (data: DischargePatientMutation, variables: { id: string }) => void, + onError?: (error: Error, variables: { id: string }) => void, +} + +export function useOptimisticDischargePatientMutation({ + id, + onSuccess, + onError, +}: UseOptimisticDischargePatientMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(DischargePatientDocument, variables)() + }, + optimisticUpdate: () => [ + { + queryKey: ['GetPatient', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientQuery | undefined + if (!data?.patient) return oldData + return { + ...data, + patient: { + ...data.patient, + state: PatientState.Discharged + } + } + } + }, + { + queryKey: ['GetPatients'], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + return { + ...data, + patients: data.patients.map(p => + p.id === id ? { ...p, state: PatientState.Discharged } : p) + } + } + }, + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + const existingPatient = data.patients.find(p => p.id === id) + const updatedPatient = existingPatient + ? { ...existingPatient, state: PatientState.Discharged } + : { __typename: 'PatientType' as const, id, state: PatientState.Discharged, assignedLocation: null } + return { + ...data, + patients: existingPatient + ? data.patients.map(p => p.id === id ? updatedPatient : p) + : [...data.patients, updatedPatient], + waitingPatients: data.waitingPatients.filter(p => p.id !== id) + } + } + } + ], + affectedQueryKeys: [['GetPatients'], ['GetGlobalData']], + onSuccess, + onError, + }) +} + +interface UseOptimisticDeletePatientMutationParams { + onSuccess?: (data: DeletePatientMutation, variables: DeletePatientMutationVariables) => void, + onError?: (error: Error, variables: DeletePatientMutationVariables) => void, +} + +export function useOptimisticDeletePatientMutation({ + onSuccess, + onError, +}: UseOptimisticDeletePatientMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(DeletePatientDocument, variables)() + }, + optimisticUpdate: (variables) => [ + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + return { + ...data, + patients: (data.patients || []).filter(p => p.id !== variables.id), + waitingPatients: (data.waitingPatients || []).filter(p => p.id !== variables.id), + } + } + }, + { + queryKey: ['GetPatients'], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + return { + ...data, + patients: data.patients.filter(p => p.id !== variables.id), + } + } + }, + { + queryKey: ['GetPatient', { id: variables.id }], + updateFn: () => undefined, + } + ], + affectedQueryKeys: [['GetGlobalData'], ['GetPatients'], ['GetOverviewData']], + onSuccess, + onError, + }) +} + +interface UseOptimisticWaitPatientMutationParams { + id: string, + onSuccess?: (data: WaitPatientMutation, variables: { id: string }) => void, + onError?: (error: Error, variables: { id: string }) => void, +} + +export function useOptimisticWaitPatientMutation({ + id, + onSuccess, + onError, +}: UseOptimisticWaitPatientMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(WaitPatientDocument, variables)() + }, + optimisticUpdate: () => [ + { + queryKey: ['GetPatient', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientQuery | undefined + if (!data?.patient) return oldData + return { + ...data, + patient: { + ...data.patient, + state: PatientState.Wait + } + } + } + }, + { + queryKey: ['GetPatients'], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + return { + ...data, + patients: data.patients.map(p => + p.id === id ? { ...p, state: PatientState.Wait } : p) + } + } + }, + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + const existingPatient = data.patients.find(p => p.id === id) + const isAlreadyWaiting = data.waitingPatients.some(p => p.id === id) + const updatedPatient = existingPatient + ? { ...existingPatient, state: PatientState.Wait } + : { __typename: 'PatientType' as const, id, state: PatientState.Wait, assignedLocation: null } + return { + ...data, + patients: existingPatient + ? data.patients.map(p => p.id === id ? updatedPatient : p) + : [...data.patients, updatedPatient], + waitingPatients: isAlreadyWaiting + ? data.waitingPatients + : [...data.waitingPatients, updatedPatient] + } + } + } + ], + affectedQueryKeys: [['GetPatients'], ['GetGlobalData']], + onSuccess, + onError, + }) +} + +interface UseOptimisticMarkPatientDeadMutationParams { + id: string, + onSuccess?: (data: MarkPatientDeadMutation, variables: { id: string }) => void, + onError?: (error: Error, variables: { id: string }) => void, +} + +export function useOptimisticMarkPatientDeadMutation({ + id, + onSuccess, + onError, +}: UseOptimisticMarkPatientDeadMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(MarkPatientDeadDocument, variables)() + }, + optimisticUpdate: () => [ + { + queryKey: ['GetPatient', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientQuery | undefined + if (!data?.patient) return oldData + return { + ...data, + patient: { + ...data.patient, + state: PatientState.Dead + } + } + } + }, + { + queryKey: ['GetPatients'], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + return { + ...data, + patients: data.patients.map(p => + p.id === id ? { ...p, state: PatientState.Dead } : p) + } + } + }, + { + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + const existingPatient = data.patients.find(p => p.id === id) + const updatedPatient = existingPatient + ? { ...existingPatient, state: PatientState.Dead } + : { __typename: 'PatientType' as const, id, state: PatientState.Dead, assignedLocation: null } + return { + ...data, + patients: existingPatient + ? data.patients.map(p => p.id === id ? updatedPatient : p) + : [...data.patients, updatedPatient], + waitingPatients: data.waitingPatients.filter(p => p.id !== id) + } + } + } + ], + affectedQueryKeys: [['GetPatients'], ['GetGlobalData']], + onSuccess, + onError, + }) +} + +interface UseOptimisticUpdatePatientMutationParams { + id: string, + onSuccess?: (data: UpdatePatientMutation, variables: UpdatePatientMutationVariables) => void, + onError?: (error: Error, variables: UpdatePatientMutationVariables) => void, +} + +export function useOptimisticUpdatePatientMutation({ + id, + onSuccess, + onError, +}: UseOptimisticUpdatePatientMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + const queryClient = useQueryClient() + + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(UpdatePatientDocument, variables)() + }, + optimisticUpdate: (variables) => { + const updateData = variables.data || {} + const locationsData = queryClient.getQueryData(['GetLocations']) as { locationNodes?: Array<{ id: string, title: string, kind: string, parentId?: string | null }> } | undefined + type PatientType = NonNullable>>['patient']> + + const updatePatientInQuery = (patient: PatientType, updateData: Partial) => { + if (!patient) return patient + + const updated: typeof patient = { ...patient } + + if (updateData.firstname !== undefined) { + updated.firstname = updateData.firstname || '' + } + if (updateData.lastname !== undefined) { + updated.lastname = updateData.lastname || '' + } + if (updateData.sex !== undefined && updateData.sex !== null) { + updated.sex = updateData.sex + } + if (updateData.birthdate !== undefined) { + updated.birthdate = updateData.birthdate || null + } + if (updateData.description !== undefined) { + updated.description = updateData.description + } + if (updateData.clinicId !== undefined) { + if (updateData.clinicId === null || updateData.clinicId === undefined) { + updated.clinic = null as unknown as typeof patient.clinic + } else { + const clinicLocation = locationsData?.locationNodes?.find(loc => loc.id === updateData.clinicId) + if (clinicLocation) { + updated.clinic = { + ...clinicLocation, + __typename: 'LocationNodeType' as const, + } as typeof patient.clinic + } + } + } + if (updateData.positionId !== undefined) { + if (updateData.positionId === null) { + updated.position = null as typeof patient.position + } else { + const positionLocation = locationsData?.locationNodes?.find(loc => loc.id === updateData.positionId) + if (positionLocation) { + updated.position = { + ...positionLocation, + __typename: 'LocationNodeType' as const, + } as typeof patient.position + } + } + } + if (updateData.teamIds !== undefined) { + const teamLocations = locationsData?.locationNodes?.filter(loc => updateData.teamIds?.includes(loc.id)) || [] + updated.teams = teamLocations.map(team => ({ + ...team, + __typename: 'LocationNodeType' as const, + })) as typeof patient.teams + } + if (updateData.properties !== undefined && updateData.properties !== null) { + const propertyMap = new Map(updateData.properties.map(p => [p.definitionId, p])) + const existingPropertyIds = new Set( + patient.properties?.map(p => p.definition?.id).filter(Boolean) || [] + ) + const newPropertyIds = new Set(updateData.properties.map(p => p.definitionId)) + + const existingProperties = patient.properties + ? patient.properties + .filter(p => newPropertyIds.has(p.definition?.id)) + .map(p => { + const newProp = propertyMap.get(p.definition?.id) + if (!newProp) return p + return { + ...p, + textValue: newProp.textValue ?? p.textValue, + numberValue: newProp.numberValue ?? p.numberValue, + booleanValue: newProp.booleanValue ?? p.booleanValue, + dateValue: newProp.dateValue ?? p.dateValue, + dateTimeValue: newProp.dateTimeValue ?? p.dateTimeValue, + selectValue: newProp.selectValue ?? p.selectValue, + multiSelectValues: newProp.multiSelectValues ?? p.multiSelectValues, + } + }) + : [] + const newProperties = updateData.properties + .filter(p => !existingPropertyIds.has(p.definitionId)) + .map(p => { + const existingProperty = patient?.properties?.find(ep => ep.definition?.id === p.definitionId) + return { + __typename: 'PropertyValueType' as const, + definition: existingProperty?.definition || { + __typename: 'PropertyDefinitionType' as const, + id: p.definitionId, + name: '', + description: null, + fieldType: 'TEXT' as FieldType, + isActive: true, + allowedEntities: [], + options: [], + }, + textValue: p.textValue, + numberValue: p.numberValue, + booleanValue: p.booleanValue, + dateValue: p.dateValue, + dateTimeValue: p.dateTimeValue, + selectValue: p.selectValue, + multiSelectValues: p.multiSelectValues, + } + }) + updated.properties = [...existingProperties, ...newProperties] + } + + return updated + } + + const updates: Array<{ queryKey: unknown[], updateFn: (oldData: unknown) => unknown }> = [] + + updates.push({ + queryKey: ['GetPatient', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientQuery | undefined + if (!data?.patient) return oldData + const updatedPatient = updatePatientInQuery(data.patient, updateData) + return { + ...data, + patient: updatedPatient + } + } + }) + + const allGetPatientsQueries = queryClient.getQueryCache().getAll() + .filter(query => { + const key = query.queryKey + return Array.isArray(key) && key[0] === 'GetPatients' + }) + + for (const query of allGetPatientsQueries) { + updates.push({ + queryKey: [...query.queryKey] as unknown[], + updateFn: (oldData: unknown) => { + const data = oldData as GetPatientsQuery | undefined + if (!data?.patients) return oldData + const patientIndex = data.patients.findIndex(p => p.id === id) + if (patientIndex === -1) return oldData + const patient = data.patients[patientIndex] + if (!patient) return oldData + const updatedPatient = updatePatientInQuery(patient as unknown as PatientType, updateData) + if (!updatedPatient) return oldData + const updatedName = updatedPatient.firstname && updatedPatient.lastname + ? `${updatedPatient.firstname} ${updatedPatient.lastname}`.trim() + : updatedPatient.firstname || updatedPatient.lastname || patient.name || '' + const updatedPatientForList: typeof data.patients[0] = { + ...patient, + firstname: updateData.firstname !== undefined ? (updateData.firstname || '') : patient.firstname, + lastname: updateData.lastname !== undefined ? (updateData.lastname || '') : patient.lastname, + name: updatedName, + sex: updateData.sex !== undefined && updateData.sex !== null ? updateData.sex : patient.sex, + birthdate: updateData.birthdate !== undefined ? (updateData.birthdate || null) : patient.birthdate, + ...('description' in patient && { description: updateData.description !== undefined ? updateData.description : (patient as unknown as PatientType & { description?: string | null }).description }), + clinic: updateData.clinicId !== undefined + ? (updateData.clinicId + ? (locationsData?.locationNodes?.find(loc => loc.id === updateData.clinicId) as typeof patient.clinic || patient.clinic) + : (null as unknown as typeof patient.clinic)) + : patient.clinic, + position: updateData.positionId !== undefined + ? (updateData.positionId + ? (locationsData?.locationNodes?.find(loc => loc.id === updateData.positionId) as typeof patient.position || patient.position) + : (null as unknown as typeof patient.position)) + : patient.position, + teams: updateData.teamIds !== undefined + ? (locationsData?.locationNodes?.filter(loc => updateData.teamIds?.includes(loc.id)).map(team => team as typeof patient.teams[0]) || patient.teams) + : patient.teams, + properties: updateData.properties !== undefined && updateData.properties !== null + ? (updatedPatient.properties || patient.properties) + : patient.properties, + } + return { + ...data, + patients: [ + ...data.patients.slice(0, patientIndex), + updatedPatientForList, + ...data.patients.slice(patientIndex + 1) + ] + } + } + }) + } + + updates.push({ + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + const existingPatient = data.patients.find(p => p.id === id) + if (!existingPatient) return oldData + const updatedPatient = updatePatientInQuery(existingPatient as unknown as PatientType, updateData) + return { + ...data, + patients: data.patients.map(p => p.id === id ? updatedPatient as typeof existingPatient : p) + } + } + }) + + updates.push({ + queryKey: ['GetOverviewData'], + updateFn: (oldData: unknown) => { + return oldData + } + }) + + return updates + }, + affectedQueryKeys: [ + ['GetPatient', { id }], + ['GetPatients'], + ['GetOverviewData'], + ['GetGlobalData'] + ], + onSuccess, + onError, + }) +} diff --git a/web/api/optimistic-updates/GetTask.ts b/web/api/optimistic-updates/GetTask.ts new file mode 100644 index 00000000..21739b71 --- /dev/null +++ b/web/api/optimistic-updates/GetTask.ts @@ -0,0 +1,254 @@ +import { useSafeMutation } from '@/hooks/useSafeMutation' +import { fetcher } from '@/api/gql/fetcher' +import { UpdateTaskDocument, type UpdateTaskMutation, type UpdateTaskMutationVariables, type UpdateTaskInput, type FieldType } from '@/api/gql/generated' +import type { GetTaskQuery, GetTasksQuery, GetGlobalDataQuery } from '@/api/gql/generated' +import { useTasksContext } from '@/hooks/useTasksContext' +import { useQueryClient } from '@tanstack/react-query' + +interface UseOptimisticUpdateTaskMutationParams { + id: string, + onSuccess?: (data: UpdateTaskMutation, variables: UpdateTaskMutationVariables) => void, + onError?: (error: Error, variables: UpdateTaskMutationVariables) => void, +} + +export function useOptimisticUpdateTaskMutation({ + id, + onSuccess, + onError, +}: UseOptimisticUpdateTaskMutationParams) { + const { selectedRootLocationIds } = useTasksContext() + const selectedRootLocationIdsForQuery = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined + const queryClient = useQueryClient() + + return useSafeMutation({ + mutationFn: async (variables) => { + return fetcher(UpdateTaskDocument, variables)() + }, + optimisticUpdate: (variables) => { + const updateData = variables.data || {} + const locationsData = queryClient.getQueryData(['GetLocations']) as { locationNodes?: Array<{ id: string, title: string, kind: string, parentId?: string | null }> } | undefined + const usersData = queryClient.getQueryData(['GetUsers']) as { users?: Array<{ id: string, name: string, avatarUrl?: string | null, lastOnline?: unknown, isOnline?: boolean }> } | undefined + type TaskType = NonNullable>>['task']> + + const updateTaskInQuery = (task: TaskType, updateData: Partial) => { + if (!task) return task + + const updated: typeof task = { ...task } + + if (updateData.title !== undefined) { + updated.title = updateData.title || '' + } + if (updateData.description !== undefined) { + updated.description = updateData.description + } + if (updateData.done !== undefined) { + updated.done = updateData.done ?? false + } + if (updateData.dueDate !== undefined) { + updated.dueDate = updateData.dueDate || null + } + if (updateData.priority !== undefined) { + updated.priority = updateData.priority + } + if (updateData.estimatedTime !== undefined) { + updated.estimatedTime = updateData.estimatedTime + } + if (updateData.assigneeId !== undefined) { + if (updateData.assigneeId === null || updateData.assigneeId === undefined) { + updated.assignee = null as typeof task.assignee + } else { + const user = usersData?.users?.find(u => u.id === updateData.assigneeId) + if (user) { + updated.assignee = { + __typename: 'UserType' as const, + id: user.id, + name: user.name, + avatarUrl: user.avatarUrl, + lastOnline: user.lastOnline, + isOnline: user.isOnline ?? false, + } as typeof task.assignee + } + } + } + if (updateData.assigneeTeamId !== undefined) { + if (updateData.assigneeTeamId === null || updateData.assigneeTeamId === undefined) { + updated.assigneeTeam = null as typeof task.assigneeTeam + } else { + const teamLocation = locationsData?.locationNodes?.find(loc => loc.id === updateData.assigneeTeamId) + if (teamLocation) { + updated.assigneeTeam = { + ...teamLocation, + __typename: 'LocationNodeType' as const, + } as typeof task.assigneeTeam + } + } + } + if (updateData.properties !== undefined && updateData.properties !== null) { + const propertyMap = new Map(updateData.properties.map(p => [p.definitionId, p])) + const existingPropertyIds = new Set( + task.properties?.map(p => p.definition?.id).filter(Boolean) || [] + ) + const newPropertyIds = new Set(updateData.properties.map(p => p.definitionId)) + + const existingProperties = task.properties + ? task.properties + .filter(p => newPropertyIds.has(p.definition?.id)) + .map(p => { + const newProp = propertyMap.get(p.definition?.id) + if (!newProp) return p + return { + ...p, + textValue: newProp.textValue ?? p.textValue, + numberValue: newProp.numberValue ?? p.numberValue, + booleanValue: newProp.booleanValue ?? p.booleanValue, + dateValue: newProp.dateValue ?? p.dateValue, + dateTimeValue: newProp.dateTimeValue ?? p.dateTimeValue, + selectValue: newProp.selectValue ?? p.selectValue, + multiSelectValues: newProp.multiSelectValues ?? p.multiSelectValues, + } + }) + : [] + const newProperties = updateData.properties + .filter(p => !existingPropertyIds.has(p.definitionId)) + .map(p => { + const existingProperty = task?.properties?.find(ep => ep.definition?.id === p.definitionId) + return { + __typename: 'PropertyValueType' as const, + definition: existingProperty?.definition || { + __typename: 'PropertyDefinitionType' as const, + id: p.definitionId, + name: '', + description: null, + fieldType: 'TEXT' as FieldType, + isActive: true, + allowedEntities: [], + options: [], + }, + textValue: p.textValue, + numberValue: p.numberValue, + booleanValue: p.booleanValue, + dateValue: p.dateValue, + dateTimeValue: p.dateTimeValue, + selectValue: p.selectValue, + multiSelectValues: p.multiSelectValues, + } + }) + updated.properties = [...existingProperties, ...newProperties] + } + + return updated + } + + const updates: Array<{ queryKey: unknown[], updateFn: (oldData: unknown) => unknown }> = [] + + updates.push({ + queryKey: ['GetTask', { id }], + updateFn: (oldData: unknown) => { + const data = oldData as GetTaskQuery | undefined + if (!data?.task) return oldData + const updatedTask = updateTaskInQuery(data.task, updateData) + return { + ...data, + task: updatedTask + } + } + }) + + const allGetTasksQueries = queryClient.getQueryCache().getAll() + .filter(query => { + const key = query.queryKey + return Array.isArray(key) && key[0] === 'GetTasks' + }) + + for (const query of allGetTasksQueries) { + updates.push({ + queryKey: [...query.queryKey] as unknown[], + updateFn: (oldData: unknown) => { + const data = oldData as GetTasksQuery | undefined + if (!data?.tasks) return oldData + const taskIndex = data.tasks.findIndex(t => t.id === id) + if (taskIndex === -1) return oldData + const task = data.tasks[taskIndex] + if (!task) return oldData + const updatedTask = updateTaskInQuery(task as unknown as TaskType, updateData) + if (!updatedTask) return oldData + const updatedTaskForList: typeof data.tasks[0] = { + ...task, + title: updateData.title !== undefined ? (updateData.title || '') : task.title, + description: updateData.description !== undefined ? updateData.description : task.description, + done: updateData.done !== undefined ? (updateData.done ?? false) : task.done, + dueDate: updateData.dueDate !== undefined ? (updateData.dueDate || null) : task.dueDate, + priority: updateData.priority !== undefined ? updateData.priority : task.priority, + estimatedTime: updateData.estimatedTime !== undefined ? updateData.estimatedTime : task.estimatedTime, + assignee: updateData.assigneeId !== undefined + ? (updateData.assigneeId + ? (usersData?.users?.find(u => u.id === updateData.assigneeId) ? { + __typename: 'UserType' as const, + id: updateData.assigneeId, + name: usersData.users.find(u => u.id === updateData.assigneeId)!.name, + avatarUrl: usersData.users.find(u => u.id === updateData.assigneeId)!.avatarUrl, + lastOnline: usersData.users.find(u => u.id === updateData.assigneeId)!.lastOnline, + isOnline: usersData.users.find(u => u.id === updateData.assigneeId)!.isOnline ?? false, + } as typeof task.assignee : task.assignee) + : (null as typeof task.assignee)) + : task.assignee, + assigneeTeam: updateData.assigneeTeamId !== undefined + ? (updateData.assigneeTeamId + ? (locationsData?.locationNodes?.find(loc => loc.id === updateData.assigneeTeamId) ? { + __typename: 'LocationNodeType' as const, + id: updateData.assigneeTeamId, + title: locationsData.locationNodes!.find(loc => loc.id === updateData.assigneeTeamId)!.title, + kind: locationsData.locationNodes!.find(loc => loc.id === updateData.assigneeTeamId)!.kind, + } as typeof task.assigneeTeam : task.assigneeTeam) + : (null as typeof task.assigneeTeam)) + : task.assigneeTeam, + } + return { + ...data, + tasks: [ + ...data.tasks.slice(0, taskIndex), + updatedTaskForList, + ...data.tasks.slice(taskIndex + 1) + ] + } + } + }) + } + + updates.push({ + queryKey: ['GetGlobalData', { rootLocationIds: selectedRootLocationIdsForQuery }], + updateFn: (oldData: unknown) => { + const data = oldData as GetGlobalDataQuery | undefined + if (!data) return oldData + const existingTask = data.me?.tasks?.find(t => t.id === id) + if (!existingTask) return oldData + const updatedTask = updateTaskInQuery(existingTask as unknown as TaskType, updateData) + return { + ...data, + me: data.me ? { + ...data.me, + tasks: data.me.tasks?.map(t => t.id === id ? updatedTask as typeof existingTask : t) || [] + } : null + } + } + }) + + updates.push({ + queryKey: ['GetOverviewData'], + updateFn: (oldData: unknown) => { + return oldData + } + }) + + return updates + }, + affectedQueryKeys: [ + ['GetTask', { id }], + ['GetTasks'], + ['GetOverviewData'], + ['GetGlobalData'] + ], + onSuccess, + onError, + }) +} diff --git a/web/components/AuditLogTimeline.tsx b/web/components/AuditLogTimeline.tsx index 61aab981..98850a3a 100644 --- a/web/components/AuditLogTimeline.tsx +++ b/web/components/AuditLogTimeline.tsx @@ -173,7 +173,7 @@ export const AuditLogTimeline: React.FC = ({ caseId, clas onClick={(e) => handleCardClick(index, e)} className={clsx( 'p-4 rounded-lg border-2 transition-all', - 'bg-[rgba(255,255,255,1)] dark:bg-[rgba(55,65,81,1)]', + 'bg-surface-variant text-on-surface', 'border-gray-300 dark:border-gray-600', 'hover:border-primary hover:shadow-md', hasDetails && 'cursor-pointer' @@ -189,7 +189,6 @@ export const AuditLogTimeline: React.FC = ({ caseId, clas <> = ({ className, ...avatarProps }) => { - const size = avatarProps.size || 'md' - const dotSizeClasses = { - sm: 'w-3 h-3', - md: 'w-3.5 h-3.5', - lg: 'w-4 h-4', - xl: 'w-5 h-5', + const size = avatarProps.size || 'sm' + const dotSizeClasses: Record, string> = { + xs: 'w-3 h-3', + sm: 'w-3.5 h-3.5', + md: 'w-4 h-4', + lg: 'w-5 h-5', } - const dotPositionClasses = { + const dotPositionClasses: Record, string> = { + xs: 'bottom-0 right-0', sm: 'bottom-0 right-0', md: 'bottom-0 right-0', lg: 'bottom-0 right-0', - xl: 'bottom-0 right-0', } - const dotBorderClasses = { - sm: 'border-[1.5px]', + const dotBorderClasses: Record, string> = { + xs: 'border-[1.5px]', + sm: 'border-2', md: 'border-2', lg: 'border-2', - xl: 'border-2', } const showOnline = isOnline === true diff --git a/web/components/AvatarStatusComponent.tsx b/web/components/AvatarStatusComponent.tsx index e94ba135..a39e670a 100644 --- a/web/components/AvatarStatusComponent.tsx +++ b/web/components/AvatarStatusComponent.tsx @@ -1,4 +1,5 @@ import React from 'react' +import type { AvatarSize } from '@helpwave/hightide' import { Avatar, type AvatarProps } from '@helpwave/hightide' import clsx from 'clsx' @@ -11,26 +12,26 @@ export const AvatarStatusComponent: React.FC = ({ className, ...avatarProps }) => { - const size = avatarProps.size || 'md' - const dotSizeClasses = { - sm: 'w-3 h-3', - md: 'w-3.5 h-3.5', - lg: 'w-4 h-4', - xl: 'w-5 h-5', + const size = avatarProps.size || 'sm' + const dotSizeClasses: Record, string> = { + xs: 'w-3 h-3', + sm: 'w-3.5 h-3.5', + md: 'w-4 h-4', + lg: 'w-5 h-5', } - const dotPositionClasses = { + const dotPositionClasses: Record, string> = { + xs: 'bottom-0 right-0', sm: 'bottom-0 right-0', md: 'bottom-0 right-0', lg: 'bottom-0 right-0', - xl: 'bottom-0 right-0', } - const dotBorderClasses = { - sm: 'border-[1.5px]', + const dotBorderClasses: Record, string> = { + xs: 'border-[1.5px]', + sm: 'border-2', md: 'border-2', lg: 'border-2', - xl: 'border-2', } const showOnline = isOnline === true @@ -40,7 +41,7 @@ export const AvatarStatusComponent: React.FC = ({
({ - field, - localValue: localData[field], - serverValue: serverData[field], - })).filter(f => JSON.stringify(f.localValue) !== JSON.stringify(f.serverValue)) + field, + localValue: localData[field], + serverValue: serverData[field], + })).filter(f => JSON.stringify(f.localValue) !== JSON.stringify(f.serverValue)) : [] return ( diff --git a/web/components/FeedbackDialog.tsx b/web/components/FeedbackDialog.tsx index b48566d5..a2c97679 100644 --- a/web/components/FeedbackDialog.tsx +++ b/web/components/FeedbackDialog.tsx @@ -1,9 +1,8 @@ -import { useState, useEffect, useRef } from 'react' -import { Dialog, Button, Textarea, FormElementWrapper, Checkbox } from '@helpwave/hightide' +import { useState, useEffect, useRef, useMemo } from 'react' +import { Dialog, Button, Textarea, FormField, FormProvider, Checkbox, useCreateForm, useTranslatedValidators, useFormObserverKey } from '@helpwave/hightide' import { useTasksTranslation, useLocale } from '@/i18n/useTasksTranslation' import { useTasksContext } from '@/hooks/useTasksContext' import { Mic, Pause } from 'lucide-react' -import clsx from 'clsx' interface FeedbackDialogProps { isOpen: boolean, @@ -11,6 +10,12 @@ interface FeedbackDialogProps { hideUrl?: boolean, } +type FeedbackFormValues = { + url?: string, + feedback: string, + isAnonymous: boolean, +} + interface SpeechRecognitionEvent { resultIndex: number, results: Array>, @@ -35,14 +40,78 @@ export const FeedbackDialog = ({ isOpen, onClose, hideUrl = false }: FeedbackDia const translation = useTasksTranslation() const { locale } = useLocale() const { user } = useTasksContext() - const [feedback, setFeedback] = useState('') const [isRecording, setIsRecording] = useState(false) const [isSupported, setIsSupported] = useState(false) - const [isAnonymous, setIsAnonymous] = useState(false) const recognitionRef = useRef(null) const isRecordingRef = useRef(false) const finalTranscriptRef = useRef('') const lastFinalLengthRef = useRef(0) + const validators = useTranslatedValidators() + + + + const form = useCreateForm({ + initialValues: { + url: typeof window !== 'undefined' ? window.location.href : '', + feedback: '', + isAnonymous: false, + }, + validators:{ + feedback: validators.notEmpty + }, + onFormSubmit: async (values) => { + if (!values.feedback.trim()) return + + const feedbackData: { + url?: string, + feedback: string, + timestamp: string, + username: string, + userId?: string, + } = { + feedback: values.feedback.trim(), + timestamp: new Date().toISOString(), + username: values.isAnonymous ? 'Anonymous' : (user?.name || 'Unknown User'), + userId: values.isAnonymous ? undefined : user?.id, + } + + if (!hideUrl) { + feedbackData.url = values.url || (typeof window !== 'undefined' ? window.location.href : '') + } + + try { + const response = await fetch('/api/feedback', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(feedbackData), + }) + + if (response.ok) { + form.update(prev => ({ + ...prev, + feedback: '', + isAnonymous: false, + })) + onClose() + } + } catch { + void 0 + } + }, + }) + + const { update: updateForm } = form + + const isAnonymous = useFormObserverKey({ formStore: form.store, formKey: 'isAnonymous' })?.value ?? false + + const submissionName = useMemo(() => { + if(isAnonymous) { + return translation('anonymous') + } + return user?.name || 'Unknown User' + }, [isAnonymous, translation, user?.name]) useEffect(() => { if (typeof window !== 'undefined') { @@ -91,7 +160,7 @@ export const FeedbackDialog = ({ isOpen, onClose, hideUrl = false }: FeedbackDia ? finalTranscriptRef.current.trim() + '\n\n' + interimTranscript : (finalTranscriptRef.current + interimTranscript).trim() - setFeedback(displayText) + updateForm(prev => ({ ...prev, feedback: displayText })) } recognition.onerror = (event: SpeechRecognitionErrorEvent) => { @@ -117,13 +186,13 @@ export const FeedbackDialog = ({ isOpen, onClose, hideUrl = false }: FeedbackDia setIsRecording(false) } lastFinalLengthRef.current = finalTranscriptRef.current.length - setFeedback(finalTranscriptRef.current.trim()) + updateForm(prev => ({ ...prev, feedback: finalTranscriptRef.current.trim() })) } recognitionRef.current = recognition } } - }, [locale]) + }, [locale, updateForm]) const handleToggleRecording = () => { if (!recognitionRef.current) return @@ -143,8 +212,10 @@ export const FeedbackDialog = ({ isOpen, onClose, hideUrl = false }: FeedbackDia useEffect(() => { if (!isOpen) { - setFeedback('') - setIsAnonymous(false) + updateForm(prev => ({ + ...prev, + feedback: '', + })) finalTranscriptRef.current = '' lastFinalLengthRef.current = 0 if (recognitionRef.current && isRecording) { @@ -153,123 +224,111 @@ export const FeedbackDialog = ({ isOpen, onClose, hideUrl = false }: FeedbackDia setIsRecording(false) } } - }, [isOpen, isRecording]) - - const handleSubmit = async () => { - if (!feedback.trim()) return - - const feedbackData: { - url?: string, - feedback: string, - timestamp: string, - username: string, - userId?: string, - } = { - feedback: feedback.trim(), - timestamp: new Date().toISOString(), - username: isAnonymous ? 'Anonymous' : (user?.name || 'Unknown User'), - userId: user?.id, - } + }, [isOpen, isRecording, updateForm]) - if (!hideUrl) { - feedbackData.url = typeof window !== 'undefined' ? window.location.href : '' - } - - try { - const response = await fetch('/api/feedback', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(feedbackData), - }) - - if (response.ok) { - setFeedback('') - setIsAnonymous(false) - onClose() - } - } catch { - void 0 + useEffect(() => { + if (isOpen && user) { + updateForm(prev => ({ + ...prev, + username: isAnonymous ? 'Anonymous' : (user.name || 'Unknown User'), + userId: isAnonymous ? undefined : user.id, + })) } - } + }, [isOpen, user, updateForm, isAnonymous]) return ( - -
- {!hideUrl && ( - - {() => ( -
- {typeof window !== 'undefined' ? window.location.href : ''} -
+ + +
{ event.preventDefault(); form.submit() }}> +
+ {!hideUrl && ( + + name="url" + label={translation('url')} + > + {({ dataProps }) => ( +
+ {dataProps.value || (typeof window !== 'undefined' ? window.location.href : '')} +
+ )} + )} - - )} - - - {() => ( -
- - - {translation('submitAnonymously') ?? 'Submit anonymously'} - -
- )} -
- - - {() => ( -
-