diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 0064152d3..9548f3567 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -16,6 +16,10 @@ from PartSegCore.utils import BaseModel from PartSegImage import Channel +T = typing.TypeVar("T", bound="AlgorithmDescribeBase") + +TypeT = typing.Type[T] + class AlgorithmDescribeNotFound(Exception): """ @@ -116,7 +120,7 @@ def _partial_abstractmethod(funcobj): class AlgorithmDescribeBaseMeta(ABCMeta): - def __new__(cls, name, bases, attrs, **kwargs): + def __new__(cls, name, bases, attrs, method_from_fun=None, additional_parameters=None, **kwargs): cls2 = super().__new__(cls, name, bases, attrs, **kwargs) if ( not inspect.isabstract(cls2) @@ -125,8 +129,181 @@ def __new__(cls, name, bases, attrs, **kwargs): ): raise RuntimeError("class need to have __argument_class__ set or get_fields functions defined") cls2.__new_style__ = getattr(cls2.get_fields, "__is_partial_abstractmethod__", False) + cls2.__from_function__ = getattr(cls2, "__from_function__", False) + cls2.__abstract_getters__ = {} + cls2.__method_name__ = method_from_fun or getattr(cls2, "__method_name__", None) + cls2.__additional_parameters_name__ = additional_parameters or getattr( + cls2, "__additional_parameters_name__", None + ) + if cls2.__additional_parameters_name__ is None: + cls2.__additional_parameters_name__ = cls._get_calculation_method_params_name(cls2) + + cls2.__support_from_function__ = ( + cls2.__method_name__ is not None and cls2.__additional_parameters_name__ is not None + ) + + cls2.__abstract_getters__, cls2.__support_from_function__ = cls._get_abstract_getters( + cls2, cls2.__support_from_function__, method_from_fun + ) return cls2 + @staticmethod + def _get_abstract_getters( + cls2, support_from_function, calculation_method + ) -> typing.Tuple[typing.Dict[str, typing.Any], bool]: + abstract_getters = {} + if hasattr(cls2, "__abstractmethods__") and cls2.__abstractmethods__: + # get all abstract methods that starts with `get_` + for method_name in cls2.__abstractmethods__: + if method_name.startswith("get_"): + method = getattr(cls2, method_name) + if "return" not in method.__annotations__: + msg = f"Method {method_name} of {cls2.__qualname__} need to have return type defined" + try: + file_name = inspect.getsourcefile(method) + line = inspect.getsourcelines(method)[1] + msg += f" in {file_name}:{line}" + except TypeError: + pass + raise RuntimeError(msg) + + abstract_getters[method_name[4:]] = getattr(cls2, method_name).__annotations__["return"] + elif method_name != calculation_method: + support_from_function = False + return abstract_getters, support_from_function + + @staticmethod + def _get_calculation_method_params_name(cls2) -> typing.Optional[str]: + if cls2.__method_name__ is None: + return None + signature = inspect.signature(getattr(cls2, cls2.__method_name__)) + if "arguments" in signature.parameters: + return "arguments" + if "params" in signature.parameters: + return "params" + if "parameters" in signature.parameters: + return "parameters" + raise RuntimeError(f"Cannot determine arguments parameter name in {cls2.__method_name__}") + + @staticmethod + def _validate_if_all_abstract_getters_are_defined(abstract_getters, kwargs): + abstract_getters_set = set(abstract_getters) + kwargs_set = set(kwargs.keys()) + + if abstract_getters_set != kwargs_set: + # Provide a nice error message with information about what is missing and is obsolete + missing_text = ", ".join(sorted(abstract_getters_set - kwargs_set)) + if missing_text: + missing_text = f"Not all abstract methods are provided, missing: {missing_text}." + else: + missing_text = "" + extra_text = ", ".join(sorted(kwargs_set - abstract_getters_set)) + if extra_text: + extra_text = f"There are extra attributes in call: {extra_text}." + else: + extra_text = "" + + raise ValueError(f"{missing_text} {extra_text}") + + @staticmethod + def _validate_function_parameters(func, method, method_name) -> set: + """ + Validate if all parameters without default values are defined in self.__calculation_method__ + + :param func: function to validate + :return: set of parameters that should be dropped + """ + signature = inspect.signature(func) + base_method_signature = inspect.signature(method) + take_all = False + + for parameter in signature.parameters.values(): + if parameter.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY}: + raise ValueError(f"Function {func} should not have positional only parameters") + if ( + parameter.default is inspect.Parameter.empty + and parameter.name not in base_method_signature.parameters + and parameter.kind != inspect.Parameter.VAR_KEYWORD + ): + raise ValueError(f"Parameter {parameter.name} is not defined in {method_name} method") + + if parameter.kind == inspect.Parameter.VAR_KEYWORD: + take_all = True + + if take_all: + return set() + + return { + parameters.name + for parameters in base_method_signature.parameters.values() + if parameters.name not in signature.parameters + } + + @staticmethod + def _get_argument_class_from_signature(func, argument_name: str): + signature = inspect.signature(func) + if argument_name not in signature.parameters: + return BaseModel + return signature.parameters[argument_name].annotation + + @staticmethod + def _get_parameters_from_signature(func): + signature = inspect.signature(func) + return [parameters.name for parameters in signature.parameters.values()] + + def from_function(self, func=None, **kwargs): + """generate new class from function""" + + # Test if all abstract methods values are provided in kwargs + + if not self.__support_from_function__: + raise RuntimeError("This class does not support from_function method") + + self._validate_if_all_abstract_getters_are_defined(self.__abstract_getters__, kwargs) + + # check if all values have correct type + for key, value in kwargs.items(): + if not isinstance(value, self.__abstract_getters__[key]): + raise TypeError(f"Value for {key} should be {self.__abstract_getters__[key]}") + + def _getter_by_name(name): + def _func(): + return kwargs[name] + + return _func + + parameters_order = self._get_parameters_from_signature(getattr(self, self.__method_name__)) + + def _class_generator(func_): + drop_attr = self._validate_function_parameters( + func_, getattr(self, self.__method_name__), self.__method_name__ + ) + + @wraps(func_) + def _calculate_method(*args, **kwargs_): + for attr, name in zip(args, parameters_order): + if name in kwargs_: + raise ValueError(f"Parameter {name} is defined twice") + kwargs_[name] = attr + + for name in drop_attr: + kwargs_.pop(name, None) + return func_(**kwargs_) + + class_dkt = {f"get_{name}": _getter_by_name(name) for name in self.__abstract_getters__} + + class_dkt[self.__method_name__] = _calculate_method + class_dkt["__argument_class__"] = self._get_argument_class_from_signature( + func_, self.__additional_parameters_name__ + ) + class_dkt["__from_function__"] = True + + return type(func_.__name__.replace("_", " ").title().replace(" ", ""), (self,), class_dkt) + + if func is None: + return _class_generator + return _class_generator(func) + class AlgorithmDescribeBase(ABC, metaclass=AlgorithmDescribeBaseMeta): """ @@ -138,6 +315,11 @@ class AlgorithmDescribeBase(ABC, metaclass=AlgorithmDescribeBaseMeta): __argument_class__: typing.Optional[typing.Type[PydanticBaseModel]] = None __new_style__: bool + def __new__(cls, *args, **kwargs): + if cls.__from_function__: + return getattr(cls, cls.__method_name__)(*args, **kwargs) + return super().__new__(cls) + @classmethod def get_doc_from_fields(cls): resp = "{\n" @@ -150,6 +332,29 @@ def get_doc_from_fields(cls): resp += "}\n" return resp + @classmethod + @typing.overload + def from_function(cls: TypeT, func: typing.Callable[..., typing.Any], **kwargs) -> TypeT: + ... + + @classmethod + @typing.overload + def from_function(cls: TypeT, **kwargs) -> typing.Callable[[typing.Callable[..., typing.Any]], TypeT]: + ... + + @classmethod + def from_function( + cls: TypeT, func=None, **kwargs + ) -> typing.Union[TypeT, typing.Callable[[typing.Callable], TypeT]]: + def _from_function(func_) -> typing.Type["AlgorithmDescribeBase"]: + if "name" not in kwargs: + kwargs["name"] = func_.__name__.replace("_", " ").title() + return AlgorithmDescribeBaseMeta.from_function(cls, func_, **kwargs) + + if func is None: + return _from_function + return _from_function(func) + @classmethod @abstractmethod def get_name(cls) -> str: diff --git a/package/PartSegCore/io_utils.py b/package/PartSegCore/io_utils.py index f9870b93e..f3fce5496 100644 --- a/package/PartSegCore/io_utils.py +++ b/package/PartSegCore/io_utils.py @@ -68,7 +68,7 @@ class SaveBase(AlgorithmDescribeBase, ABC): @classmethod @abstractmethod - def get_short_name(cls): + def get_short_name(cls) -> str: raise NotImplementedError @classmethod @@ -102,10 +102,12 @@ def get_default_extension(cls): @classmethod def need_segmentation(cls): + """If method requires segmentation (ROI) to work, or could work with image only""" return True @classmethod def need_mask(cls): + """If `mask` is required for perform save""" return False @classmethod @@ -132,7 +134,7 @@ class LoadBase(AlgorithmDescribeBase, ABC): @classmethod @abstractmethod - def get_short_name(cls): + def get_short_name(cls) -> str: raise NotImplementedError @classmethod @@ -161,8 +163,7 @@ def get_name_with_suffix(cls): @classmethod def get_extensions(cls) -> typing.List[str]: - match = re.match(r".*\((.*)\)", cls.get_name()) - if match is None: + if (match := re.match(r".*\((.*)\)", cls.get_name())) is None: raise ValueError(f"No extensions found in {cls.get_name()}") extensions = match[1].split(" ") if not all(x.startswith("*.") for x in extensions): @@ -205,7 +206,7 @@ def load_metadata_base(data: typing.Union[str, Path]): try: decoded_data = json.loads(str(data), object_hook=partseg_object_hook) except Exception: # pragma: no cover - raise e # noqa: B904 + raise e from None return decoded_data diff --git a/package/PartSegCore/segmentation/algorithm_base.py b/package/PartSegCore/segmentation/algorithm_base.py index 09106c2e5..6566c0f67 100644 --- a/package/PartSegCore/segmentation/algorithm_base.py +++ b/package/PartSegCore/segmentation/algorithm_base.py @@ -204,7 +204,7 @@ def calculation_run(self, report_fun: Callable[[str, int], None]) -> ROIExtracti raise NotImplementedError @abstractmethod - def get_info_text(self): + def get_info_text(self) -> str: raise NotImplementedError def get_channel(self, channel_idx): diff --git a/package/PartSegCore/segmentation/border_smoothing.py b/package/PartSegCore/segmentation/border_smoothing.py index 77fa98eea..4de40e926 100644 --- a/package/PartSegCore/segmentation/border_smoothing.py +++ b/package/PartSegCore/segmentation/border_smoothing.py @@ -11,7 +11,7 @@ from PartSegCore.utils import BaseModel -class BaseSmoothing(AlgorithmDescribeBase, ABC): +class BaseSmoothing(AlgorithmDescribeBase, ABC, method_from_fun="smooth"): __argument_class__ = BaseModel @classmethod diff --git a/package/PartSegCore/segmentation/noise_filtering.py b/package/PartSegCore/segmentation/noise_filtering.py index 1e8e40ca0..8a1b6e062 100644 --- a/package/PartSegCore/segmentation/noise_filtering.py +++ b/package/PartSegCore/segmentation/noise_filtering.py @@ -1,6 +1,6 @@ import typing import warnings -from abc import ABC +from abc import ABC, abstractmethod from enum import Enum import numpy as np @@ -22,10 +22,11 @@ def __str__(self): return self.name.replace("_", " ") -class NoiseFilteringBase(AlgorithmDescribeBase, ABC): +class NoiseFilteringBase(AlgorithmDescribeBase, ABC, method_from_fun="noise_filter"): """Base class for noise filtering operations""" @classmethod + @abstractmethod def noise_filter(cls, channel: np.ndarray, spacing: typing.Iterable[float], arguments: dict) -> np.ndarray: """ This function need be overloaded in implementation diff --git a/package/PartSegCore/segmentation/threshold.py b/package/PartSegCore/segmentation/threshold.py index 490452587..5889510bb 100644 --- a/package/PartSegCore/segmentation/threshold.py +++ b/package/PartSegCore/segmentation/threshold.py @@ -2,6 +2,7 @@ import warnings from abc import ABC +import mahotas import numpy as np import SimpleITK as sitk from local_migrator import register_class, rename_key, update_argument @@ -36,7 +37,7 @@ class MultipleOtsuThresholdParams(BaseModel): bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16) -class BaseThreshold(AlgorithmDescribeBase, ABC): +class BaseThreshold(AlgorithmDescribeBase, ABC, method_from_fun="calculate_mask"): @classmethod def calculate_mask( cls, @@ -275,6 +276,53 @@ def get_name(cls) -> str: return "Multiple Otsu" +class MahotasThreshold(BaseModel): + ignore_zeros: bool = False + + +@BaseThreshold.from_function() +def riddler_calvard( + data: np.ndarray, mask: np.ndarray, arguments: MahotasThreshold, operator: typing.Callable[[object, object], bool] +): + """ + Riddler-Calvard thresholding algorithm from mahotas + + Parameters + ---------- + data : ndarray + Image data. + mask : ndarray + Mask data. + arguments : MahotasThreshold + method parameter + operator : callable + operator to use for thresholding + + Returns + ------- + mask : ndarray + Computed mask + References + ---------- + .. [1] C. A. Riddler, and G. S. Calvard, "Picture thresholding using an iterative selection method," + """ + + try: + threshold = mahotas.rc(data, ignore_zeros=arguments.ignore_zeros) + except TypeError as e: # pragma: no cover + if "This function only accepts integer types" in e.args[0]: + raise SegmentationLimitException(*e.args) from e + raise + if operator(1, 0): + res = (data >= threshold).astype(np.uint8) + else: + res = (data < threshold).astype(np.uint8) + + if mask is not None: + res = res * (mask > 0) + return res, threshold + + class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold): pass @@ -293,6 +341,7 @@ class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], s ThresholdSelection.register(MomentsThreshold) ThresholdSelection.register(MaximumEntropyThreshold) ThresholdSelection.register(MultipleOtsuThreshold) +ThresholdSelection.register(riddler_calvard) class DoubleThresholdParams(BaseModel): diff --git a/package/PartSegCore/segmentation/watershed.py b/package/PartSegCore/segmentation/watershed.py index a3d41bc30..5bf035cb3 100644 --- a/package/PartSegCore/segmentation/watershed.py +++ b/package/PartSegCore/segmentation/watershed.py @@ -22,7 +22,7 @@ ) -class BaseWatershed(AlgorithmDescribeBase, ABC): +class BaseWatershed(AlgorithmDescribeBase, ABC, method_from_fun="sprawl"): """base class for all sprawl interface""" __argument_class__ = BaseModel diff --git a/package/PartSegCore/utils.py b/package/PartSegCore/utils.py index f2baaf7bf..7c8044d69 100644 --- a/package/PartSegCore/utils.py +++ b/package/PartSegCore/utils.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import suppress +from functools import wraps from types import MethodType import numpy as np @@ -462,3 +463,84 @@ def iterate_names(base_name: str, data_dict, max_length=None) -> typing.Optional if res_name not in data_dict: return res_name return None + + +def _get_translation_dicts_from_signature(signature: inspect.Signature) -> typing.Tuple[dict, dict]: + kwargs_to_field = {} + par_to_class = {} + + for par in signature.parameters.values(): + if par.annotation is par.empty or not issubclass(par.annotation, BaseModel): + raise TypeError(f"This decorator can be used only with functions with BaseModel parameters. {par} is not.") + par_to_class[par.name] = par.annotation + for field in par.annotation.__fields__.values(): + if field.name in kwargs_to_field: + raise TypeError( + f"Two parameters have same field name {field.name}. " + f"{kwargs_to_field[field.name]} and {par.name}." + ) + kwargs_to_field[field.name] = par.name + + return kwargs_to_field, par_to_class + + +def _get_kwargs_from_old_args(old_args_order: typing.List[str], args, kwargs): + if len(args) > len(old_args_order): + raise TypeError("Too many positional arguments, please use keyword argument or update old_args_order") + kwargs = kwargs.copy() + for name, arg in zip(old_args_order, args): + kwargs[name] = arg + return kwargs + + +def _map_kwargs_to_model(kwargs_to_field, par_to_class, kwargs): + class_kwargs = {name: {} for name in par_to_class} + + for key, value in kwargs.items(): + if key in kwargs_to_field: + class_kwargs[kwargs_to_field[key]][key] = value + else: + raise TypeError(f"Unexpected keyword argument {key}") + + return {name: par_to_class[name](**class_kwargs[name]) for name in par_to_class} + + +def kwargs_to_model(func=None, old_args_order: typing.Optional[typing.List[str]] = None): + """ + Decorator for converting kwargs to base models before calling function. + + This decorator is to provide backward compatibility layer. + """ + + if old_args_order is None: + old_args_order = [] + + def decorator(func_): + # get functions parameters type annotation to checks its model fields. + + signature = inspect.signature(func_) + + kwargs_to_field, par_to_class = _get_translation_dicts_from_signature(signature) + + @wraps(func_) + def wrapper(*args, **kwargs): + try: + return func_(*args, **kwargs) + except TypeError as e: + if "unexpected keyword argument" not in str(e): + raise e + kwargs = _get_kwargs_from_old_args(old_args_order, args, kwargs) + kwargs = _map_kwargs_to_model(kwargs_to_field, par_to_class, kwargs) + warnings.warn( + "The function is called using old signature. There is try to fix it. Please update function call", + FutureWarning, + stacklevel=2, + ) + + return func_(**kwargs) + + return wrapper + + if func: + return decorator(func) + return decorator diff --git a/package/tests/test_PartSegCore/segmentation/test_threshold.py b/package/tests/test_PartSegCore/segmentation/test_threshold.py index 65155adcb..0e20167ce 100644 --- a/package/tests/test_PartSegCore/segmentation/test_threshold.py +++ b/package/tests/test_PartSegCore/segmentation/test_threshold.py @@ -12,11 +12,11 @@ ThresholdSelection, ) -square = np.zeros((21, 21)) +square = np.zeros((21, 21), dtype=np.uint32) for i, v in [(2, 1000), (4, 10000), (6, 20000), (7, 21000), (8, 22000), (9, 23000)]: square[i:-i, i:-i] = v -cube = np.zeros((21, 21, 21)) +cube = np.zeros((21, 21, 21), dtype=np.uint32) for i, v in [(2, 1000), (4, 10000), (6, 20000), (7, 21000), (8, 22000), (9, 23000)]: cube[i:-i, i:-i, i:-i] = v @@ -39,7 +39,7 @@ def test_threshold(method: BaseThreshold, data, op, masking): pytest.xfail("IntermodesThreshold sigma problem") raise assert isinstance(data, np.ndarray) - assert isinstance(thr_info, (int, float)) + assert isinstance(thr_info, (int, float, np.integer)), type(thr_info) @pytest.mark.parametrize("method", DoubleThresholdSelection.__register__.values()) @@ -52,5 +52,5 @@ def test_double_threshold(method: BaseThreshold, data, op, masking): data, thr_info = method.calculate_mask(data=data, mask=mask, arguments=method.get_default_values(), operator=op) assert isinstance(data, np.ndarray) - assert isinstance(thr_info[0], (int, float)) - assert isinstance(thr_info[1], (int, float)) + assert isinstance(thr_info[0], (int, float, np.integer)) + assert isinstance(thr_info[1], (int, float, np.integer)) diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 408beb464..d0794bdd8 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -1,5 +1,6 @@ # pylint: disable=no-self-use import typing +from abc import ABC, abstractmethod from enum import Enum import pytest @@ -9,6 +10,7 @@ from PartSegCore.algorithm_describe_base import ( AlgorithmDescribeBase, + AlgorithmDescribeBaseMeta, AlgorithmProperty, AlgorithmSelection, ROIExtractionProfile, @@ -354,6 +356,148 @@ def get_fields(cls) -> typing.List[typing.Union[AlgorithmProperty, str]]: with pytest.warns(FutureWarning, match=r"Class has __argument_class__ defined"): assert SampleSubAlgorithm.get_default_values() == {"name": 1, "name2": 3.0} + def test_generate_class_from_function_lack_of_methods(self): + def sample_function(params: dict) -> dict: + """For test purpose""" + + with pytest.raises(ValueError, match="missing: alpha, info"): + ClassForTestFromFunc.from_function(sample_function) + + with pytest.raises(ValueError, match="missing: info"): + ClassForTestFromFunc.from_function(sample_function, alpha=1.0) + + with pytest.raises(ValueError, match="missing: alpha"): + ClassForTestFromFunc.from_function(sample_function, info="sample") + + with pytest.raises(ValueError, match="missing: alpha, info.*call: info2"): + ClassForTestFromFunc.from_function(sample_function, info2="sample") + + with pytest.raises(ValueError, match="call: additions"): + ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0, additions="sample3") + + def test_missing_return_annotation(self): + with pytest.raises(RuntimeError, match="Method get_sample of .*SampleClass need to have return type defined.*"): + + class SampleClass(AlgorithmDescribeBase): # pylint: disable=unused-variable + @classmethod + @abstractmethod + def get_sample(cls): + raise NotImplementedError + + def test_not_supported_from_function(self): + def sample_function(params: dict) -> dict: + """For test purpose""" + + class SampleClass(AlgorithmDescribeBase): + @classmethod + @abstractmethod + def sample(cls) -> dict: + raise NotImplementedError + + with pytest.raises(RuntimeError, match="This class does not support from_function method"): + SampleClass.from_function(sample_function) + + def test_wrong_type(self): + def func(params: dict) -> dict: + """For test purpose""" + + with pytest.raises(TypeError, match="Value for info should be "): + ClassForTestFromFunc.from_function(func, info=1, name="sample", alpha=1.0) + + def test_generate_class_from_function(self): + def sample_function(params: dict) -> dict: + params["a"] = 1 + return params + + new_cls = ClassForTestFromFunc.from_function(sample_function, name="sample1", info="sample2", alpha=2.0) + assert issubclass(new_cls, ClassForTestFromFunc) + assert new_cls.get_name() == "sample1" + assert new_cls.get_info() == "sample2" + assert new_cls.get_alpha() == 2.0 + assert new_cls.calculate(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} + assert new_cls.calculate(params={"b": 2}) == {"b": 2, "a": 1} + with pytest.raises(ValueError, match="Parameter params is defined twice"): + new_cls.calculate({"a": 1}, params={}) + assert new_cls.__argument_class__ == dict + assert new_cls.__name__ == "SampleFunction" + assert new_cls(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} # pylint: disable=not-callable + assert new_cls({"b": 2}) == {"b": 2, "a": 1} # pylint: disable=not-callable + + def test_generate_class_from_function_without_params(self): + @ClassForTestFromFunc.from_function(info="sample2", alpha=2.0) + def sample_function(scalar: int) -> dict: + return {"a": scalar} + + assert issubclass(sample_function, ClassForTestFromFunc) + assert sample_function.get_name() == "Sample Function" + assert sample_function.__name__ == "SampleFunction" + assert sample_function.calculate(scalar=1, params={"b": 2}) == {"a": 1} + assert sample_function.__argument_class__.__name__ == "BaseModel" + + def test_additional_function_parameter_error(self): + def sample_function(params: dict, beta: float) -> dict: + """for test purpose only""" + + with pytest.raises(ValueError, match="Parameter beta is not defined"): + ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0) + + def test_positional_only_argument(self): + def sample_function(params: dict, /) -> dict: + """for test purpose only""" + + with pytest.raises(ValueError, match="Function .*sample_function.* should not have positional only parameters"): + ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0) + + def test_fom_function_as_decorator(self): + class SampleClass(ABC, metaclass=AlgorithmDescribeBaseMeta): + @classmethod + @abstractmethod + def get_sample(cls) -> str: + raise NotImplementedError + + @classmethod + def get_fields(cls): + raise NotImplementedError + + class SampleClass2(SampleClass, method_from_fun="calculate"): + @classmethod + @abstractmethod + def calculate(cls, a: int, arguments: dict) -> str: + raise NotImplementedError + + @SampleClass2.from_function(sample="aaa") + def calc(a: int) -> str: + return f"aaa {a}" + + assert calc.calculate(a=1, arguments={}) == "aaa 1" + + def test_class_without_user_provided_attributes(self): + class SampleClass(AlgorithmDescribeBase, method_from_fun="calculate", additional_parameters="parameters"): + @classmethod + @abstractmethod + def calculate(cls, a: int, b: int) -> int: + raise NotImplementedError + + @SampleClass.from_function() + def calc(a: int, b: int) -> int: + return a + b + + assert calc.calculate(a=1, b=2) == 3 + + def test_functions_with_kwargs(self): + @ClassForTestFromFunc.from_function(info="sample2", alpha=2.0) + def sample_function(params: dict, **kwargs) -> dict: + params["scalar"] = kwargs["scalar"] + return params + + assert sample_function.calculate(params={"b": 2}, scalar=1) == {"b": 2, "scalar": 1} + + +def test_roi_extraction_profile(): + ROIExtractionProfile(name="aaa", algorithm="aaa", values={}) + with pytest.warns(FutureWarning): + ROIExtractionProfile("aaa", "aaa", {}) + class TestROIExtractionProfile: def test_roi_extraction_profile(self): @@ -370,3 +514,22 @@ def test_pretty_print(self): values=LowerThresholdAlgorithm.get_default_values(), ) assert prof2.pretty_print(AnalysisAlgorithmSelection).count("\n") == 7 + + +class ClassForTestFromFuncBase(AlgorithmDescribeBase): + @classmethod + @abstractmethod + def get_alpha(cls) -> float: + raise NotImplementedError + + +class ClassForTestFromFunc(ClassForTestFromFuncBase, method_from_fun="calculate"): + @classmethod + @abstractmethod + def get_info(cls) -> str: + raise NotImplementedError + + @classmethod + @abstractmethod + def calculate(cls, params: BaseModel, scalar: float) -> dict: + raise NotImplementedError diff --git a/package/tests/test_PartSegCore/test_utils.py b/package/tests/test_PartSegCore/test_utils.py index efb91a3be..50c3d577d 100644 --- a/package/tests/test_PartSegCore/test_utils.py +++ b/package/tests/test_PartSegCore/test_utils.py @@ -13,6 +13,7 @@ ProfileDict, get_callback, iterate_names, + kwargs_to_model, recursive_update_dict, ) @@ -361,3 +362,34 @@ class SampleModel(BaseModel): assert ob["c"] == "3" with pytest.raises(KeyError): ob["d"] # pylint: disable=pointless-statement + + +def test_kwargs_to_model(): + class SampleModel(BaseModel): + a: int + b: float + c: str + + class SampleModel2(BaseModel): + d: int + + @kwargs_to_model(old_args_order=["a", "b"]) + def func(par1: SampleModel, par2: SampleModel2): + return f"{par1.a} {par1.b} {par1.c} {par2.d}" + + assert func(SampleModel(a=1, b=2.0, c="a"), SampleModel2(d=3)) == "1 2.0 a 3" + + with pytest.warns(FutureWarning, match="called using old signature"): + assert func(a=1, b=2.0, c="a", d=4) == "1 2.0 a 4" + + with pytest.warns(FutureWarning, match="called using old signature"): + assert func(1, b=2.0, c="a", d=5) == "1 2.0 a 5" + + with pytest.warns(FutureWarning, match="called using old signature"): + assert func(1, 2.0, c="a", d=6) == "1 2.0 a 6" + + with pytest.raises(TypeError, match="Too many positional arguments"): + func(1, 2.0, "a", d=4) + + with pytest.raises(TypeError, match="Unexpected keyword argument e"): + assert func(a=1, b=2.0, c="a", d=4, e=1) == "1 2.0 a 4" diff --git a/setup.cfg b/setup.cfg index 21a86175a..a819c2909 100644 --- a/setup.cfg +++ b/setup.cfg @@ -177,6 +177,7 @@ exclude_lines = raise NotImplementedError if typing.TYPE_CHECKING raise RuntimeError() + \.\.\. [flake8] max-line-length = 120