diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py index 5510af704f..82273b4160 100644 --- a/baybe/searchspace/core.py +++ b/baybe/searchspace/core.py @@ -5,7 +5,7 @@ import gc from collections.abc import Iterable, Sequence from enum import Enum -from typing import cast +from typing import TYPE_CHECKING, cast import pandas as pd from attrs import define, field @@ -28,6 +28,9 @@ from baybe.serialization import SerialMixin, converter, select_constructor_hook from baybe.utils.conversion import to_string +if TYPE_CHECKING: + from baybe.parameters.selectors import ParameterSelectorProtocol + class SearchSpaceType(Enum): """Enum class for different types of search spaces and respective compatibility.""" @@ -287,35 +290,38 @@ def n_tasks(self) -> int: return 1 return len(task_param.values) - def get_comp_rep_parameter_indices(self, name: str, /) -> tuple[int, ...]: - """Find a parameter's column indices in the computational representation. + def get_comp_rep_parameter_indices( + self, + name_or_selector: str | ParameterSelectorProtocol, + /, + ) -> tuple[int, ...]: + """Find comp-rep column indices for a parameter selection. + + When called with a parameter name, returns the indices for that single + parameter. When called with a + :class:`~baybe.parameters.selectors.ParameterSelectorProtocol`, + returns the combined indices for all matching parameters. Args: - name: The name of the parameter whose columns indices are to be retrieved. - - Raises: - ValueError: If no parameter with the provided name exists. - ValueError: If more than one parameter with the provided name exists. + name_or_selector: Either the name of a single parameter or a selector + that filters parameters to be included. Returns: A tuple containing the integer indices of the columns in the computational - representation associated with the parameter. When the parameter is not part - of the computational representation, an empty tuple is returned. + representation associated with the selected parameter(s). When a selected + parameter is not part of the computational representation, it contributes + no indices. """ - params = self.get_parameters_by_name([name]) - if len(params) < 1: - raise ValueError( - f"There exists no parameter named '{name}' in the search space." - ) - if len(params) > 1: - raise ValueError( - f"There exist multiple parameter matches for '{name}' in the search " - f"space." - ) - p = params[0] + if isinstance(name_or_selector, str): + params: list[Parameter] = [ + p for p in self.parameters if p.name == name_or_selector + ] + else: + params = [p for p in self.parameters if name_or_selector(p)] return tuple( i + for p in params for i, col in enumerate(self.comp_rep_columns) if col in p.comp_rep_columns ) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 5bc0415da7..2bd2bdeafe 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools from abc import ABC, abstractmethod from collections.abc import Iterable from functools import partial @@ -22,6 +23,7 @@ to_parameter_selector, ) from baybe.searchspace.core import SearchSpace +from baybe.serialization.mixin import SerialMixin from baybe.surrogates.gaussian_process.components.generic import ( GPComponentFactoryProtocol, GPComponentType, @@ -44,7 +46,7 @@ @define -class _PureKernelFactory(KernelFactoryProtocol, ABC): +class _PureKernelFactory(KernelFactoryProtocol, SerialMixin, ABC): """Base class for pure kernel factories.""" # For internal use only: sanity check mechanism to remind developers of new @@ -75,6 +77,14 @@ def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...]: selector = self.parameter_selector or (lambda _: True) return tuple(p.name for p in searchspace.parameters if selector(p)) + def _get_effective_dimensionality(self, searchspace: SearchSpace) -> int: + """Get the number of computational columns for the selected parameters.""" + return len( + searchspace.get_comp_rep_parameter_indices( + self.parameter_selector or (lambda _: True) + ) + ) + def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None: """Validate that the given parameters are supported by the factory. @@ -115,6 +125,94 @@ def _make( """Construct the kernel.""" +def _enable_transfer_learning( + cls: type[_PureKernelFactory], name: str | None = None, / +) -> type[_PureKernelFactory]: + """Class decorator enabling BayBE's default transfer learning mechanism. + + When the search space contains a task parameter, the decorated factory + automatically composes its kernel with BayBE's default task kernel. + Otherwise, the factory behaves unchanged. + + When used as a decorator (without ``name``), the class is modified in-place. + When called with a ``name`` argument, a new subclass is created so that the + original class remains unmodified. The latter form is intended for cases where + the original class is reused independently elsewhere. + + Args: + cls: The kernel factory class to decorate. + name: Optional name for the created class. If provided, a new subclass is + created instead of modifying ``cls`` in-place. + + Raises: + TypeError: If the factory already supports task parameters. + + Returns: + The decorated kernel factory class with transfer learning enabled. + """ + if cls._supported_parameter_kinds & _ParameterKind.TASK: + raise TypeError(f"'{cls.__name__}' already supports task parameters.") + + # This distinction is important for serialization so that the classes can be + # correctly identified by their names in the subclass registry + if name is not None: + # Create a sibling class so the original class remains unmodified. + # We use cls.__bases__ (not (cls,)) because the new class is conceptually + # an equivalent variant, not a specialization. Concrete (non-dunder) + # attributes are copied so the sibling has the same behavior. + # __module__ must be set explicitly because the Protocol metaclass + # would otherwise default it to "abc". + ns = { + k: v + for k, v in cls.__dict__.items() + if not (k.startswith("__") and k.endswith("__")) + } + ns["__doc__"] = cls.__doc__ + ns["__module__"] = cls.__module__ + target_cls = type(name, cls.__bases__, ns) + else: + # Modify the class in-place (avoids name collision in subclass registry) + target_cls = cls + + original_call = cls.__call__ + original_supported_kinds = cls._supported_parameter_kinds + _task_exclude_selector = TypeSelector((TaskParameter,), exclude=True) + + @functools.wraps(original_call) + def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): + # Temporarily narrow the supported parameter kinds to those of the original + # class. If the decorator logic is correct, the original factory should never + # see the extended scope, but this acts as a sanity check to prevent regressions + broadened_kinds = target_cls._supported_parameter_kinds # type: ignore[attr-defined] + target_cls._supported_parameter_kinds = original_supported_kinds # type: ignore[attr-defined] + + # Split off the task parameters + original_selector = self.parameter_selector + if original_selector is None: + self.parameter_selector = _task_exclude_selector + else: + self.parameter_selector = lambda p: ( + _task_exclude_selector(p) and original_selector(p) + ) + + try: + base_kernel = original_call(self, searchspace, train_x, train_y) + finally: + target_cls._supported_parameter_kinds = broadened_kinds # type: ignore[attr-defined] + self.parameter_selector = original_selector + + if searchspace.task_idx is not None: + icm = ICMKernelFactory(base_kernel_or_factory=base_kernel) + return icm(searchspace, train_x, train_y) + return base_kernel + + target_cls.__call__ = __call__ # type: ignore[method-assign] + target_cls._supported_parameter_kinds = ( # type: ignore[attr-defined] + cls._supported_parameter_kinds | _ParameterKind.TASK + ) + return target_cls + + @define class _MetaKernelFactory(KernelFactoryProtocol, ABC): """Base class for meta kernel factories that orchestrate other kernel factories.""" @@ -150,18 +248,25 @@ class ICMKernelFactory(_MetaKernelFactory): @base_kernel_factory.default def _default_base_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( - BayBENumericalKernelFactory, + _BayBENumericalKernelFactory, ) - return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True)) + assert ( + _BayBENumericalKernelFactory._supported_parameter_kinds + is _ParameterKind.REGULAR + ) + return _BayBENumericalKernelFactory( + TypeSelector((TaskParameter,), exclude=True) + ) @task_kernel_factory.default def _default_task_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( - BayBETaskKernelFactory, + _BayBETaskKernelFactory, ) - return BayBETaskKernelFactory(TypeSelector((TaskParameter,))) + assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK + return _BayBETaskKernelFactory() @override def __call__( diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 0cf6686064..520fb00acd 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -26,39 +26,24 @@ from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( SmoothedEDBOKernelFactory, SmoothedEDBOLikelihoodFactory, + _SmoothedEDBONumericalKernelFactory, ) if TYPE_CHECKING: from torch import Tensor -@define -class BayBEKernelFactory(_PureKernelFactory): - """The default kernel factory for Gaussian process surrogates.""" - - _supported_parameter_kinds: ClassVar[_ParameterKind] = ( - _ParameterKind.REGULAR | _ParameterKind.TASK - ) - # See base class. - - @override - def _make( - self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor - ) -> Kernel: - from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory +class _BayBENumericalKernelFactory(_SmoothedEDBONumericalKernelFactory): + """The default numerical kernel factory for GP surrogates.""" - is_multitask = searchspace.task_idx is not None - factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory - return factory()(searchspace, train_x, train_y) - -BayBENumericalKernelFactory = SmoothedEDBOKernelFactory -"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 +class BayBEKernelFactory(SmoothedEDBOKernelFactory): # type: ignore[valid-type, misc] + """The default kernel factory for GP surrogates.""" @define -class BayBETaskKernelFactory(_PureKernelFactory): - """The factory providing the default task kernel for Gaussian process surrogates.""" +class _BayBETaskKernelFactory(_PureKernelFactory): + """The default task kernel factory for GP surrogates.""" _uses_parameter_names: ClassVar[bool] = True # See base class. @@ -83,11 +68,13 @@ def _make( ) -BayBEMeanFactory = LazyConstantMeanFactory -"""The factory providing the default mean function for Gaussian process surrogates.""" +class BayBEMeanFactory(LazyConstantMeanFactory): + """The default mean factory for GP surrogates.""" + + +class BayBELikelihoodFactory(SmoothedEDBOLikelihoodFactory): + """The default likelihood factory for GP surrogates.""" -BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory -"""The factory providing the default likelihood for Gaussian process surrogates.""" @define diff --git a/baybe/surrogates/gaussian_process/presets/chen.py b/baybe/surrogates/gaussian_process/presets/chen.py index e461bc0750..0122004beb 100644 --- a/baybe/surrogates/gaussian_process/presets/chen.py +++ b/baybe/surrogates/gaussian_process/presets/chen.py @@ -6,22 +6,17 @@ import math from typing import TYPE_CHECKING, ClassVar -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters.categorical import TaskParameter -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) from baybe.priors.basic import GammaPrior from baybe.surrogates.gaussian_process.components.fit_criterion import ( _MLLForNonTLFitCriterionFactory, ) from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -36,6 +31,7 @@ from baybe.searchspace.core import SearchSpace +@_enable_transfer_learning @define class CHENKernelFactory(_PureKernelFactory): """A factory providing adaptive hyperprior kernels as proposed by :cite:p:`Chen2026`.""" # noqa: E501 @@ -43,17 +39,12 @@ class CHENKernelFactory(_PureKernelFactory): _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - lengthscale = 0.4 * math.sqrt(train_x.shape[-1]) + 4.0 + n_dimensions = self._get_effective_dimensionality(searchspace) + lengthscale = 0.4 * math.sqrt(n_dimensions) + 4.0 lengthscale_prior = GammaPrior(2.0 * lengthscale, 2.0) lengthscale_initial_value = lengthscale outputscale_prior = GammaPrior(1.0 * lengthscale, 1.0) diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 1ff8d2bf80..c701d4dbb4 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -6,18 +6,12 @@ from collections.abc import Collection from typing import TYPE_CHECKING, ClassVar -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters import TaskParameter -from baybe.parameters.enum import SubstanceEncoding -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) +from baybe.parameters.enum import SubstanceEncoding, _ParameterKind from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete @@ -25,6 +19,7 @@ _MLLForNonTLFitCriterionFactory, ) from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -59,6 +54,7 @@ def _contains_encoding( """Encodings relevant to EDBO logic.""" +@_enable_transfer_learning @define class EDBOKernelFactory(_PureKernelFactory): """A factory providing EDBO kernels, as proposed by :cite:p:`Shields2021`. @@ -70,17 +66,11 @@ class EDBOKernelFactory(_PureKernelFactory): _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] + effective_dims = self._get_effective_dimensionality(searchspace) switching_condition = _contains_encoding( searchspace.discrete, _EDBO_ENCODINGS @@ -126,8 +116,8 @@ def _make( ) -EDBOMeanFactory = LazyConstantMeanFactory -"""A factory providing mean functions for the EDBO preset.""" +class EDBOMeanFactory(LazyConstantMeanFactory): + """A factory providing mean functions for the EDBO preset.""" @define @@ -145,8 +135,10 @@ def __call__( import torch from gpytorch.likelihoods import GaussianLikelihood - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] + effective_dims = len( + searchspace.get_comp_rep_parameter_indices( + lambda p: bool(p._kind & _ParameterKind.REGULAR) + ) ) switching_condition = _contains_encoding( diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index 519713f9dc..0ac1c6108e 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -6,22 +6,18 @@ from typing import TYPE_CHECKING, ClassVar import numpy as np -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.kernels.basic import MaternKernel from baybe.kernels.composite import ScaleKernel -from baybe.parameters import TaskParameter -from baybe.parameters.selectors import ( - ParameterSelectorProtocol, - TypeSelector, - to_parameter_selector, -) +from baybe.parameters.enum import _ParameterKind from baybe.priors.basic import GammaPrior from baybe.surrogates.gaussian_process.components.fit_criterion import ( _MLLForNonTLFitCriterionFactory, ) from baybe.surrogates.gaussian_process.components.kernel import ( + _enable_transfer_learning, _PureKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -42,28 +38,17 @@ @define -class SmoothedEDBOKernelFactory(_PureKernelFactory): - """A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`). - - Takes the low and high dimensional limits of - :class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory` - and interpolates the prior moments linearly in between. - """ # noqa: E501 +class _SmoothedEDBONumericalKernelFactory(_PureKernelFactory): + """A factory providing the core numerical kernel for the smoothed EDBO preset.""" _uses_parameter_names: ClassVar[bool] = True # See base class. - parameter_selector: ParameterSelectorProtocol | None = field( - factory=lambda: TypeSelector([TaskParameter], exclude=True), - converter=to_parameter_selector, - ) - # TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429) - @override def _make( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - effective_dims = train_x.shape[-1] + effective_dims = self._get_effective_dimensionality(searchspace) # Interpolate prior moments linearly between low D and high D regime. # The high D regime itself is the average of the EDBO OHE and Mordred regime. @@ -91,8 +76,19 @@ def _make( ) -SmoothedEDBOMeanFactory = LazyConstantMeanFactory -"""A factory providing mean functions for the smoothed EDBO preset.""" +SmoothedEDBOKernelFactory = _enable_transfer_learning( + _SmoothedEDBONumericalKernelFactory, "SmoothedEDBOKernelFactory" +) +"""A factory providing smoothed versions of EDBO kernels (adapted from :cite:p:`Shields2021`). + +Takes the low and high dimensional limits of +:class:`baybe.surrogates.gaussian_process.presets.edbo.EDBOKernelFactory` +and interpolates the prior moments linearly in between. +""" # noqa: E501 + + +class SmoothedEDBOMeanFactory(LazyConstantMeanFactory): + """A factory providing mean functions for the smoothed EDBO preset.""" @define @@ -114,8 +110,10 @@ def __call__( # Interpolate prior moments linearly between low D and high D regime. # The high D regime itself is the average of the EDBO OHE and Mordred regime. # Values outside the dimension limits will get the border value assigned. - effective_dims = train_x.shape[-1] - len( - [p for p in searchspace.parameters if isinstance(p, TaskParameter)] + effective_dims = len( + searchspace.get_comp_rep_parameter_indices( + lambda p: bool(p._kind & _ParameterKind.REGULAR) + ) ) prior = GammaPrior( diff --git a/tests/serialization/test_kernel_factory_serialization.py b/tests/serialization/test_kernel_factory_serialization.py new file mode 100644 index 0000000000..0bdc20af93 --- /dev/null +++ b/tests/serialization/test_kernel_factory_serialization.py @@ -0,0 +1,23 @@ +"""Kernel factory serialization tests.""" + +import pytest + +from baybe.surrogates.gaussian_process.components.kernel import _PureKernelFactory +from baybe.surrogates.gaussian_process.presets import * # noqa: F401, F403 +from baybe.utils.basic import get_subclasses +from tests.serialization.utils import assert_roundtrip_consistency + +_KERNEL_FACTORIES = [ + cls + for cls in get_subclasses(_PureKernelFactory) + if not cls.__name__.startswith("_") +] + + +@pytest.mark.parametrize( + "factory", + [pytest.param(cls(), id=cls.__name__) for cls in _KERNEL_FACTORIES], +) +def test_roundtrip(factory): + """A serialization roundtrip yields an equivalent object.""" + assert_roundtrip_consistency(factory) diff --git a/tests/test_kernel_factories.py b/tests/test_kernel_factories.py index 6a8acda6a6..45f3ccc109 100644 --- a/tests/test_kernel_factories.py +++ b/tests/test_kernel_factories.py @@ -15,8 +15,8 @@ from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.presets.baybe import ( BayBEKernelFactory, - BayBENumericalKernelFactory, - BayBETaskKernelFactory, + _BayBENumericalKernelFactory, + _BayBETaskKernelFactory, ) # A selector that accepts all parameters @@ -27,25 +27,25 @@ ("factory", "parameters", "error"), [ param( - BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL), + _BayBENumericalKernelFactory(parameter_selector=_SELECT_ALL), [TaskParameter("task", ["t1", "t2"])], IncompatibleSearchSpaceError, id="regular_rejects_task", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [CategoricalParameter("cat", ["a", "b"])], IncompatibleSearchSpaceError, id="task_rejects_categorical", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [NumericalDiscreteParameter("num", [1, 2, 3])], IncompatibleSearchSpaceError, id="task_rejects_numerical_discrete", ), param( - BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), + _BayBETaskKernelFactory(parameter_selector=_SELECT_ALL), [NumericalContinuousParameter("cont", (0, 1))], IncompatibleSearchSpaceError, id="task_rejects_numerical_continuous",