From dd0bee53651885ace6140eb86744b94fb3c503e0 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Sun, 1 Jan 2023 19:09:37 +0100 Subject: [PATCH 01/25] initial implementation --- .../PartSegCore/algorithm_describe_base.py | 42 +++++++++++++++++++ .../segmentation/noise_filtering.py | 5 ++- .../test_algorithm_describe_base.py | 35 ++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index d0a4fa957..f32045422 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -119,6 +119,7 @@ def _partial_abstractmethod(funcobj): class AlgorithmDescribeBaseMeta(ABCMeta): def __new__(cls, name, bases, attrs, **kwargs): + calculation_method = kwargs.pop("calculation_method", None) cls2 = super().__new__(cls, name, bases, attrs, **kwargs) if ( not inspect.isabstract(cls2) @@ -127,8 +128,49 @@ def __new__(cls, name, bases, attrs, **kwargs): ): raise RuntimeError("class need to have __argument_class__ set or get_fields functions defined") cls2.__new_style__ = getattr(cls2.get_fields, "__is_partial_abstractmethod__", False) + cls2.__abstract_getters__ = [] + cls2.__calculation_method__ = calculation_method + if hasattr(cls2, "__abstractmethods__") and cls2.__abstractmethods__: + # get all abstract methods that starts with `get_` + abstract_names = [ + method[4:] + for method in cls2.__abstractmethods__ + if method.startswith("get_") and not method.endswith("_fields") + ] + cls2.__abstract_getters__ = abstract_names + return cls2 + def from_function(self, func=None, **kwargs): + """generate new class from function""" + + # Test if all abstract methods values are provided in kwargs + + if set(self.__abstract_getters__) != set(kwargs.keys()): + raise ValueError("Not all abstract methods values are provided") + + def _getter_by_name(name): + def _func(): + return kwargs[name] + + return _func + + def _class_generator(func): + class _Class(self): + def __call__(self, *args, **kwargs): + return func(*args, **kwargs) + + for name in self.__abstract_getters__: + setattr(_Class, f"get_{name}", _getter_by_name(name)) + + setattr(_Class, self.__calculation_method__, func) + + return _Class + + if func is None: + return _class_generator + return _class_generator(func) + class AlgorithmDescribeBase(ABC, metaclass=AlgorithmDescribeBaseMeta): """ diff --git a/package/PartSegCore/segmentation/noise_filtering.py b/package/PartSegCore/segmentation/noise_filtering.py index a060f9941..618f10d6f 100644 --- a/package/PartSegCore/segmentation/noise_filtering.py +++ b/package/PartSegCore/segmentation/noise_filtering.py @@ -1,6 +1,6 @@ import typing import warnings -from abc import ABC +from abc import ABC, abstractmethod from enum import Enum import numpy as np @@ -22,10 +22,11 @@ def __str__(self): return self.name.replace("_", " ") -class NoiseFilteringBase(AlgorithmDescribeBase, ABC): +class NoiseFilteringBase(AlgorithmDescribeBase, ABC, calculation_method="noise_filter"): """Base class for noise filtering operations""" @classmethod + @abstractmethod def noise_filter(cls, channel: np.ndarray, spacing: typing.Iterable[float], arguments: dict) -> np.ndarray: """ This function need be overloaded in implementation diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index c4e71240a..4f5610437 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -1,5 +1,6 @@ # pylint: disable=R0201 import typing +from abc import abstractmethod from enum import Enum import pytest @@ -341,8 +342,42 @@ def get_fields(cls) -> typing.List[typing.Union[AlgorithmProperty, str]]: with pytest.warns(FutureWarning, match=r"Class has __argument_class__ defined"): assert SampleSubAlgorithm.get_default_values() == {"name": 1, "name2": 3.0} + def test_generate_class_from_function_lack_of_methods(self): + def sample_function(params: dict) -> dict: + return params + + with pytest.raises(ValueError, match="Not all abstract methods values are provided"): + ClassForTestFromFunc.from_function(sample_function) + + with pytest.raises(ValueError, match="Not all abstract methods values are provided"): + ClassForTestFromFunc.from_function(sample_function, name="sample") + + with pytest.raises(ValueError, match="Not all abstract methods values are provided"): + ClassForTestFromFunc.from_function(sample_function, info="sample") + + def test_generate_class_from_function(self): + def sample_function(params: dict) -> dict: + return params + + new_cls = ClassForTestFromFunc.from_function(sample_function, name="sample1", info="sample2") + assert issubclass(new_cls, ClassForTestFromFunc) + assert new_cls.get_name() == "sample1" + assert new_cls.get_info() == "sample2" + def test_roi_extraction_profile(): ROIExtractionProfile(name="aaa", algorithm="aaa", values={}) with pytest.warns(FutureWarning): ROIExtractionProfile("aaa", "aaa", {}) + + +class ClassForTestFromFunc(AlgorithmDescribeBase, calculation_method="calculate"): + @classmethod + @abstractmethod + def get_info(cls) -> bool: + raise NotImplementedError() + + @classmethod + @abstractmethod + def calculate(cls, params: dict) -> dict: + raise NotImplementedError() From 423288f1fee4a5a47d2accb9541858d79d29b99b Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 2 Jan 2023 12:14:40 +0100 Subject: [PATCH 02/25] add checks for proper methods and types --- .../PartSegCore/algorithm_describe_base.py | 52 ++++++++++++++----- .../segmentation/algorithm_base.py | 6 +-- .../test_algorithm_describe_base.py | 43 +++++++++++++-- 3 files changed, 81 insertions(+), 20 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index f32045422..160932cae 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -128,16 +128,19 @@ def __new__(cls, name, bases, attrs, **kwargs): ): raise RuntimeError("class need to have __argument_class__ set or get_fields functions defined") cls2.__new_style__ = getattr(cls2.get_fields, "__is_partial_abstractmethod__", False) - cls2.__abstract_getters__ = [] + cls2.__abstract_getters__ = {} cls2.__calculation_method__ = calculation_method + cls2.__support_from_function__ = True if hasattr(cls2, "__abstractmethods__") and cls2.__abstractmethods__: # get all abstract methods that starts with `get_` - abstract_names = [ - method[4:] - for method in cls2.__abstractmethods__ - if method.startswith("get_") and not method.endswith("_fields") - ] - cls2.__abstract_getters__ = abstract_names + for method_name in cls2.__abstractmethods__: + if method_name.startswith("get_"): + if "return" not in getattr(cls2, method_name).__annotations__: + raise RuntimeError(f"Method {method_name} should have return annotation") + + cls2.__abstract_getters__[method_name[4:]] = getattr(cls2, method_name).__annotations__["return"] + elif method_name != calculation_method: + cls2.__support_from_function__ = False return cls2 @@ -146,8 +149,31 @@ def from_function(self, func=None, **kwargs): # Test if all abstract methods values are provided in kwargs - if set(self.__abstract_getters__) != set(kwargs.keys()): - raise ValueError("Not all abstract methods values are provided") + if not self.__support_from_function__: + raise RuntimeError("This class does not support from_function method") + + abstract_getters_set = set(self.__abstract_getters__) + kwargs_set = set(kwargs.keys()) + + if abstract_getters_set != kwargs_set: + # Provide a nice error message with information about what is missing and is obsolete + missing_text = ", ".join(sorted(abstract_getters_set - kwargs_set)) + if missing_text: + missing_text = f"Not all abstract methods are provided, missing: {missing_text}." + else: + missing_text = "" + extra_text = ", ".join(sorted(kwargs_set - abstract_getters_set)) + if extra_text: + extra_text = f"There are extra attributes in call: {extra_text}." + else: + extra_text = "" + + raise ValueError(f"{missing_text} {extra_text}") + + # check if all values have correct type + for key, value in kwargs.items(): + if not isinstance(value, self.__abstract_getters__[key]): + raise TypeError(f"Value for {key} should be {self.__abstract_getters__[key]}") def _getter_by_name(name): def _func(): @@ -155,15 +181,15 @@ def _func(): return _func - def _class_generator(func): + def _class_generator(func_): class _Class(self): - def __call__(self, *args, **kwargs): - return func(*args, **kwargs) + def __call__(self, *args, **kwargs_): + return func_(*args, **kwargs_) for name in self.__abstract_getters__: setattr(_Class, f"get_{name}", _getter_by_name(name)) - setattr(_Class, self.__calculation_method__, func) + setattr(_Class, self.__calculation_method__, func_) return _Class diff --git a/package/PartSegCore/segmentation/algorithm_base.py b/package/PartSegCore/segmentation/algorithm_base.py index 677d5bfe6..dbf372dc2 100644 --- a/package/PartSegCore/segmentation/algorithm_base.py +++ b/package/PartSegCore/segmentation/algorithm_base.py @@ -175,12 +175,12 @@ def mask(self, val: Optional[np.ndarray]): @classmethod @abstractmethod - def support_time(cls): + def support_time(cls) -> bool: raise NotImplementedError() @classmethod @abstractmethod - def support_z(cls): + def support_z(cls) -> bool: raise NotImplementedError() def set_mask(self, mask): @@ -204,7 +204,7 @@ def calculation_run(self, report_fun: Callable[[str, int], None]) -> ROIExtracti raise NotImplementedError() @abstractmethod - def get_info_text(self): + def get_info_text(self) -> str: raise NotImplementedError() def get_channel(self, channel_idx): diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 4f5610437..adec91fdb 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -346,15 +346,50 @@ def test_generate_class_from_function_lack_of_methods(self): def sample_function(params: dict) -> dict: return params - with pytest.raises(ValueError, match="Not all abstract methods values are provided"): + with pytest.raises(ValueError, match="missing: info, name"): ClassForTestFromFunc.from_function(sample_function) - with pytest.raises(ValueError, match="Not all abstract methods values are provided"): + with pytest.raises(ValueError, match="missing: info"): ClassForTestFromFunc.from_function(sample_function, name="sample") - with pytest.raises(ValueError, match="Not all abstract methods values are provided"): + with pytest.raises(ValueError, match="missing: name"): ClassForTestFromFunc.from_function(sample_function, info="sample") + with pytest.raises(ValueError, match="missing: info, name.*call: info2"): + ClassForTestFromFunc.from_function(sample_function, info2="sample") + + with pytest.raises(ValueError, match="call: additions"): + ClassForTestFromFunc.from_function(sample_function, info="sample", name="sample2", additions="sample3") + + def test_missing_return_annotation(self): + with pytest.raises(RuntimeError, match="Method get_sample should have return annotation"): + + class SampleClass(AlgorithmDescribeBase): + @classmethod + @abstractmethod + def get_sample(cls): + raise NotImplementedError() + + def test_not_supported_from_function(self): + def sample_function(params: dict) -> dict: + return params + + class SampleClass(AlgorithmDescribeBase): + @classmethod + @abstractmethod + def sample(cls) -> dict: + raise NotImplementedError() + + with pytest.raises(RuntimeError, match="This class does not support from_function method"): + SampleClass.from_function(sample_function) + + def test_wrong_type(self): + def func(params: dict) -> dict: + return params + + with pytest.raises(TypeError, match="Value for info should be "): + ClassForTestFromFunc.from_function(func, info=1, name="sample") + def test_generate_class_from_function(self): def sample_function(params: dict) -> dict: return params @@ -374,7 +409,7 @@ def test_roi_extraction_profile(): class ClassForTestFromFunc(AlgorithmDescribeBase, calculation_method="calculate"): @classmethod @abstractmethod - def get_info(cls) -> bool: + def get_info(cls) -> str: raise NotImplementedError() @classmethod From 0799fef2f4fe3cc5991ad6698ff8b895e1502c30 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 2 Jan 2023 12:34:04 +0100 Subject: [PATCH 03/25] add name detection base of function name --- .../PartSegCore/algorithm_describe_base.py | 6 +++++ .../test_algorithm_describe_base.py | 24 ++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 160932cae..bcdf689aa 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -220,6 +220,12 @@ def get_doc_from_fields(cls): resp += "}\n" return resp + @classmethod + def from_function(cls, func=None, **kwargs): + if "name" not in kwargs: + kwargs["name"] = func.__name__.replace("_", " ").capitalize() + return AlgorithmDescribeBaseMeta.from_function(cls, func, **kwargs) + @classmethod @abstractmethod def get_name(cls) -> str: diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index adec91fdb..fc866ceba 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -346,20 +346,20 @@ def test_generate_class_from_function_lack_of_methods(self): def sample_function(params: dict) -> dict: return params - with pytest.raises(ValueError, match="missing: info, name"): + with pytest.raises(ValueError, match="missing: alpha, info"): ClassForTestFromFunc.from_function(sample_function) with pytest.raises(ValueError, match="missing: info"): - ClassForTestFromFunc.from_function(sample_function, name="sample") + ClassForTestFromFunc.from_function(sample_function, alpha=1.0) - with pytest.raises(ValueError, match="missing: name"): + with pytest.raises(ValueError, match="missing: alpha"): ClassForTestFromFunc.from_function(sample_function, info="sample") - with pytest.raises(ValueError, match="missing: info, name.*call: info2"): + with pytest.raises(ValueError, match="missing: alpha, info.*call: info2"): ClassForTestFromFunc.from_function(sample_function, info2="sample") with pytest.raises(ValueError, match="call: additions"): - ClassForTestFromFunc.from_function(sample_function, info="sample", name="sample2", additions="sample3") + ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0, additions="sample3") def test_missing_return_annotation(self): with pytest.raises(RuntimeError, match="Method get_sample should have return annotation"): @@ -388,16 +388,17 @@ def func(params: dict) -> dict: return params with pytest.raises(TypeError, match="Value for info should be "): - ClassForTestFromFunc.from_function(func, info=1, name="sample") + ClassForTestFromFunc.from_function(func, info=1, name="sample", alpha=1.0) def test_generate_class_from_function(self): def sample_function(params: dict) -> dict: return params - new_cls = ClassForTestFromFunc.from_function(sample_function, name="sample1", info="sample2") + new_cls = ClassForTestFromFunc.from_function(sample_function, name="sample1", info="sample2", alpha=2.0) assert issubclass(new_cls, ClassForTestFromFunc) assert new_cls.get_name() == "sample1" assert new_cls.get_info() == "sample2" + assert new_cls.get_alpha() == 2.0 def test_roi_extraction_profile(): @@ -406,7 +407,14 @@ def test_roi_extraction_profile(): ROIExtractionProfile("aaa", "aaa", {}) -class ClassForTestFromFunc(AlgorithmDescribeBase, calculation_method="calculate"): +class ClassForTestFromFuncBase(AlgorithmDescribeBase): + @classmethod + @abstractmethod + def get_alpha(cls) -> float: + raise NotImplementedError() + + +class ClassForTestFromFunc(ClassForTestFromFuncBase, calculation_method="calculate"): @classmethod @abstractmethod def get_info(cls) -> str: From 350a363b6c0376c2f7fa130c5f09927f8a5872d3 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 2 Jan 2023 21:08:07 +0100 Subject: [PATCH 04/25] Iprove validation of function parameter --- .../PartSegCore/algorithm_describe_base.py | 56 ++++++++++++++++--- .../test_algorithm_describe_base.py | 16 +++++- 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index bcdf689aa..d8e1a573d 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -144,14 +144,7 @@ def __new__(cls, name, bases, attrs, **kwargs): return cls2 - def from_function(self, func=None, **kwargs): - """generate new class from function""" - - # Test if all abstract methods values are provided in kwargs - - if not self.__support_from_function__: - raise RuntimeError("This class does not support from_function method") - + def _validate_if_all_abstract_getters_are_defined(self, kwargs): abstract_getters_set = set(self.__abstract_getters__) kwargs_set = set(kwargs.keys()) @@ -170,6 +163,41 @@ def from_function(self, func=None, **kwargs): raise ValueError(f"{missing_text} {extra_text}") + def _validate_function_parameters(self, func) -> set: + """ + Validate if all parameters without default values are defined in self.__calculation_method__ + + :param func: function to validate + :return: set of parameters that should be dropped + """ + signature = inspect.signature(func) + base_method_signature = inspect.signature(getattr(self, self.__calculation_method__)) + + for parameters in signature.parameters.values(): + if parameters.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY}: + raise ValueError(f"Function {func} should not have positional only parameters") + if ( + parameters.default is inspect.Parameter.empty + and parameters.name not in base_method_signature.parameters + ): + raise ValueError(f"Parameter {parameters.name} is not defined in {self.__calculation_method__} method") + + return { + parameters.name + for parameters in base_method_signature.parameters.values() + if parameters.name not in signature.parameters + } + + def from_function(self, func=None, **kwargs): + """generate new class from function""" + + # Test if all abstract methods values are provided in kwargs + + if not self.__support_from_function__: + raise RuntimeError("This class does not support from_function method") + + self._validate_if_all_abstract_getters_are_defined(kwargs) + # check if all values have correct type for key, value in kwargs.items(): if not isinstance(value, self.__abstract_getters__[key]): @@ -182,14 +210,24 @@ def _func(): return _func def _class_generator(func_): + + drop_attr = self._validate_function_parameters(func_) + + @wraps(func_) + def _calculate_method(self, **kwargs_): + for name in drop_attr: + kwargs_.pop(name, None) + return func_(self, **kwargs_) + class _Class(self): + @wraps(func_) def __call__(self, *args, **kwargs_): return func_(*args, **kwargs_) for name in self.__abstract_getters__: setattr(_Class, f"get_{name}", _getter_by_name(name)) - setattr(_Class, self.__calculation_method__, func_) + setattr(_Class, self.__calculation_method__, _calculate_method) return _Class diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index fc866ceba..65c6bb6fb 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -400,6 +400,20 @@ def sample_function(params: dict) -> dict: assert new_cls.get_info() == "sample2" assert new_cls.get_alpha() == 2.0 + def test_additional_function_parameter_error(self): + def sample_function(params: dict, beta: float) -> dict: + """for test purpose only""" + + with pytest.raises(ValueError, match="Parameter beta is not defined"): + ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0) + + def test_positional_only_argument(self): + def sample_function(params: dict, /) -> dict: + """for test purpose only""" + + with pytest.raises(ValueError, match="Function .*sample_function.* should not have positional only parameters"): + ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0) + def test_roi_extraction_profile(): ROIExtractionProfile(name="aaa", algorithm="aaa", values={}) @@ -422,5 +436,5 @@ def get_info(cls) -> str: @classmethod @abstractmethod - def calculate(cls, params: dict) -> dict: + def calculate(cls, params: BaseModel, scalar: float) -> dict: raise NotImplementedError() From 526d8f1d355958f544b879a322d47b28ae1a768e Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 12:00:48 +0100 Subject: [PATCH 05/25] add __argument_class__ creation --- .../PartSegCore/algorithm_describe_base.py | 43 +++++++++++++------ .../test_algorithm_describe_base.py | 3 ++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index d8e1a573d..7f0945acf 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -120,6 +120,7 @@ def _partial_abstractmethod(funcobj): class AlgorithmDescribeBaseMeta(ABCMeta): def __new__(cls, name, bases, attrs, **kwargs): calculation_method = kwargs.pop("calculation_method", None) + calculation_method_params_name = kwargs.pop("calculation_method__params_name", None) cls2 = super().__new__(cls, name, bases, attrs, **kwargs) if ( not inspect.isabstract(cls2) @@ -130,7 +131,13 @@ def __new__(cls, name, bases, attrs, **kwargs): cls2.__new_style__ = getattr(cls2.get_fields, "__is_partial_abstractmethod__", False) cls2.__abstract_getters__ = {} cls2.__calculation_method__ = calculation_method - cls2.__support_from_function__ = True + cls2.__calculation_method_params_name__ = calculation_method_params_name + if cls2.__calculation_method_params_name__ is None: + cls2.__calculation_method_params_name__ = cls._get_calculation_method_params_name(cls2) + + cls2.__support_from_function__ = ( + cls2.__calculation_method__ is not None and cls2.__calculation_method_params_name__ is not None + ) if hasattr(cls2, "__abstractmethods__") and cls2.__abstractmethods__: # get all abstract methods that starts with `get_` for method_name in cls2.__abstractmethods__: @@ -144,6 +151,17 @@ def __new__(cls, name, bases, attrs, **kwargs): return cls2 + @staticmethod + def _get_calculation_method_params_name(cls2) -> typing.Optional[str]: + if cls2.__calculation_method__ is None: + return None + signature = inspect.signature(getattr(cls2, cls2.__calculation_method__)) + if "arguments" in signature.parameters: + return "arguments" + if "params" in signature.parameters: + return "params" + raise RuntimeError(f"Cannot determine arguments parameter name in {cls2.__calculation_method__}") + def _validate_if_all_abstract_getters_are_defined(self, kwargs): abstract_getters_set = set(self.__abstract_getters__) kwargs_set = set(kwargs.keys()) @@ -188,6 +206,12 @@ def _validate_function_parameters(self, func) -> set: if parameters.name not in signature.parameters } + def _get_argument_class_from_signature(self, func): + signature = inspect.signature(func) + if self.__calculation_method_params_name__ not in signature.parameters: + return BaseModel + return signature.parameters[self.__calculation_method_params_name__].annotation + def from_function(self, func=None, **kwargs): """generate new class from function""" @@ -214,22 +238,17 @@ def _class_generator(func_): drop_attr = self._validate_function_parameters(func_) @wraps(func_) - def _calculate_method(self, **kwargs_): + def _calculate_method(**kwargs_): for name in drop_attr: kwargs_.pop(name, None) - return func_(self, **kwargs_) - - class _Class(self): - @wraps(func_) - def __call__(self, *args, **kwargs_): - return func_(*args, **kwargs_) + return func_(**kwargs_) - for name in self.__abstract_getters__: - setattr(_Class, f"get_{name}", _getter_by_name(name)) + class_dkt = {f"get_{name}": _getter_by_name(name) for name in self.__abstract_getters__} - setattr(_Class, self.__calculation_method__, _calculate_method) + class_dkt[self.__calculation_method__] = _calculate_method + class_dkt["__argument_class__"] = self._get_argument_class_from_signature(func_) - return _Class + return type("aaa", (self,), class_dkt) if func is None: return _class_generator diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 65c6bb6fb..bd7d7353d 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -392,6 +392,7 @@ def func(params: dict) -> dict: def test_generate_class_from_function(self): def sample_function(params: dict) -> dict: + params["a"] = 1 return params new_cls = ClassForTestFromFunc.from_function(sample_function, name="sample1", info="sample2", alpha=2.0) @@ -399,6 +400,8 @@ def sample_function(params: dict) -> dict: assert new_cls.get_name() == "sample1" assert new_cls.get_info() == "sample2" assert new_cls.get_alpha() == 2.0 + assert new_cls.calculate(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} + assert new_cls.__argument_class__ == dict def test_additional_function_parameter_error(self): def sample_function(params: dict, beta: float) -> dict: From 42f7739e14723a50e1b62bf3ba8f130ea7c3a118 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 13:37:44 +0100 Subject: [PATCH 06/25] add calling function directly. --- .../PartSegCore/algorithm_describe_base.py | 44 ++++++++++++++++--- .../test_algorithm_describe_base.py | 2 + 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 7f0945acf..a963cd232 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -129,9 +129,12 @@ def __new__(cls, name, bases, attrs, **kwargs): ): raise RuntimeError("class need to have __argument_class__ set or get_fields functions defined") cls2.__new_style__ = getattr(cls2.get_fields, "__is_partial_abstractmethod__", False) + cls2.__from_function__ = getattr(cls2, "__from_function__", False) cls2.__abstract_getters__ = {} - cls2.__calculation_method__ = calculation_method - cls2.__calculation_method_params_name__ = calculation_method_params_name + cls2.__calculation_method__ = calculation_method or getattr(cls2, "__calculation_method__", None) + cls2.__calculation_method_params_name__ = calculation_method_params_name or getattr( + cls2, "__calculation_method_params_name__", None + ) if cls2.__calculation_method_params_name__ is None: cls2.__calculation_method_params_name__ = cls._get_calculation_method_params_name(cls2) @@ -247,8 +250,9 @@ def _calculate_method(**kwargs_): class_dkt[self.__calculation_method__] = _calculate_method class_dkt["__argument_class__"] = self._get_argument_class_from_signature(func_) + class_dkt["__from_function__"] = True - return type("aaa", (self,), class_dkt) + return type(func_.__name__.replace("_", " ").title().replace(" ", ""), (self,), class_dkt) if func is None: return _class_generator @@ -265,6 +269,11 @@ class AlgorithmDescribeBase(ABC, metaclass=AlgorithmDescribeBaseMeta): __argument_class__: typing.Optional[typing.Type[PydanticBaseModel]] = None __new_style__: bool + def __new__(cls, *args, **kwargs): + if cls.__from_function__: + return getattr(cls, cls.__calculation_method__)(*args, **kwargs) + return super().__new__(cls) + @classmethod def get_doc_from_fields(cls): resp = "{\n" @@ -278,10 +287,31 @@ def get_doc_from_fields(cls): return resp @classmethod - def from_function(cls, func=None, **kwargs): - if "name" not in kwargs: - kwargs["name"] = func.__name__.replace("_", " ").capitalize() - return AlgorithmDescribeBaseMeta.from_function(cls, func, **kwargs) + @typing.overload + def from_function(cls, func: typing.Callable[..., typing.Any], **kwargs) -> typing.Type["AlgorithmDescribeBase"]: + ... + + @classmethod + @typing.overload + def from_function( + cls, **kwargs + ) -> typing.Callable[[typing.Callable[..., typing.Any]], typing.Type["AlgorithmDescribeBase"]]: + ... + + @classmethod + def from_function( + cls, func=None, **kwargs + ) -> typing.Union[ + typing.Type["AlgorithmDescribeBase"], typing.Callable[[typing.Callable], typing.Type["AlgorithmDescribeBase"]] + ]: + def _from_function(func_) -> typing.Type["AlgorithmDescribeBase"]: + if "name" not in kwargs: + kwargs["name"] = func_.__name__.replace("_", " ").title() + return AlgorithmDescribeBaseMeta.from_function(cls, func_, **kwargs) + + if func is None: + return _from_function + return _from_function(func) @classmethod @abstractmethod diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index bd7d7353d..54df311de 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -402,6 +402,8 @@ def sample_function(params: dict) -> dict: assert new_cls.get_alpha() == 2.0 assert new_cls.calculate(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} assert new_cls.__argument_class__ == dict + assert new_cls.__name__ == "SampleFunction" + assert new_cls(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} def test_additional_function_parameter_error(self): def sample_function(params: dict, beta: float) -> dict: From 47b4d4ff43f35793f1be2eef1ccc3de801a09c08 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 13:38:22 +0100 Subject: [PATCH 07/25] add sample usage of functional api --- package/PartSegCore/segmentation/threshold.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/package/PartSegCore/segmentation/threshold.py b/package/PartSegCore/segmentation/threshold.py index 1ac5c7e19..f4a7d0ff0 100644 --- a/package/PartSegCore/segmentation/threshold.py +++ b/package/PartSegCore/segmentation/threshold.py @@ -2,6 +2,7 @@ import warnings from abc import ABC +import mahotas import numpy as np import SimpleITK as sitk from nme import register_class, rename_key, update_argument @@ -35,7 +36,7 @@ class MultipleOtsuThresholdParams(BaseModel): bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16) -class BaseThreshold(AlgorithmDescribeBase, ABC): +class BaseThreshold(AlgorithmDescribeBase, ABC, calculation_method="calculate_mask"): @classmethod def calculate_mask( cls, @@ -274,6 +275,35 @@ def get_name(cls) -> str: return "Multiple Otsu" +class MahotasThreshold(BaseModel): + ignore_zeros: bool = False + + +@BaseThreshold.from_function() +def riddler_calvard(data: np.ndarray, arguments: MahotasThreshold): + """ + Riddler-Calvard thresholding algorithm from mahotas + + Parameters + ---------- + data : ndarray + Image data. + arguments : MahotasThreshold + method parameter + + Returns + ------- + mask : ndarray + Computed mask + References + ---------- + .. [1] C. A. Riddler, and G. S. Calvard, "Picture thresholding using an iterative selection method," + """ + + threshold = mahotas.rc(data, ignore_zeros=arguments.ignore_zeros) + return (data > threshold).astype(np.uint8), threshold + + class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold): pass @@ -292,6 +322,7 @@ class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], s ThresholdSelection.register(MomentsThreshold) ThresholdSelection.register(MaximumEntropyThreshold) ThresholdSelection.register(MultipleOtsuThreshold) +ThresholdSelection.register(riddler_calvard) class DoubleThresholdParams(BaseModel): From 2f8c4708518f5750090f3867f4f86fbf1746ac41 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 14:29:00 +0100 Subject: [PATCH 08/25] fix tests --- .../PartSegCore/algorithm_describe_base.py | 12 +++++++++- package/PartSegCore/segmentation/threshold.py | 24 ++++++++++++++++--- .../segmentation/test_threshold.py | 10 ++++---- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index a963cd232..4ce2bdc83 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -215,6 +215,10 @@ def _get_argument_class_from_signature(self, func): return BaseModel return signature.parameters[self.__calculation_method_params_name__].annotation + def _get_parameters_from_signature(self): + signature = inspect.signature(getattr(self, self.__calculation_method__)) + return [parameters.name for parameters in signature.parameters.values()] + def from_function(self, func=None, **kwargs): """generate new class from function""" @@ -239,9 +243,15 @@ def _func(): def _class_generator(func_): drop_attr = self._validate_function_parameters(func_) + parameters_order = self._get_parameters_from_signature() @wraps(func_) - def _calculate_method(**kwargs_): + def _calculate_method(*args, **kwargs_): + for attr, name in zip(args, parameters_order): + if name in kwargs_: + raise ValueError(f"Parameter {name} is defined twice") + kwargs_[name] = attr + for name in drop_attr: kwargs_.pop(name, None) return func_(**kwargs_) diff --git a/package/PartSegCore/segmentation/threshold.py b/package/PartSegCore/segmentation/threshold.py index f4a7d0ff0..e717194b0 100644 --- a/package/PartSegCore/segmentation/threshold.py +++ b/package/PartSegCore/segmentation/threshold.py @@ -280,7 +280,9 @@ class MahotasThreshold(BaseModel): @BaseThreshold.from_function() -def riddler_calvard(data: np.ndarray, arguments: MahotasThreshold): +def riddler_calvard( + data: np.ndarray, mask: np.ndarray, arguments: MahotasThreshold, operator: typing.Callable[[object, object], bool] +): """ Riddler-Calvard thresholding algorithm from mahotas @@ -288,8 +290,12 @@ def riddler_calvard(data: np.ndarray, arguments: MahotasThreshold): ---------- data : ndarray Image data. + mask : ndarray + Mask data. arguments : MahotasThreshold method parameter + operator : callable + operator to use for thresholding Returns ------- @@ -300,8 +306,20 @@ def riddler_calvard(data: np.ndarray, arguments: MahotasThreshold): .. [1] C. A. Riddler, and G. S. Calvard, "Picture thresholding using an iterative selection method," """ - threshold = mahotas.rc(data, ignore_zeros=arguments.ignore_zeros) - return (data > threshold).astype(np.uint8), threshold + try: + threshold = mahotas.rc(data, ignore_zeros=arguments.ignore_zeros) + except TypeError as e: + if "This function only accepts integer types" in e.args[0]: + raise SegmentationLimitException(*e.args) from e + raise + if operator(1, 0): + res = (data >= threshold).astype(np.uint8) + else: + res = (data < threshold).astype(np.uint8) + + if mask is not None: + res = res * (mask > 0) + return res, threshold class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold): diff --git a/package/tests/test_PartSegCore/segmentation/test_threshold.py b/package/tests/test_PartSegCore/segmentation/test_threshold.py index 4bb23e454..e923de7f7 100644 --- a/package/tests/test_PartSegCore/segmentation/test_threshold.py +++ b/package/tests/test_PartSegCore/segmentation/test_threshold.py @@ -12,11 +12,11 @@ ThresholdSelection, ) -square = np.zeros((21, 21)) +square = np.zeros((21, 21), dtype=np.uint32) for i, v in [(2, 1000), (4, 10000), (6, 20000), (7, 21000), (8, 22000), (9, 23000)]: square[i:-i, i:-i] = v -cube = np.zeros((21, 21, 21)) +cube = np.zeros((21, 21, 21), dtype=np.uint32) for i, v in [(2, 1000), (4, 10000), (6, 20000), (7, 21000), (8, 22000), (9, 23000)]: cube[i:-i, i:-i, i:-i] = v @@ -37,7 +37,7 @@ def test_threshold(method: BaseThreshold, data, op, masking): pytest.xfail("IntermodesThreshold sigma problem") raise assert isinstance(data, np.ndarray) - assert isinstance(thr_info, (int, float)) + assert isinstance(thr_info, (int, float, np.integer)), type(thr_info) @pytest.mark.parametrize("method", DoubleThresholdSelection.__register__.values()) @@ -50,5 +50,5 @@ def test_double_threshold(method: BaseThreshold, data, op, masking): data, thr_info = method.calculate_mask(data=data, mask=mask, arguments=method.get_default_values(), operator=op) assert isinstance(data, np.ndarray) - assert isinstance(thr_info[0], (int, float)) - assert isinstance(thr_info[1], (int, float)) + assert isinstance(thr_info[0], (int, float, np.integer)) + assert isinstance(thr_info[1], (int, float, np.integer)) From a858a74d83ac6d6d310dc0ea51c73e55688f2909 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 20:44:14 +0100 Subject: [PATCH 09/25] improve tests --- .../test_algorithm_describe_base.py | 23 +++++++++++++++---- setup.cfg | 1 + 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 54df311de..7a3079149 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -344,7 +344,7 @@ def get_fields(cls) -> typing.List[typing.Union[AlgorithmProperty, str]]: def test_generate_class_from_function_lack_of_methods(self): def sample_function(params: dict) -> dict: - return params + """For test purpose""" with pytest.raises(ValueError, match="missing: alpha, info"): ClassForTestFromFunc.from_function(sample_function) @@ -364,7 +364,7 @@ def sample_function(params: dict) -> dict: def test_missing_return_annotation(self): with pytest.raises(RuntimeError, match="Method get_sample should have return annotation"): - class SampleClass(AlgorithmDescribeBase): + class SampleClass(AlgorithmDescribeBase): # pylint: disable=unused-variable @classmethod @abstractmethod def get_sample(cls): @@ -372,7 +372,7 @@ def get_sample(cls): def test_not_supported_from_function(self): def sample_function(params: dict) -> dict: - return params + """For test purpose""" class SampleClass(AlgorithmDescribeBase): @classmethod @@ -385,7 +385,7 @@ def sample(cls) -> dict: def test_wrong_type(self): def func(params: dict) -> dict: - return params + """For test purpose""" with pytest.raises(TypeError, match="Value for info should be "): ClassForTestFromFunc.from_function(func, info=1, name="sample", alpha=1.0) @@ -401,9 +401,24 @@ def sample_function(params: dict) -> dict: assert new_cls.get_info() == "sample2" assert new_cls.get_alpha() == 2.0 assert new_cls.calculate(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} + assert new_cls.calculate(params={"b": 2}) == {"b": 2, "a": 1} + with pytest.raises(ValueError, match="Parameter params is defined twice"): + new_cls.calculate({"a": 1}, params={}) assert new_cls.__argument_class__ == dict assert new_cls.__name__ == "SampleFunction" assert new_cls(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} + assert new_cls({"b": 2}) == {"b": 2, "a": 1} + + def test_generate_class_from_function_without_params(self): + @ClassForTestFromFunc.from_function(info="sample2", alpha=2.0) + def sample_function(scalar: int) -> dict: + return {"a": scalar} + + assert issubclass(sample_function, ClassForTestFromFunc) + assert sample_function.get_name() == "Sample Function" + assert sample_function.__name__ == "SampleFunction" + assert sample_function.calculate(scalar=1, params={"b": 2}) == {"a": 1} + assert sample_function.__argument_class__.__name__ == "BaseModel" def test_additional_function_parameter_error(self): def sample_function(params: dict, beta: float) -> dict: diff --git a/setup.cfg b/setup.cfg index acd61d3e2..8f5e0a931 100644 --- a/setup.cfg +++ b/setup.cfg @@ -177,6 +177,7 @@ exclude_lines = raise NotImplementedError if typing.TYPE_CHECKING raise RuntimeError() + ... [flake8] max-line-length = 120 From 3f0f9dccf92b5f75f32511f99bdac6c9805e39dd Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 21:24:10 +0100 Subject: [PATCH 10/25] simplify code --- package/PartSegCore/algorithm_describe_base.py | 18 ++++++++++++++---- package/PartSegCore/segmentation/threshold.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 4ce2bdc83..9fe0172b3 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -141,6 +141,17 @@ def __new__(cls, name, bases, attrs, **kwargs): cls2.__support_from_function__ = ( cls2.__calculation_method__ is not None and cls2.__calculation_method_params_name__ is not None ) + + cls2.__abstract_getters__, cls2.__support_from_function__ = cls._get_abstract_getters( + cls2, cls2.__support_from_function__, calculation_method + ) + return cls2 + + @staticmethod + def _get_abstract_getters( + cls2, support_from_function, calculation_method + ) -> typing.Tuple[typing.Dict[str, typing.Any], bool]: + abstract_getters = {} if hasattr(cls2, "__abstractmethods__") and cls2.__abstractmethods__: # get all abstract methods that starts with `get_` for method_name in cls2.__abstractmethods__: @@ -148,11 +159,10 @@ def __new__(cls, name, bases, attrs, **kwargs): if "return" not in getattr(cls2, method_name).__annotations__: raise RuntimeError(f"Method {method_name} should have return annotation") - cls2.__abstract_getters__[method_name[4:]] = getattr(cls2, method_name).__annotations__["return"] + abstract_getters[method_name[4:]] = getattr(cls2, method_name).__annotations__["return"] elif method_name != calculation_method: - cls2.__support_from_function__ = False - - return cls2 + support_from_function = False + return abstract_getters, support_from_function @staticmethod def _get_calculation_method_params_name(cls2) -> typing.Optional[str]: diff --git a/package/PartSegCore/segmentation/threshold.py b/package/PartSegCore/segmentation/threshold.py index e717194b0..09adbbeef 100644 --- a/package/PartSegCore/segmentation/threshold.py +++ b/package/PartSegCore/segmentation/threshold.py @@ -308,7 +308,7 @@ def riddler_calvard( try: threshold = mahotas.rc(data, ignore_zeros=arguments.ignore_zeros) - except TypeError as e: + except TypeError as e: # pragma: no cover if "This function only accepts integer types" in e.args[0]: raise SegmentationLimitException(*e.args) from e raise From 2525ce453747fe814ed1788024f4b1ce6af2b55e Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 21:52:06 +0100 Subject: [PATCH 11/25] try upload coverage --- .github/workflows/base_test_workflow.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/base_test_workflow.yml b/.github/workflows/base_test_workflow.yml index 03af46c46..c2d5362fb 100644 --- a/.github/workflows/base_test_workflow.yml +++ b/.github/workflows/base_test_workflow.yml @@ -82,3 +82,10 @@ jobs: with: file: ./coverage.xml fail_ci_if_error: true + + - name: Upload coverage + uses: actions/upload-artifact@v3 + with: + name: coverage + path: ./coverage.xml + retention-days: 5 From 94bfd7974eaec4c9fe8e5a35ba6eeddaacb9cceb Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 3 Jan 2023 21:32:34 +0000 Subject: [PATCH 12/25] fix coverage --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8f5e0a931..948625ad5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -177,7 +177,7 @@ exclude_lines = raise NotImplementedError if typing.TYPE_CHECKING raise RuntimeError() - ... + \.\.\. [flake8] max-line-length = 120 From 02e776302eda18894c2902c66025a3a0f5fc13e5 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 4 Jan 2023 11:50:10 +0100 Subject: [PATCH 13/25] a little cleanup --- package/PartSegCore/io_utils.py | 11 ++++++----- package/PartSegCore/segmentation/border_smoothing.py | 2 +- package/PartSegCore/segmentation/watershed.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/package/PartSegCore/io_utils.py b/package/PartSegCore/io_utils.py index 6cc5d5527..51f1c0397 100644 --- a/package/PartSegCore/io_utils.py +++ b/package/PartSegCore/io_utils.py @@ -56,7 +56,7 @@ def get_tarinfo(name, buffer: typing.Union[BytesIO, StringIO]): return tar_info -class SaveBase(AlgorithmDescribeBase, ABC): +class SaveBase(AlgorithmDescribeBase, ABC, calculation_method="save"): need_functions = [ "save", "get_short_name", @@ -100,10 +100,12 @@ def get_default_extension(cls): @classmethod def need_segmentation(cls): + """If method requires segmentation (ROI) to work, or could work with image only""" return True @classmethod def need_mask(cls): + """If `mask` is required for perform save""" return False @classmethod @@ -117,7 +119,7 @@ def get_extensions(cls) -> typing.List[str]: return [x[1:] for x in extensions] -class LoadBase(AlgorithmDescribeBase, ABC): +class LoadBase(AlgorithmDescribeBase, ABC, calculation_method="load"): need_functions = [ "load", "get_short_name", @@ -157,8 +159,7 @@ def get_name_with_suffix(cls): @classmethod def get_extensions(cls) -> typing.List[str]: - match = re.match(r".*\((.*)\)", cls.get_name()) - if match is None: + if match := re.match(r".*\((.*)\)", cls.get_name()) is None: raise ValueError(f"No extensions found in {cls.get_name()}") extensions = match[1].split(" ") if not all(x.startswith("*.") for x in extensions): @@ -201,7 +202,7 @@ def load_metadata_base(data: typing.Union[str, Path]): try: decoded_data = json.loads(str(data), object_hook=partseg_object_hook) except Exception: - raise e + raise e from None return decoded_data diff --git a/package/PartSegCore/segmentation/border_smoothing.py b/package/PartSegCore/segmentation/border_smoothing.py index 0cc2d0b5c..ef82f21c8 100644 --- a/package/PartSegCore/segmentation/border_smoothing.py +++ b/package/PartSegCore/segmentation/border_smoothing.py @@ -11,7 +11,7 @@ from PartSegCore.utils import BaseModel -class BaseSmoothing(AlgorithmDescribeBase, ABC): +class BaseSmoothing(AlgorithmDescribeBase, ABC, calculation_method="smooth"): __argument_class__ = BaseModel @classmethod diff --git a/package/PartSegCore/segmentation/watershed.py b/package/PartSegCore/segmentation/watershed.py index baff5ee59..4391aeb1d 100644 --- a/package/PartSegCore/segmentation/watershed.py +++ b/package/PartSegCore/segmentation/watershed.py @@ -22,7 +22,7 @@ ) -class BaseWatershed(AlgorithmDescribeBase, ABC): +class BaseWatershed(AlgorithmDescribeBase, ABC, calculation_method="sprawl"): """base class for all sprawl interface""" __argument_class__ = BaseModel From fb941dd2fbeb5962acbcf4df32f2d44d4f2f862c Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 4 Jan 2023 12:17:11 +0100 Subject: [PATCH 14/25] add test for decorator --- .../PartSegCore/algorithm_describe_base.py | 4 ++- .../test_algorithm_describe_base.py | 26 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 9fe0172b3..82b31de3f 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -120,7 +120,7 @@ def _partial_abstractmethod(funcobj): class AlgorithmDescribeBaseMeta(ABCMeta): def __new__(cls, name, bases, attrs, **kwargs): calculation_method = kwargs.pop("calculation_method", None) - calculation_method_params_name = kwargs.pop("calculation_method__params_name", None) + calculation_method_params_name = kwargs.pop("calculation_method_params_name", None) cls2 = super().__new__(cls, name, bases, attrs, **kwargs) if ( not inspect.isabstract(cls2) @@ -173,6 +173,8 @@ def _get_calculation_method_params_name(cls2) -> typing.Optional[str]: return "arguments" if "params" in signature.parameters: return "params" + if "parameters" in signature.parameters: + return "parameters" raise RuntimeError(f"Cannot determine arguments parameter name in {cls2.__calculation_method__}") def _validate_if_all_abstract_getters_are_defined(self, kwargs): diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 7a3079149..deb521fe6 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -1,6 +1,6 @@ # pylint: disable=R0201 import typing -from abc import abstractmethod +from abc import ABC, abstractmethod from enum import Enum import pytest @@ -10,6 +10,7 @@ from PartSegCore.algorithm_describe_base import ( AlgorithmDescribeBase, + AlgorithmDescribeBaseMeta, AlgorithmProperty, AlgorithmSelection, ROIExtractionProfile, @@ -434,6 +435,29 @@ def sample_function(params: dict, /) -> dict: with pytest.raises(ValueError, match="Function .*sample_function.* should not have positional only parameters"): ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0) + def test_fom_function_as_decorator(self): + class SampleClass(ABC, metaclass=AlgorithmDescribeBaseMeta): + @classmethod + @abstractmethod + def get_sample(cls) -> str: + raise NotImplementedError() + + @classmethod + def get_fields(cls): + raise NotImplementedError() + + class SampleClass2(SampleClass, calculation_method="calculate"): + @classmethod + @abstractmethod + def calculate(cls, a: int, arguments: dict) -> str: + raise NotImplementedError() + + @SampleClass2.from_function(sample="aaa") + def calc(a: int) -> str: + return f"aaa {a}" + + assert calc.calculate(a=1, arguments={}) == "aaa 1" + def test_roi_extraction_profile(): ROIExtractionProfile(name="aaa", algorithm="aaa", values={}) From 077f13a7a7c6ffa6a4ef4021ebeb98b46ee6b2cc Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 4 Jan 2023 13:49:10 +0100 Subject: [PATCH 15/25] add test for function without params --- package/PartSegCore/io_utils.py | 4 ++-- .../test_algorithm_describe_base.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/package/PartSegCore/io_utils.py b/package/PartSegCore/io_utils.py index 51f1c0397..a0510277e 100644 --- a/package/PartSegCore/io_utils.py +++ b/package/PartSegCore/io_utils.py @@ -119,7 +119,7 @@ def get_extensions(cls) -> typing.List[str]: return [x[1:] for x in extensions] -class LoadBase(AlgorithmDescribeBase, ABC, calculation_method="load"): +class LoadBase(AlgorithmDescribeBase, ABC, calculation_method="load", calculation_method_params_name="parameters"): need_functions = [ "load", "get_short_name", @@ -159,7 +159,7 @@ def get_name_with_suffix(cls): @classmethod def get_extensions(cls) -> typing.List[str]: - if match := re.match(r".*\((.*)\)", cls.get_name()) is None: + if (match := re.match(r".*\((.*)\)", cls.get_name())) is None: raise ValueError(f"No extensions found in {cls.get_name()}") extensions = match[1].split(" ") if not all(x.startswith("*.") for x in extensions): diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index deb521fe6..56d95e91e 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -458,6 +458,21 @@ def calc(a: int) -> str: assert calc.calculate(a=1, arguments={}) == "aaa 1" + def test_class_without_user_provided_attributes(self): + class SampleClass( + AlgorithmDescribeBase, calculation_method="calculate", calculation_method_params_name="parameters" + ): + @classmethod + @abstractmethod + def calculate(cls, a: int, b: int) -> int: + raise NotImplementedError() + + @SampleClass.from_function() + def calc(a: int, b: int) -> int: + return a + b + + assert calc.calculate(a=1, b=2) == 3 + def test_roi_extraction_profile(): ROIExtractionProfile(name="aaa", algorithm="aaa", values={}) From bebab4cc5a5a57757ec23c06e9f53244b8f935c2 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 4 Jan 2023 17:35:52 +0100 Subject: [PATCH 16/25] try to fix warnings --- package/PartSegCore/algorithm_describe_base.py | 3 ++- .../tests/test_PartSegCore/test_algorithm_describe_base.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 82b31de3f..8752ea07b 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -252,10 +252,11 @@ def _func(): return _func + parameters_order = self._get_parameters_from_signature() + def _class_generator(func_): drop_attr = self._validate_function_parameters(func_) - parameters_order = self._get_parameters_from_signature() @wraps(func_) def _calculate_method(*args, **kwargs_): diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 56d95e91e..460b93f27 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -407,8 +407,8 @@ def sample_function(params: dict) -> dict: new_cls.calculate({"a": 1}, params={}) assert new_cls.__argument_class__ == dict assert new_cls.__name__ == "SampleFunction" - assert new_cls(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} - assert new_cls({"b": 2}) == {"b": 2, "a": 1} + assert new_cls(params={"b": 2}, scalar=1) == {"b": 2, "a": 1} # pylint: disable=not-callable + assert new_cls({"b": 2}) == {"b": 2, "a": 1} # pylint: disable=not-callable def test_generate_class_from_function_without_params(self): @ClassForTestFromFunc.from_function(info="sample2", alpha=2.0) From 2beb83c4bdc10cd3046ee5a0defb00e67325ce8f Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 4 Jan 2023 20:45:30 +0100 Subject: [PATCH 17/25] check if warning is resolved --- package/PartSegCore/algorithm_describe_base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 8752ea07b..918b71c43 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -177,8 +177,9 @@ def _get_calculation_method_params_name(cls2) -> typing.Optional[str]: return "parameters" raise RuntimeError(f"Cannot determine arguments parameter name in {cls2.__calculation_method__}") - def _validate_if_all_abstract_getters_are_defined(self, kwargs): - abstract_getters_set = set(self.__abstract_getters__) + @staticmethod + def _validate_if_all_abstract_getters_are_defined(abstract_getters, kwargs): + abstract_getters_set = set(abstract_getters) kwargs_set = set(kwargs.keys()) if abstract_getters_set != kwargs_set: @@ -227,8 +228,9 @@ def _get_argument_class_from_signature(self, func): return BaseModel return signature.parameters[self.__calculation_method_params_name__].annotation - def _get_parameters_from_signature(self): - signature = inspect.signature(getattr(self, self.__calculation_method__)) + @staticmethod + def _get_parameters_from_signature(func): + signature = inspect.signature(func) return [parameters.name for parameters in signature.parameters.values()] def from_function(self, func=None, **kwargs): @@ -239,7 +241,7 @@ def from_function(self, func=None, **kwargs): if not self.__support_from_function__: raise RuntimeError("This class does not support from_function method") - self._validate_if_all_abstract_getters_are_defined(kwargs) + self._validate_if_all_abstract_getters_are_defined(self.__abstract_getters__, kwargs) # check if all values have correct type for key, value in kwargs.items(): @@ -252,7 +254,7 @@ def _func(): return _func - parameters_order = self._get_parameters_from_signature() + parameters_order = self._get_parameters_from_signature(getattr(self, self.__calculation_method__)) def _class_generator(func_): From c0a37704a2941606364671f4fa92e5119ef54314 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 4 Jan 2023 20:55:05 +0100 Subject: [PATCH 18/25] fix remaining names --- .../PartSegCore/algorithm_describe_base.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 918b71c43..7e4a5991a 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -118,9 +118,7 @@ def _partial_abstractmethod(funcobj): class AlgorithmDescribeBaseMeta(ABCMeta): - def __new__(cls, name, bases, attrs, **kwargs): - calculation_method = kwargs.pop("calculation_method", None) - calculation_method_params_name = kwargs.pop("calculation_method_params_name", None) + def __new__(cls, name, bases, attrs, calculation_method=None, calculation_method_params_name=None, **kwargs): cls2 = super().__new__(cls, name, bases, attrs, **kwargs) if ( not inspect.isabstract(cls2) @@ -197,7 +195,8 @@ def _validate_if_all_abstract_getters_are_defined(abstract_getters, kwargs): raise ValueError(f"{missing_text} {extra_text}") - def _validate_function_parameters(self, func) -> set: + @staticmethod + def _validate_function_parameters(func, method, method_name) -> set: """ Validate if all parameters without default values are defined in self.__calculation_method__ @@ -205,7 +204,7 @@ def _validate_function_parameters(self, func) -> set: :return: set of parameters that should be dropped """ signature = inspect.signature(func) - base_method_signature = inspect.signature(getattr(self, self.__calculation_method__)) + base_method_signature = inspect.signature(method) for parameters in signature.parameters.values(): if parameters.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY}: @@ -214,7 +213,7 @@ def _validate_function_parameters(self, func) -> set: parameters.default is inspect.Parameter.empty and parameters.name not in base_method_signature.parameters ): - raise ValueError(f"Parameter {parameters.name} is not defined in {self.__calculation_method__} method") + raise ValueError(f"Parameter {parameters.name} is not defined in {method_name} method") return { parameters.name @@ -222,11 +221,12 @@ def _validate_function_parameters(self, func) -> set: if parameters.name not in signature.parameters } - def _get_argument_class_from_signature(self, func): + @staticmethod + def _get_argument_class_from_signature(func, argument_name: str): signature = inspect.signature(func) - if self.__calculation_method_params_name__ not in signature.parameters: + if argument_name not in signature.parameters: return BaseModel - return signature.parameters[self.__calculation_method_params_name__].annotation + return signature.parameters[argument_name].annotation @staticmethod def _get_parameters_from_signature(func): @@ -258,7 +258,9 @@ def _func(): def _class_generator(func_): - drop_attr = self._validate_function_parameters(func_) + drop_attr = self._validate_function_parameters( + func_, getattr(self, self.__calculation_method__), self.__calculation_method__ + ) @wraps(func_) def _calculate_method(*args, **kwargs_): @@ -274,7 +276,9 @@ def _calculate_method(*args, **kwargs_): class_dkt = {f"get_{name}": _getter_by_name(name) for name in self.__abstract_getters__} class_dkt[self.__calculation_method__] = _calculate_method - class_dkt["__argument_class__"] = self._get_argument_class_from_signature(func_) + class_dkt["__argument_class__"] = self._get_argument_class_from_signature( + func_, self.__calculation_method_params_name__ + ) class_dkt["__from_function__"] = True return type(func_.__name__.replace("_", " ").title().replace(" ", ""), (self,), class_dkt) From e7480a533dfd16b284068f0d63ca4edfe574cf18 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 4 Jan 2023 21:45:54 +0100 Subject: [PATCH 19/25] use better names --- .../PartSegCore/algorithm_describe_base.py | 32 +++++++++---------- package/PartSegCore/io_utils.py | 4 +-- .../segmentation/border_smoothing.py | 2 +- .../segmentation/noise_filtering.py | 2 +- package/PartSegCore/segmentation/threshold.py | 2 +- package/PartSegCore/segmentation/watershed.py | 2 +- .../test_algorithm_describe_base.py | 8 ++--- 7 files changed, 25 insertions(+), 27 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 7e4a5991a..389b95274 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -118,7 +118,7 @@ def _partial_abstractmethod(funcobj): class AlgorithmDescribeBaseMeta(ABCMeta): - def __new__(cls, name, bases, attrs, calculation_method=None, calculation_method_params_name=None, **kwargs): + def __new__(cls, name, bases, attrs, method_from_fun=None, additional_parameters=None, **kwargs): cls2 = super().__new__(cls, name, bases, attrs, **kwargs) if ( not inspect.isabstract(cls2) @@ -129,19 +129,19 @@ def __new__(cls, name, bases, attrs, calculation_method=None, calculation_method cls2.__new_style__ = getattr(cls2.get_fields, "__is_partial_abstractmethod__", False) cls2.__from_function__ = getattr(cls2, "__from_function__", False) cls2.__abstract_getters__ = {} - cls2.__calculation_method__ = calculation_method or getattr(cls2, "__calculation_method__", None) - cls2.__calculation_method_params_name__ = calculation_method_params_name or getattr( - cls2, "__calculation_method_params_name__", None + cls2.__method_name__ = method_from_fun or getattr(cls2, "__method_name__", None) + cls2.__additional_parameters_name__ = additional_parameters or getattr( + cls2, "__additional_parameters_name__", None ) - if cls2.__calculation_method_params_name__ is None: - cls2.__calculation_method_params_name__ = cls._get_calculation_method_params_name(cls2) + if cls2.__additional_parameters_name__ is None: + cls2.__additional_parameters_name__ = cls._get_calculation_method_params_name(cls2) cls2.__support_from_function__ = ( - cls2.__calculation_method__ is not None and cls2.__calculation_method_params_name__ is not None + cls2.__method_name__ is not None and cls2.__additional_parameters_name__ is not None ) cls2.__abstract_getters__, cls2.__support_from_function__ = cls._get_abstract_getters( - cls2, cls2.__support_from_function__, calculation_method + cls2, cls2.__support_from_function__, method_from_fun ) return cls2 @@ -164,16 +164,16 @@ def _get_abstract_getters( @staticmethod def _get_calculation_method_params_name(cls2) -> typing.Optional[str]: - if cls2.__calculation_method__ is None: + if cls2.__method_name__ is None: return None - signature = inspect.signature(getattr(cls2, cls2.__calculation_method__)) + signature = inspect.signature(getattr(cls2, cls2.__method_name__)) if "arguments" in signature.parameters: return "arguments" if "params" in signature.parameters: return "params" if "parameters" in signature.parameters: return "parameters" - raise RuntimeError(f"Cannot determine arguments parameter name in {cls2.__calculation_method__}") + raise RuntimeError(f"Cannot determine arguments parameter name in {cls2.__method_name__}") @staticmethod def _validate_if_all_abstract_getters_are_defined(abstract_getters, kwargs): @@ -254,12 +254,12 @@ def _func(): return _func - parameters_order = self._get_parameters_from_signature(getattr(self, self.__calculation_method__)) + parameters_order = self._get_parameters_from_signature(getattr(self, self.__method_name__)) def _class_generator(func_): drop_attr = self._validate_function_parameters( - func_, getattr(self, self.__calculation_method__), self.__calculation_method__ + func_, getattr(self, self.__method_name__), self.__method_name__ ) @wraps(func_) @@ -275,9 +275,9 @@ def _calculate_method(*args, **kwargs_): class_dkt = {f"get_{name}": _getter_by_name(name) for name in self.__abstract_getters__} - class_dkt[self.__calculation_method__] = _calculate_method + class_dkt[self.__method_name__] = _calculate_method class_dkt["__argument_class__"] = self._get_argument_class_from_signature( - func_, self.__calculation_method_params_name__ + func_, self.__additional_parameters_name__ ) class_dkt["__from_function__"] = True @@ -300,7 +300,7 @@ class AlgorithmDescribeBase(ABC, metaclass=AlgorithmDescribeBaseMeta): def __new__(cls, *args, **kwargs): if cls.__from_function__: - return getattr(cls, cls.__calculation_method__)(*args, **kwargs) + return getattr(cls, cls.__method_name__)(*args, **kwargs) return super().__new__(cls) @classmethod diff --git a/package/PartSegCore/io_utils.py b/package/PartSegCore/io_utils.py index a0510277e..c9c46d15a 100644 --- a/package/PartSegCore/io_utils.py +++ b/package/PartSegCore/io_utils.py @@ -56,7 +56,7 @@ def get_tarinfo(name, buffer: typing.Union[BytesIO, StringIO]): return tar_info -class SaveBase(AlgorithmDescribeBase, ABC, calculation_method="save"): +class SaveBase(AlgorithmDescribeBase, ABC, method_from_fun="save"): need_functions = [ "save", "get_short_name", @@ -119,7 +119,7 @@ def get_extensions(cls) -> typing.List[str]: return [x[1:] for x in extensions] -class LoadBase(AlgorithmDescribeBase, ABC, calculation_method="load", calculation_method_params_name="parameters"): +class LoadBase(AlgorithmDescribeBase, ABC, method_from_fun="load", additional_parameters="parameters"): need_functions = [ "load", "get_short_name", diff --git a/package/PartSegCore/segmentation/border_smoothing.py b/package/PartSegCore/segmentation/border_smoothing.py index ef82f21c8..ba983d07a 100644 --- a/package/PartSegCore/segmentation/border_smoothing.py +++ b/package/PartSegCore/segmentation/border_smoothing.py @@ -11,7 +11,7 @@ from PartSegCore.utils import BaseModel -class BaseSmoothing(AlgorithmDescribeBase, ABC, calculation_method="smooth"): +class BaseSmoothing(AlgorithmDescribeBase, ABC, method_from_fun="smooth"): __argument_class__ = BaseModel @classmethod diff --git a/package/PartSegCore/segmentation/noise_filtering.py b/package/PartSegCore/segmentation/noise_filtering.py index 618f10d6f..504a3a524 100644 --- a/package/PartSegCore/segmentation/noise_filtering.py +++ b/package/PartSegCore/segmentation/noise_filtering.py @@ -22,7 +22,7 @@ def __str__(self): return self.name.replace("_", " ") -class NoiseFilteringBase(AlgorithmDescribeBase, ABC, calculation_method="noise_filter"): +class NoiseFilteringBase(AlgorithmDescribeBase, ABC, method_from_fun="noise_filter"): """Base class for noise filtering operations""" @classmethod diff --git a/package/PartSegCore/segmentation/threshold.py b/package/PartSegCore/segmentation/threshold.py index 09adbbeef..72681d89d 100644 --- a/package/PartSegCore/segmentation/threshold.py +++ b/package/PartSegCore/segmentation/threshold.py @@ -36,7 +36,7 @@ class MultipleOtsuThresholdParams(BaseModel): bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16) -class BaseThreshold(AlgorithmDescribeBase, ABC, calculation_method="calculate_mask"): +class BaseThreshold(AlgorithmDescribeBase, ABC, method_from_fun="calculate_mask"): @classmethod def calculate_mask( cls, diff --git a/package/PartSegCore/segmentation/watershed.py b/package/PartSegCore/segmentation/watershed.py index 4391aeb1d..6d253eb48 100644 --- a/package/PartSegCore/segmentation/watershed.py +++ b/package/PartSegCore/segmentation/watershed.py @@ -22,7 +22,7 @@ ) -class BaseWatershed(AlgorithmDescribeBase, ABC, calculation_method="sprawl"): +class BaseWatershed(AlgorithmDescribeBase, ABC, method_from_fun="sprawl"): """base class for all sprawl interface""" __argument_class__ = BaseModel diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 460b93f27..85e002471 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -446,7 +446,7 @@ def get_sample(cls) -> str: def get_fields(cls): raise NotImplementedError() - class SampleClass2(SampleClass, calculation_method="calculate"): + class SampleClass2(SampleClass, method_from_fun="calculate"): @classmethod @abstractmethod def calculate(cls, a: int, arguments: dict) -> str: @@ -459,9 +459,7 @@ def calc(a: int) -> str: assert calc.calculate(a=1, arguments={}) == "aaa 1" def test_class_without_user_provided_attributes(self): - class SampleClass( - AlgorithmDescribeBase, calculation_method="calculate", calculation_method_params_name="parameters" - ): + class SampleClass(AlgorithmDescribeBase, method_from_fun="calculate", additional_parameters="parameters"): @classmethod @abstractmethod def calculate(cls, a: int, b: int) -> int: @@ -487,7 +485,7 @@ def get_alpha(cls) -> float: raise NotImplementedError() -class ClassForTestFromFunc(ClassForTestFromFuncBase, calculation_method="calculate"): +class ClassForTestFromFunc(ClassForTestFromFuncBase, method_from_fun="calculate"): @classmethod @abstractmethod def get_info(cls) -> str: From ef72b97f8735966a499ecf16d377764cff19f66b Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Thu, 5 Jan 2023 12:36:31 +0100 Subject: [PATCH 20/25] allow kwargs function argument --- package/PartSegCore/algorithm_describe_base.py | 18 +++++++++++++----- .../test_algorithm_describe_base.py | 8 ++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 389b95274..92e0f7d24 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -205,15 +205,23 @@ def _validate_function_parameters(func, method, method_name) -> set: """ signature = inspect.signature(func) base_method_signature = inspect.signature(method) + take_all = False - for parameters in signature.parameters.values(): - if parameters.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY}: + for parameter in signature.parameters.values(): + if parameter.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY}: raise ValueError(f"Function {func} should not have positional only parameters") if ( - parameters.default is inspect.Parameter.empty - and parameters.name not in base_method_signature.parameters + parameter.default is inspect.Parameter.empty + and parameter.name not in base_method_signature.parameters + and parameter.kind != inspect.Parameter.VAR_KEYWORD ): - raise ValueError(f"Parameter {parameters.name} is not defined in {method_name} method") + raise ValueError(f"Parameter {parameter.name} is not defined in {method_name} method") + + if parameter.kind == inspect.Parameter.VAR_KEYWORD: + take_all = True + + if take_all: + return set() return { parameters.name diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index 85e002471..8f83f59eb 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -471,6 +471,14 @@ def calc(a: int, b: int) -> int: assert calc.calculate(a=1, b=2) == 3 + def test_functions_with_kwargs(self): + @ClassForTestFromFunc.from_function(info="sample2", alpha=2.0) + def sample_function(params: dict, **kwargs) -> dict: + params["scalar"] = kwargs["scalar"] + return params + + assert sample_function.calculate(params={"b": 2}, scalar=1) == {"b": 2, "scalar": 1} + def test_roi_extraction_profile(): ROIExtractionProfile(name="aaa", algorithm="aaa", values={}) From 05e2a86c2f746b8f9a665be0168c0101b9c39d9d Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Thu, 5 Jan 2023 12:41:53 +0100 Subject: [PATCH 21/25] improve type annotation --- package/PartSegCore/algorithm_describe_base.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 92e0f7d24..c17844cfa 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -19,6 +19,11 @@ from pydantic.fields import ModelField +T = typing.TypeVar("T", bound="AlgorithmDescribeBase") + +type_T = typing.Type[T] + + class AlgorithmDescribeNotFound(Exception): """ When algorithm description not found @@ -325,22 +330,18 @@ def get_doc_from_fields(cls): @classmethod @typing.overload - def from_function(cls, func: typing.Callable[..., typing.Any], **kwargs) -> typing.Type["AlgorithmDescribeBase"]: + def from_function(cls: type_T, func: typing.Callable[..., typing.Any], **kwargs) -> type_T: ... @classmethod @typing.overload - def from_function( - cls, **kwargs - ) -> typing.Callable[[typing.Callable[..., typing.Any]], typing.Type["AlgorithmDescribeBase"]]: + def from_function(cls: type_T, **kwargs) -> typing.Callable[[typing.Callable[..., typing.Any]], type_T]: ... @classmethod def from_function( - cls, func=None, **kwargs - ) -> typing.Union[ - typing.Type["AlgorithmDescribeBase"], typing.Callable[[typing.Callable], typing.Type["AlgorithmDescribeBase"]] - ]: + cls: type_T, func=None, **kwargs + ) -> typing.Union[type_T, typing.Callable[[typing.Callable], type_T]]: def _from_function(func_) -> typing.Type["AlgorithmDescribeBase"]: if "name" not in kwargs: kwargs["name"] = func_.__name__.replace("_", " ").title() From 7749300e185c15ad3ae62d576e6fce238b5bc28a Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 11 Jan 2023 18:41:03 +0100 Subject: [PATCH 22/25] install libhdf --- .github/workflows/base_test_workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/base_test_workflow.yml b/.github/workflows/base_test_workflow.yml index 6c6b31420..695cb406a 100644 --- a/.github/workflows/base_test_workflow.yml +++ b/.github/workflows/base_test_workflow.yml @@ -47,7 +47,7 @@ jobs: if: runner.os == 'Linux' run: | sudo apt update - sudo apt-get install -y libegl1 libdbus-1-3 libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xinput0 libxcb-xfixes0 x11-utils + sudo apt-get install -y libegl1 libdbus-1-3 libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xinput0 libxcb-xfixes0 x11-utils libhdf5-dev - name: Install Windows OpenGL if: runner.os == 'Windows' From ac7e3523fa46734db2c0ca0dbddfae74c68cbab3 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 11 Jan 2023 19:49:40 +0100 Subject: [PATCH 23/25] fix type name --- package/PartSegCore/algorithm_describe_base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index c17844cfa..32b3b86cb 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -21,7 +21,7 @@ T = typing.TypeVar("T", bound="AlgorithmDescribeBase") -type_T = typing.Type[T] +TypeT = typing.Type[T] class AlgorithmDescribeNotFound(Exception): @@ -330,18 +330,18 @@ def get_doc_from_fields(cls): @classmethod @typing.overload - def from_function(cls: type_T, func: typing.Callable[..., typing.Any], **kwargs) -> type_T: + def from_function(cls: TypeT, func: typing.Callable[..., typing.Any], **kwargs) -> TypeT: ... @classmethod @typing.overload - def from_function(cls: type_T, **kwargs) -> typing.Callable[[typing.Callable[..., typing.Any]], type_T]: + def from_function(cls: TypeT, **kwargs) -> typing.Callable[[typing.Callable[..., typing.Any]], TypeT]: ... @classmethod def from_function( - cls: type_T, func=None, **kwargs - ) -> typing.Union[type_T, typing.Callable[[typing.Callable], type_T]]: + cls: TypeT, func=None, **kwargs + ) -> typing.Union[TypeT, typing.Callable[[typing.Callable], TypeT]]: def _from_function(func_) -> typing.Type["AlgorithmDescribeBase"]: if "name" not in kwargs: kwargs["name"] = func_.__name__.replace("_", " ").title() From be6bef094585c26726992d00baea0180f25b6173 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 7 Feb 2023 22:16:25 +0100 Subject: [PATCH 24/25] add `kwargs_to_model` decorator for simply add backward compatybility --- package/PartSegCore/utils.py | 84 ++++++++++++++++++++ package/tests/test_PartSegCore/test_utils.py | 32 ++++++++ 2 files changed, 116 insertions(+) diff --git a/package/PartSegCore/utils.py b/package/PartSegCore/utils.py index 49b650f92..3e3ebb4d6 100644 --- a/package/PartSegCore/utils.py +++ b/package/PartSegCore/utils.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import suppress +from functools import wraps from types import MethodType import numpy as np @@ -466,3 +467,86 @@ def iterate_names(base_name: str, data_dict, max_length=None) -> typing.Optional if res_name not in data_dict: return res_name return None + + +def _get_translation_dicts_from_signature(signature: inspect.Signature) -> typing.Tuple[dict, dict]: + kwargs_to_field = {} + par_to_class = {} + + for par in signature.parameters.values(): + if par.annotation is par.empty or not issubclass(par.annotation, BaseModel): + raise TypeError(f"This decorator can be used only with functions with BaseModel parameters. {par} is not.") + par_to_class[par.name] = par.annotation + for field in par.annotation.__fields__.values(): + if field.name in kwargs_to_field: + raise TypeError( + f"Two parameters have same field name {field.name}. " + f"{kwargs_to_field[field.name]} and {par.name}." + ) + kwargs_to_field[field.name] = par.name + + return kwargs_to_field, par_to_class + + +def _get_kwargs_from_old_args(old_args_order: typing.List[str], args, kwargs): + + if len(args) > len(old_args_order): + raise TypeError("Too many positional arguments, please use keyword argument or update old_args_order") + kwargs = kwargs.copy() + for name, arg in zip(old_args_order, args): + kwargs[name] = arg + return kwargs + + +def _map_kwargs_to_model(kwargs_to_field, par_to_class, kwargs): + class_kwargs = {name: {} for name in par_to_class} + + for key, value in kwargs.items(): + if key in kwargs_to_field: + class_kwargs[kwargs_to_field[key]][key] = value + else: + raise TypeError(f"Unexpected keyword argument {key}") + + return {name: par_to_class[name](**class_kwargs[name]) for name in par_to_class} + + +def kwargs_to_model(func=None, old_args_order: typing.Optional[typing.List[str]] = None): + """ + Decorator for converting kwargs to base models before calling function. + + This decorator is to provide backward compatibility layer. + """ + + if old_args_order is None: + old_args_order = [] + + def decorator(func_): + + # get functions parameters type annotation to checks its model fields. + + signature = inspect.signature(func_) + + kwargs_to_field, par_to_class = _get_translation_dicts_from_signature(signature) + + @wraps(func_) + def wrapper(*args, **kwargs): + try: + return func_(*args, **kwargs) + except TypeError as e: + if "unexpected keyword argument" not in str(e): + raise e + kwargs = _get_kwargs_from_old_args(old_args_order, args, kwargs) + kwargs = _map_kwargs_to_model(kwargs_to_field, par_to_class, kwargs) + warnings.warn( + "The function is called using old signature. There is try to fix it. Please update function call", + FutureWarning, + stacklevel=2, + ) + + return func_(**kwargs) + + return wrapper + + if func: + return decorator(func) + return decorator diff --git a/package/tests/test_PartSegCore/test_utils.py b/package/tests/test_PartSegCore/test_utils.py index c6891b2c2..3fe282633 100644 --- a/package/tests/test_PartSegCore/test_utils.py +++ b/package/tests/test_PartSegCore/test_utils.py @@ -13,6 +13,7 @@ ProfileDict, get_callback, iterate_names, + kwargs_to_model, recursive_update_dict, ) @@ -357,3 +358,34 @@ class SampleModel(BaseModel): assert ob["c"] == "3" with pytest.raises(KeyError): ob["d"] # pylint: disable=pointless-statement + + +def test_kwargs_to_model(): + class SampleModel(BaseModel): + a: int + b: float + c: str + + class SampleModel2(BaseModel): + d: int + + @kwargs_to_model(old_args_order=["a", "b"]) + def func(par1: SampleModel, par2: SampleModel2): + return f"{par1.a} {par1.b} {par1.c} {par2.d}" + + assert func(SampleModel(a=1, b=2.0, c="a"), SampleModel2(d=3)) == "1 2.0 a 3" + + with pytest.warns(FutureWarning, match="called using old signature"): + assert func(a=1, b=2.0, c="a", d=4) == "1 2.0 a 4" + + with pytest.warns(FutureWarning, match="called using old signature"): + assert func(1, b=2.0, c="a", d=5) == "1 2.0 a 5" + + with pytest.warns(FutureWarning, match="called using old signature"): + assert func(1, 2.0, c="a", d=6) == "1 2.0 a 6" + + with pytest.raises(TypeError, match="Too many positional arguments"): + func(1, 2.0, "a", d=4) + + with pytest.raises(TypeError, match="Unexpected keyword argument e"): + assert func(a=1, b=2.0, c="a", d=4, e=1) == "1 2.0 a 4" From 28b697aa7eee1a69c22d3dff876f665777691717 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Fri, 26 Jan 2024 17:04:20 +0100 Subject: [PATCH 25/25] fix tests --- package/PartSegCore/algorithm_describe_base.py | 12 ++++++++++-- package/PartSegCore/io_utils.py | 4 ++-- package/PartSegCore/segmentation/algorithm_base.py | 2 +- .../test_PartSegCore/test_algorithm_describe_base.py | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index c3981bacd..9548f3567 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -156,8 +156,16 @@ def _get_abstract_getters( # get all abstract methods that starts with `get_` for method_name in cls2.__abstractmethods__: if method_name.startswith("get_"): - if "return" not in getattr(cls2, method_name).__annotations__: - raise RuntimeError(f"Method {method_name} should have return annotation") + method = getattr(cls2, method_name) + if "return" not in method.__annotations__: + msg = f"Method {method_name} of {cls2.__qualname__} need to have return type defined" + try: + file_name = inspect.getsourcefile(method) + line = inspect.getsourcelines(method)[1] + msg += f" in {file_name}:{line}" + except TypeError: + pass + raise RuntimeError(msg) abstract_getters[method_name[4:]] = getattr(cls2, method_name).__annotations__["return"] elif method_name != calculation_method: diff --git a/package/PartSegCore/io_utils.py b/package/PartSegCore/io_utils.py index 704b37ac7..f3fce5496 100644 --- a/package/PartSegCore/io_utils.py +++ b/package/PartSegCore/io_utils.py @@ -68,7 +68,7 @@ class SaveBase(AlgorithmDescribeBase, ABC): @classmethod @abstractmethod - def get_short_name(cls): + def get_short_name(cls) -> str: raise NotImplementedError @classmethod @@ -134,7 +134,7 @@ class LoadBase(AlgorithmDescribeBase, ABC): @classmethod @abstractmethod - def get_short_name(cls): + def get_short_name(cls) -> str: raise NotImplementedError @classmethod diff --git a/package/PartSegCore/segmentation/algorithm_base.py b/package/PartSegCore/segmentation/algorithm_base.py index 09106c2e5..6566c0f67 100644 --- a/package/PartSegCore/segmentation/algorithm_base.py +++ b/package/PartSegCore/segmentation/algorithm_base.py @@ -204,7 +204,7 @@ def calculation_run(self, report_fun: Callable[[str, int], None]) -> ROIExtracti raise NotImplementedError @abstractmethod - def get_info_text(self): + def get_info_text(self) -> str: raise NotImplementedError def get_channel(self, channel_idx): diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index ea683f27d..d0794bdd8 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -376,7 +376,7 @@ def sample_function(params: dict) -> dict: ClassForTestFromFunc.from_function(sample_function, info="sample", alpha=1.0, additions="sample3") def test_missing_return_annotation(self): - with pytest.raises(RuntimeError, match="Method get_sample should have return annotation"): + with pytest.raises(RuntimeError, match="Method get_sample of .*SampleClass need to have return type defined.*"): class SampleClass(AlgorithmDescribeBase): # pylint: disable=unused-variable @classmethod