From b84af1d9c0cff7069ad5a2e103dd5df76915cb82 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 12 Feb 2026 09:07:41 +0100 Subject: [PATCH 01/20] Absorb index kernel construction into ICMKernelFactory --- .../gaussian_process/components/kernel.py | 44 +++++++++++++++++++ baybe/surrogates/gaussian_process/core.py | 9 +--- .../gaussian_process/presets/baybe.py | 44 ++++++++++++++++++- 3 files changed, 87 insertions(+), 10 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index c5c4c85e09..42704dfed9 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -4,7 +4,12 @@ from typing import TYPE_CHECKING +from attrs import define, field +from typing_extensions import override + from baybe.kernels.base import Kernel +from baybe.kernels.composite import ProductKernel +from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.components.generic import ( GPComponentFactory, PlainGPComponentFactory, @@ -12,6 +17,7 @@ if TYPE_CHECKING: from gpytorch.kernels import Kernel as GPyTorchKernel + from torch import Tensor KernelFactory = GPComponentFactory[Kernel | GPyTorchKernel] PlainKernelFactory = PlainGPComponentFactory[Kernel | GPyTorchKernel] @@ -19,3 +25,41 @@ # At runtime, we use only the BayBE type for serialization compatibility KernelFactory = GPComponentFactory[Kernel] PlainKernelFactory = PlainGPComponentFactory[Kernel] + + +@define +class ICMKernelFactory(KernelFactory): + """A kernel factory that constructs an ICM kernel for transfer learning. + + ICM: Intrinsic model of coregionalization + """ + + base_kernel_factory: KernelFactory = field(alias="base_kernel_or_factory") + """The factory for the base kernel operating on numerical input features.""" + + task_kernel_factory: KernelFactory = field(alias="task_kernel_or_factory") + """The factory for the task kernel operating on the task indices.""" + + @base_kernel_factory.default + def _default_base_kernel_factory(self) -> KernelFactory: + from baybe.surrogates.gaussian_process.presets.baybe import ( + BayBENumericalKernelFactory, + ) + + return BayBENumericalKernelFactory() + + @task_kernel_factory.default + def _default_task_kernel_factory(self) -> KernelFactory: + from baybe.surrogates.gaussian_process.presets.baybe import ( + BayBETaskKernelFactory, + ) + + return BayBETaskKernelFactory() + + @override + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> Kernel: + base_kernel = self.base_kernel_factory(searchspace, train_x, train_y) + task_kernel = self.task_kernel_factory(searchspace, train_x, train_y) + return ProductKernel([base_kernel, task_kernel]) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index bb2316d95d..d6cd81e959 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -12,7 +12,6 @@ from typing_extensions import Self, override from baybe.kernels.base import Kernel -from baybe.kernels.basic import IndexKernel from baybe.parameters.base import Parameter from baybe.searchspace.core import SearchSpace from baybe.surrogates.base import Surrogate @@ -226,19 +225,13 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: ### Kernel kernel = self.kernel_factory(context.searchspace, train_x, train_y) + raise NotImplementedError("The active dimensions are not yet implemented!") if isinstance(kernel, Kernel): kernel_num_dims = train_x.shape[-1] - context.n_task_dimensions kernel = kernel.to_gpytorch( ard_num_dims=kernel_num_dims, active_dims=context.numerical_indices, ) - if context.is_multitask: - assert context.task_idx is not None - task_kernel = IndexKernel( - num_tasks=context.n_tasks, - rank=context.n_tasks, # TODO: make controllable - ).to_gpytorch(active_dims=[context.task_idx]) - kernel = kernel * task_kernel ### Likelihood likelihood = self.likelihood_factory(context.searchspace, train_x, train_y) diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index a0867b45d7..bd7d86f880 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -2,14 +2,54 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +from attrs import define +from typing_extensions import override + +from baybe.kernels.base import Kernel +from baybe.kernels.basic import IndexKernel +from baybe.searchspace.core import SearchSpace +from baybe.surrogates.gaussian_process.components.kernel import KernelFactory from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( SmoothedEDBOKernelFactory, SmoothedEDBOLikelihoodFactory, ) -BayBEKernelFactory = SmoothedEDBOKernelFactory -"""The factory providing the default kernel for Gaussian process surrogates.""" +if TYPE_CHECKING: + from torch import Tensor + + +@define +class BayBEKernelFactory(KernelFactory): + """The default kernel factory for Gaussian process surrogates.""" + + @override + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> Kernel: + from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory + + is_multitask = searchspace.n_tasks > 0 + factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory + return factory()(searchspace, train_x, train_y) + + +BayBENumericalKernelFactory = SmoothedEDBOKernelFactory +"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 + + +@define +class BayBETaskKernelFactory(KernelFactory): + """The factory providing the default task kernel for Gaussian process surrogates.""" + + @override + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> Kernel: + return IndexKernel(num_tasks=searchspace.n_tasks, rank=searchspace.n_tasks) + BayBEMeanFactory = LazyConstantMeanFactory """The factory providing the default mean function for Gaussian process surrogates.""" From e365862de0d6ad9d8bd5d6d2e56b9d42f4bffd9a Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 12 Feb 2026 12:54:44 +0100 Subject: [PATCH 02/20] Add parameter selectors --- baybe/parameters/selector.py | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 baybe/parameters/selector.py diff --git a/baybe/parameters/selector.py b/baybe/parameters/selector.py new file mode 100644 index 0000000000..4de9e89856 --- /dev/null +++ b/baybe/parameters/selector.py @@ -0,0 +1,47 @@ +"""Parameter selectors.""" + +from abc import abstractmethod +from typing import Protocol + +from attrs import define, field +from attrs.validators import instance_of +from typing_extensions import override + +from baybe.parameters.base import Parameter + + +class ParameterSelectorProtocol(Protocol): + """Type protocol specifying the interface parameter selectors need to implement.""" + + def __call__(self, parameter: Parameter) -> bool: + """Determine if a parameter should be included in the selection.""" + + +@define +class ParameterSelector(ParameterSelectorProtocol): + """Base class for parameter selectors.""" + + exclude: bool = field(default=False, validator=instance_of(bool), kw_only=True) + """Boolean flag indicating whether invert the selection criterion.""" + + @abstractmethod + def _is_match(self, parameter: Parameter) -> bool: + """Determine if a parameter meets the selection criterion.""" + + @override + def __call__(self, parameter: Parameter) -> bool: + """Determine if a parameter should be included in the selection.""" + result = self._is_match(parameter) + return not result if self.exclude else result + + +@define +class TypeSelector(ParameterSelector): + """Select parameters by type.""" + + parameter_types: tuple[type[Parameter], ...] = field(converter=tuple) + """The parameter types to be selected.""" + + @override + def _is_match(self, parameter: Parameter) -> bool: + return isinstance(parameter, self.parameter_types) From dc9ee907f1ca97914eefbed450e6214684b0964d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 12 Feb 2026 13:00:36 +0100 Subject: [PATCH 03/20] Rename protocols --- .../gaussian_process/components/__init__.py | 12 ++++----- .../gaussian_process/components/generic.py | 8 +++--- .../gaussian_process/components/kernel.py | 16 ++++++------ .../gaussian_process/components/likelihood.py | 6 ++--- .../gaussian_process/components/mean.py | 8 +++--- baybe/surrogates/gaussian_process/core.py | 25 ++++++++++++------- .../gaussian_process/presets/baybe.py | 6 ++--- .../gaussian_process/presets/edbo.py | 12 ++++++--- .../gaussian_process/presets/edbo_smoothed.py | 12 ++++++--- 9 files changed, 60 insertions(+), 45 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/__init__.py b/baybe/surrogates/gaussian_process/components/__init__.py index 1a7131b825..a9e11f4afe 100644 --- a/baybe/surrogates/gaussian_process/components/__init__.py +++ b/baybe/surrogates/gaussian_process/components/__init__.py @@ -1,28 +1,28 @@ """Gaussian process surrogate components.""" from baybe.surrogates.gaussian_process.components.kernel import ( - KernelFactory, + KernelFactoryProtocol, PlainKernelFactory, ) from baybe.surrogates.gaussian_process.components.likelihood import ( - LikelihoodFactory, + LikelihoodFactoryProtocol, PlainLikelihoodFactory, ) from baybe.surrogates.gaussian_process.components.mean import ( LazyConstantMeanFactory, - MeanFactory, + MeanFactoryProtocol, PlainMeanFactory, ) __all__ = [ # Kernel - "KernelFactory", + "KernelFactoryProtocol", "PlainKernelFactory", # Likelihood - "LikelihoodFactory", + "LikelihoodFactoryProtocol", "PlainLikelihoodFactory", # Mean "LazyConstantMeanFactory", - "MeanFactory", + "MeanFactoryProtocol", "PlainMeanFactory", ] diff --git a/baybe/surrogates/gaussian_process/components/generic.py b/baybe/surrogates/gaussian_process/components/generic.py index 1f4a171e1d..c1977146e5 100644 --- a/baybe/surrogates/gaussian_process/components/generic.py +++ b/baybe/surrogates/gaussian_process/components/generic.py @@ -95,7 +95,7 @@ def _validate_component(instance: Any, attribute: Attribute, value: Any) -> None ) -class GPComponentFactory(Protocol, Generic[_T_co]): +class GPComponentFactoryProtocol(Protocol, Generic[_T_co]): """A protocol defining the interface expected for GP component factories.""" def __call__( @@ -105,7 +105,7 @@ def __call__( @define(frozen=True) -class PlainGPComponentFactory(GPComponentFactory[_T_co], SerialMixin): +class PlainGPComponentFactory(GPComponentFactoryProtocol[_T_co], SerialMixin): """A trivial factory that returns a fixed pre-defined component upon request.""" component: _T_co = field(validator=_validate_component) @@ -119,11 +119,11 @@ def __call__( def to_component_factory( - obj: GPComponent | GPComponentFactory, + obj: GPComponent | GPComponentFactoryProtocol, /, *, component_type: GPComponentType | None = None, -) -> GPComponentFactory: +) -> GPComponentFactoryProtocol: """Wrap a component into a plain component factory (with factory passthrough). Args: diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 42704dfed9..0fa7cd490f 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -11,7 +11,7 @@ from baybe.kernels.composite import ProductKernel from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.components.generic import ( - GPComponentFactory, + GPComponentFactoryProtocol, PlainGPComponentFactory, ) @@ -19,29 +19,29 @@ from gpytorch.kernels import Kernel as GPyTorchKernel from torch import Tensor - KernelFactory = GPComponentFactory[Kernel | GPyTorchKernel] + KernelFactoryProtocol = GPComponentFactoryProtocol[Kernel | GPyTorchKernel] PlainKernelFactory = PlainGPComponentFactory[Kernel | GPyTorchKernel] else: # At runtime, we use only the BayBE type for serialization compatibility - KernelFactory = GPComponentFactory[Kernel] + KernelFactoryProtocol = GPComponentFactoryProtocol[Kernel] PlainKernelFactory = PlainGPComponentFactory[Kernel] @define -class ICMKernelFactory(KernelFactory): +class ICMKernelFactory(KernelFactoryProtocol): """A kernel factory that constructs an ICM kernel for transfer learning. ICM: Intrinsic model of coregionalization """ - base_kernel_factory: KernelFactory = field(alias="base_kernel_or_factory") + base_kernel_factory: KernelFactoryProtocol = field(alias="base_kernel_or_factory") """The factory for the base kernel operating on numerical input features.""" - task_kernel_factory: KernelFactory = field(alias="task_kernel_or_factory") + task_kernel_factory: KernelFactoryProtocol = field(alias="task_kernel_or_factory") """The factory for the task kernel operating on the task indices.""" @base_kernel_factory.default - def _default_base_kernel_factory(self) -> KernelFactory: + def _default_base_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( BayBENumericalKernelFactory, ) @@ -49,7 +49,7 @@ def _default_base_kernel_factory(self) -> KernelFactory: return BayBENumericalKernelFactory() @task_kernel_factory.default - def _default_task_kernel_factory(self) -> KernelFactory: + def _default_task_kernel_factory(self) -> KernelFactoryProtocol: from baybe.surrogates.gaussian_process.presets.baybe import ( BayBETaskKernelFactory, ) diff --git a/baybe/surrogates/gaussian_process/components/likelihood.py b/baybe/surrogates/gaussian_process/components/likelihood.py index 29ae203f19..2f8007ed59 100644 --- a/baybe/surrogates/gaussian_process/components/likelihood.py +++ b/baybe/surrogates/gaussian_process/components/likelihood.py @@ -5,16 +5,16 @@ from typing import TYPE_CHECKING, Any from baybe.surrogates.gaussian_process.components.generic import ( - GPComponentFactory, + GPComponentFactoryProtocol, PlainGPComponentFactory, ) if TYPE_CHECKING: from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood - LikelihoodFactory = GPComponentFactory[GPyTorchLikelihood] + LikelihoodFactoryProtocol = GPComponentFactoryProtocol[GPyTorchLikelihood] PlainLikelihoodFactory = PlainGPComponentFactory[GPyTorchLikelihood] else: # At runtime, we avoid loading GPyTorch eagerly for performance reasons - LikelihoodFactory = GPComponentFactory[Any] + LikelihoodFactoryProtocol = GPComponentFactoryProtocol[Any] PlainLikelihoodFactory = PlainGPComponentFactory[Any] diff --git a/baybe/surrogates/gaussian_process/components/mean.py b/baybe/surrogates/gaussian_process/components/mean.py index c0db8b1d7b..e4b7c70cf3 100644 --- a/baybe/surrogates/gaussian_process/components/mean.py +++ b/baybe/surrogates/gaussian_process/components/mean.py @@ -9,7 +9,7 @@ from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.components.generic import ( - GPComponentFactory, + GPComponentFactoryProtocol, PlainGPComponentFactory, ) @@ -17,16 +17,16 @@ from gpytorch.means import Mean as GPyTorchMean from torch import Tensor - MeanFactory = GPComponentFactory[GPyTorchMean] + MeanFactoryProtocol = GPComponentFactoryProtocol[GPyTorchMean] PlainMeanFactory = PlainGPComponentFactory[GPyTorchMean] else: # At runtime, we avoid loading GPyTorch eagerly for performance reasons - MeanFactory = GPComponentFactory[Any] + MeanFactoryProtocol = GPComponentFactoryProtocol[Any] PlainMeanFactory = PlainGPComponentFactory[Any] @define -class LazyConstantMeanFactory(MeanFactory): +class LazyConstantMeanFactory(MeanFactoryProtocol): """A factory providing constant mean functions using lazy loading.""" @override diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index d6cd81e959..156b412dbb 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -20,10 +20,12 @@ to_component_factory, ) from baybe.surrogates.gaussian_process.components.kernel import ( - KernelFactory, + KernelFactoryProtocol, ) -from baybe.surrogates.gaussian_process.components.likelihood import LikelihoodFactory -from baybe.surrogates.gaussian_process.components.mean import MeanFactory +from baybe.surrogates.gaussian_process.components.likelihood import ( + LikelihoodFactoryProtocol, +) +from baybe.surrogates.gaussian_process.components.mean import MeanFactoryProtocol from baybe.surrogates.gaussian_process.presets import ( GaussianProcessPreset, ) @@ -112,7 +114,7 @@ class GaussianProcessSurrogate(Surrogate): supports_transfer_learning: ClassVar[bool] = True # See base class. - kernel_factory: KernelFactory = field( + kernel_factory: KernelFactoryProtocol = field( alias="kernel_or_factory", factory=BayBEKernelFactory, converter=partial(to_component_factory, component_type=GPComponentType.KERNEL), # type: ignore[misc] @@ -125,7 +127,7 @@ class GaussianProcessSurrogate(Surrogate): * :class:`gpytorch.kernels.Kernel` """ - mean_factory: MeanFactory = field( + mean_factory: MeanFactoryProtocol = field( alias="mean_or_factory", factory=BayBEMeanFactory, converter=partial(to_component_factory, component_type=GPComponentType.MEAN), # type: ignore[misc] @@ -137,7 +139,7 @@ class GaussianProcessSurrogate(Surrogate): * :class:`gpytorch.means.Mean` """ - likelihood_factory: LikelihoodFactory = field( + likelihood_factory: LikelihoodFactoryProtocol = field( alias="likelihood_or_factory", factory=BayBELikelihoodFactory, converter=partial( # type: ignore[misc] @@ -160,9 +162,14 @@ class GaussianProcessSurrogate(Surrogate): def from_preset( cls, preset: GaussianProcessPreset | str, - kernel_or_factory: KernelFactory | Kernel | GPyTorchKernel | None = None, - mean_or_factory: MeanFactory | GPyTorchMean | None = None, - likelihood_or_factory: LikelihoodFactory | GPyTorchLikelihood | None = None, + kernel_or_factory: KernelFactoryProtocol + | Kernel + | GPyTorchKernel + | None = None, + mean_or_factory: MeanFactoryProtocol | GPyTorchMean | None = None, + likelihood_or_factory: LikelihoodFactoryProtocol + | GPyTorchLikelihood + | None = None, ) -> Self: """Create a Gaussian process surrogate from one of the defined presets.""" preset = GaussianProcessPreset(preset) diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index bd7d86f880..4ea36d6113 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -10,7 +10,7 @@ from baybe.kernels.base import Kernel from baybe.kernels.basic import IndexKernel from baybe.searchspace.core import SearchSpace -from baybe.surrogates.gaussian_process.components.kernel import KernelFactory +from baybe.surrogates.gaussian_process.components.kernel import KernelFactoryProtocol from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( SmoothedEDBOKernelFactory, @@ -22,7 +22,7 @@ @define -class BayBEKernelFactory(KernelFactory): +class BayBEKernelFactory(KernelFactoryProtocol): """The default kernel factory for Gaussian process surrogates.""" @override @@ -41,7 +41,7 @@ def __call__( @define -class BayBETaskKernelFactory(KernelFactory): +class BayBETaskKernelFactory(KernelFactoryProtocol): """The factory providing the default task kernel for Gaussian process surrogates.""" @override diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 2e1321e208..48e4ced508 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -16,8 +16,12 @@ from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete -from baybe.surrogates.gaussian_process.components.kernel import KernelFactory -from baybe.surrogates.gaussian_process.components.likelihood import LikelihoodFactory +from baybe.surrogates.gaussian_process.components.kernel import ( + KernelFactoryProtocol, +) +from baybe.surrogates.gaussian_process.components.likelihood import ( + LikelihoodFactoryProtocol, +) from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory if TYPE_CHECKING: @@ -48,7 +52,7 @@ def _contains_encoding( @define -class EDBOKernelFactory(KernelFactory): +class EDBOKernelFactory(KernelFactoryProtocol): """A factory providing EDBO kernels. References: @@ -112,7 +116,7 @@ def __call__( @define -class EDBOLikelihoodFactory(LikelihoodFactory): +class EDBOLikelihoodFactory(LikelihoodFactoryProtocol): """A factory providing EDBO likelihoods. References: diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index b7027028d7..b70d546633 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -13,8 +13,12 @@ from baybe.kernels.composite import ScaleKernel from baybe.parameters import TaskParameter from baybe.priors.basic import GammaPrior -from baybe.surrogates.gaussian_process.components.kernel import KernelFactory -from baybe.surrogates.gaussian_process.components.likelihood import LikelihoodFactory +from baybe.surrogates.gaussian_process.components.kernel import ( + KernelFactoryProtocol, +) +from baybe.surrogates.gaussian_process.components.likelihood import ( + LikelihoodFactoryProtocol, +) from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory if TYPE_CHECKING: @@ -30,7 +34,7 @@ @define -class SmoothedEDBOKernelFactory(KernelFactory): +class SmoothedEDBOKernelFactory(KernelFactoryProtocol): """A factory providing smoothed versions of EDBO kernels. Takes the low and high dimensional limits of @@ -76,7 +80,7 @@ def __call__( @define -class SmoothedEDBOLikelihoodFactory(LikelihoodFactory): +class SmoothedEDBOLikelihoodFactory(LikelihoodFactoryProtocol): """A factory providing smoothed versions of EDBO likelihoods. Takes the low and high dimensional limits of From 63e64fa6a925022f3a2b4b92852e3736b1ea1988 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 12 Feb 2026 13:13:35 +0100 Subject: [PATCH 04/20] Implement active kernel dimension control --- baybe/kernels/base.py | 53 +++++++++++++++---- .../gaussian_process/components/kernel.py | 39 ++++++++++++-- baybe/surrogates/gaussian_process/core.py | 7 +-- .../gaussian_process/presets/baybe.py | 18 +++++-- .../gaussian_process/presets/edbo.py | 12 +++-- .../gaussian_process/presets/edbo_smoothed.py | 12 +++-- 6 files changed, 108 insertions(+), 33 deletions(-) diff --git a/baybe/kernels/base.py b/baybe/kernels/base.py index 8ccc7c75fc..7df1159b31 100644 --- a/baybe/kernels/base.py +++ b/baybe/kernels/base.py @@ -4,13 +4,17 @@ import gc from abc import ABC -from collections.abc import Sequence +from itertools import chain from typing import TYPE_CHECKING, Any -from attrs import define +from attrs import define, field, fields +from attrs.converters import optional as optional_c +from attrs.validators import deep_iterable, instance_of +from attrs.validators import optional as optional_v from baybe.exceptions import UnmatchedAttributeError from baybe.priors.base import Prior +from baybe.searchspace.core import SearchSpace from baybe.serialization.mixin import SerialMixin from baybe.settings import active_settings from baybe.utils.basic import get_baseclasses, match_attributes @@ -25,6 +29,14 @@ class Kernel(ABC, SerialMixin): """Abstract base class for all kernels.""" + parameter_names: tuple[str, ...] | None = field( + default=None, + converter=optional_c(tuple), + validator=optional_v(deep_iterable(member_validator=instance_of(str))), + kw_only=True, + ) + """An optional set of names specifiying the parameters the kernel should act on.""" + def to_factory(self) -> PlainKernelFactory: """Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501 from baybe.surrogates.gaussian_process.components.kernel import ( @@ -35,21 +47,41 @@ def to_factory(self) -> PlainKernelFactory: def to_gpytorch( self, + searchspace: SearchSpace, *, - ard_num_dims: int | None = None, batch_shape: torch.Size | None = None, - active_dims: Sequence[int] | None = None, ): """Create the gpytorch representation of the kernel.""" import gpytorch.kernels + # Extract the active dimensions for the gpytorch kernel + if self.parameter_names is not None: + active_dims = list( + chain( + *[ + searchspace.get_comp_rep_parameter_indices(name) + for name in self.parameter_names + ] + ) + ) + else: + active_dims = None + + # We use automatic relevance determination for all (non-composite) kernels + if isinstance(self, CompositeKernel): + ard_num_dims = None + else: + ard_num_dims = ( + len(active_dims) + if active_dims is not None + else len(searchspace.comp_rep_columns) + ) + # Extract keywords with non-default values. This is required since gpytorch # makes use of kwargs, i.e. differentiates if certain keywords are explicitly # passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)` # fails if we explicitly pass `ard_num_dims=None`. - kw: dict[str, Any] = dict( - ard_num_dims=ard_num_dims, batch_shape=batch_shape, active_dims=active_dims - ) + kw: dict[str, Any] = dict(batch_shape=batch_shape, active_dims=active_dims) kw = {k: v for k, v in kw.items() if v is not None} # Get corresponding gpytorch kernel class and its base classes @@ -79,7 +111,8 @@ def to_gpytorch( # in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured). # Exception: initial values are not used during construction but are set # on the created object (see code at the end of the method). - missing = set(unmatched) - set(kernel_attrs) + exclude = {fields(Kernel).parameter_names.name} + missing = set(unmatched) - set(kernel_attrs) - exclude if leftover := {m for m in missing if not m.endswith("_initial_value")}: raise UnmatchedAttributeError(leftover) @@ -92,7 +125,7 @@ def to_gpytorch( # Convert specified inner kernels to gpytorch, if provided kernel_dict = { - key: value.to_gpytorch(**kw) + key: value.to_gpytorch(searchspace, **kw) for key, value in kernel_attrs.items() if isinstance(value, Kernel) } @@ -100,7 +133,7 @@ def to_gpytorch( # Create the kernel with all its inner gpytorch objects kernel_attrs.update(kernel_dict) kernel_attrs.update(prior_dict) - gpytorch_kernel = kernel_cls(**kernel_attrs, **kw) + gpytorch_kernel = kernel_cls(**kernel_attrs, ard_num_dims=ard_num_dims, **kw) # If the kernel has a lengthscale, set its initial value if kernel_cls.has_lengthscale: diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 0fa7cd490f..337459f2aa 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -2,13 +2,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from attrs import define, field from typing_extensions import override from baybe.kernels.base import Kernel from baybe.kernels.composite import ProductKernel +from baybe.parameters.categorical import TaskParameter +from baybe.parameters.selector import ( + ParameterSelectorProtocol, + TypeSelector, +) from baybe.searchspace.core import SearchSpace from baybe.surrogates.gaussian_process.components.generic import ( GPComponentFactoryProtocol, @@ -27,6 +32,34 @@ PlainKernelFactory = PlainGPComponentFactory[Kernel] +@define +class KernelFactory(KernelFactoryProtocol): + """Base class for kernel factories.""" + + # For internal use only: sanity check mechanism to remind developers of new + # factories to actually use the parameter selector when it is provided + # TODO: Perhaps we can find a more elegant way to enforce this by design + _uses_parameter_names: ClassVar[bool] = False + + parameter_selector: ParameterSelectorProtocol | None = field(default=None) + """An optional selector to specify which parameters are considered by the kernel.""" + + def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None: + """Get the names of the parameters to be considered by the kernel.""" + if self.parameter_selector is None: + return None + + return tuple( + p.name for p in searchspace.parameters if self.parameter_selector(p) + ) + + def __attrs_post_init__(self): + # This helps to ensure that new factories actually use the parameter selector + # by requiring the developer to explicitly set the flag to `True` + if self.parameter_selector is not None: + assert self._uses_parameter_names + + @define class ICMKernelFactory(KernelFactoryProtocol): """A kernel factory that constructs an ICM kernel for transfer learning. @@ -46,7 +79,7 @@ def _default_base_kernel_factory(self) -> KernelFactoryProtocol: BayBENumericalKernelFactory, ) - return BayBENumericalKernelFactory() + return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True)) @task_kernel_factory.default def _default_task_kernel_factory(self) -> KernelFactoryProtocol: @@ -54,7 +87,7 @@ def _default_task_kernel_factory(self) -> KernelFactoryProtocol: BayBETaskKernelFactory, ) - return BayBETaskKernelFactory() + return BayBETaskKernelFactory(TypeSelector((TaskParameter,))) @override def __call__( diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 156b412dbb..79664d74aa 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -232,13 +232,8 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: ### Kernel kernel = self.kernel_factory(context.searchspace, train_x, train_y) - raise NotImplementedError("The active dimensions are not yet implemented!") if isinstance(kernel, Kernel): - kernel_num_dims = train_x.shape[-1] - context.n_task_dimensions - kernel = kernel.to_gpytorch( - ard_num_dims=kernel_num_dims, - active_dims=context.numerical_indices, - ) + kernel = kernel.to_gpytorch(searchspace=context.searchspace) ### Likelihood likelihood = self.likelihood_factory(context.searchspace, train_x, train_y) diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 4ea36d6113..41b2350c38 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from attrs import define from typing_extensions import override @@ -10,7 +10,10 @@ from baybe.kernels.base import Kernel from baybe.kernels.basic import IndexKernel from baybe.searchspace.core import SearchSpace -from baybe.surrogates.gaussian_process.components.kernel import KernelFactoryProtocol +from baybe.surrogates.gaussian_process.components.kernel import ( + KernelFactory, + KernelFactoryProtocol, +) from baybe.surrogates.gaussian_process.components.mean import LazyConstantMeanFactory from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( SmoothedEDBOKernelFactory, @@ -41,14 +44,21 @@ def __call__( @define -class BayBETaskKernelFactory(KernelFactoryProtocol): +class BayBETaskKernelFactory(KernelFactory): """The factory providing the default task kernel for Gaussian process surrogates.""" + _uses_parameter_names: ClassVar[bool] = True + # See base class. + @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> Kernel: - return IndexKernel(num_tasks=searchspace.n_tasks, rank=searchspace.n_tasks) + return IndexKernel( + num_tasks=searchspace.n_tasks, + rank=searchspace.n_tasks, + parameter_names=self.get_parameter_names(searchspace), + ) BayBEMeanFactory = LazyConstantMeanFactory diff --git a/baybe/surrogates/gaussian_process/presets/edbo.py b/baybe/surrogates/gaussian_process/presets/edbo.py index 48e4ced508..b248e8e40c 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo.py +++ b/baybe/surrogates/gaussian_process/presets/edbo.py @@ -4,7 +4,7 @@ import gc from collections.abc import Collection -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from attrs import define from typing_extensions import override @@ -16,9 +16,7 @@ from baybe.parameters.substance import SubstanceParameter from baybe.priors.basic import GammaPrior from baybe.searchspace.discrete import SubspaceDiscrete -from baybe.surrogates.gaussian_process.components.kernel import ( - KernelFactoryProtocol, -) +from baybe.surrogates.gaussian_process.components.kernel import KernelFactory from baybe.surrogates.gaussian_process.components.likelihood import ( LikelihoodFactoryProtocol, ) @@ -52,7 +50,7 @@ def _contains_encoding( @define -class EDBOKernelFactory(KernelFactoryProtocol): +class EDBOKernelFactory(KernelFactory): """A factory providing EDBO kernels. References: @@ -60,6 +58,9 @@ class EDBOKernelFactory(KernelFactoryProtocol): * https://doi.org/10.1038/s41586-021-03213-y """ + _uses_parameter_names: ClassVar[bool] = True + # See base class. + @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor @@ -105,6 +106,7 @@ def __call__( nu=2.5, lengthscale_prior=lengthscale_prior, lengthscale_initial_value=lengthscale_initial_value, + parameter_names=self.get_parameter_names(searchspace), ), outputscale_prior=outputscale_prior, outputscale_initial_value=outputscale_initial_value, diff --git a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py index b70d546633..3afdec2c8a 100644 --- a/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py +++ b/baybe/surrogates/gaussian_process/presets/edbo_smoothed.py @@ -3,7 +3,7 @@ from __future__ import annotations import gc -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar import numpy as np from attrs import define @@ -13,9 +13,7 @@ from baybe.kernels.composite import ScaleKernel from baybe.parameters import TaskParameter from baybe.priors.basic import GammaPrior -from baybe.surrogates.gaussian_process.components.kernel import ( - KernelFactoryProtocol, -) +from baybe.surrogates.gaussian_process.components.kernel import KernelFactory from baybe.surrogates.gaussian_process.components.likelihood import ( LikelihoodFactoryProtocol, ) @@ -34,7 +32,7 @@ @define -class SmoothedEDBOKernelFactory(KernelFactoryProtocol): +class SmoothedEDBOKernelFactory(KernelFactory): """A factory providing smoothed versions of EDBO kernels. Takes the low and high dimensional limits of @@ -42,6 +40,9 @@ class SmoothedEDBOKernelFactory(KernelFactoryProtocol): and interpolates the prior moments linearly in between. """ + _uses_parameter_names: ClassVar[bool] = True + # See base class. + @override def __call__( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor @@ -69,6 +70,7 @@ def __call__( nu=2.5, lengthscale_prior=lengthscale_prior, lengthscale_initial_value=lengthscale_initial_value, + parameter_names=self.get_parameter_names(searchspace), ), outputscale_prior=outputscale_prior, outputscale_initial_value=outputscale_initial_value, From da8ffc0754f274d1c06a35e0c3645b1cd19c8eae Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 12 Feb 2026 13:47:55 +0100 Subject: [PATCH 05/20] Move parameter_names attribute down to BasicKernel subclass --- baybe/kernels/base.py | 85 ++++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/baybe/kernels/base.py b/baybe/kernels/base.py index 7df1159b31..c29abea813 100644 --- a/baybe/kernels/base.py +++ b/baybe/kernels/base.py @@ -3,21 +3,22 @@ from __future__ import annotations import gc -from abc import ABC +from abc import ABC, abstractmethod from itertools import chain from typing import TYPE_CHECKING, Any -from attrs import define, field, fields +from attrs import define, field from attrs.converters import optional as optional_c from attrs.validators import deep_iterable, instance_of from attrs.validators import optional as optional_v +from typing_extensions import override from baybe.exceptions import UnmatchedAttributeError from baybe.priors.base import Prior from baybe.searchspace.core import SearchSpace from baybe.serialization.mixin import SerialMixin from baybe.settings import active_settings -from baybe.utils.basic import get_baseclasses, match_attributes +from baybe.utils.basic import classproperty, get_baseclasses, match_attributes if TYPE_CHECKING: import torch @@ -29,13 +30,10 @@ class Kernel(ABC, SerialMixin): """Abstract base class for all kernels.""" - parameter_names: tuple[str, ...] | None = field( - default=None, - converter=optional_c(tuple), - validator=optional_v(deep_iterable(member_validator=instance_of(str))), - kw_only=True, - ) - """An optional set of names specifiying the parameters the kernel should act on.""" + @classproperty + def _whitelisted_attributes(cls) -> frozenset[str]: + """Attribute names to exclude from gpytorch matching.""" + return frozenset() def to_factory(self) -> PlainKernelFactory: """Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501 @@ -45,6 +43,10 @@ def to_factory(self) -> PlainKernelFactory: return PlainKernelFactory(self) + @abstractmethod + def _get_dimensions(self, searchspace: SearchSpace) -> tuple[tuple[int, ...], int]: + """Get the active dimensions and the number of ARD dimensions.""" + def to_gpytorch( self, searchspace: SearchSpace, @@ -54,28 +56,7 @@ def to_gpytorch( """Create the gpytorch representation of the kernel.""" import gpytorch.kernels - # Extract the active dimensions for the gpytorch kernel - if self.parameter_names is not None: - active_dims = list( - chain( - *[ - searchspace.get_comp_rep_parameter_indices(name) - for name in self.parameter_names - ] - ) - ) - else: - active_dims = None - - # We use automatic relevance determination for all (non-composite) kernels - if isinstance(self, CompositeKernel): - ard_num_dims = None - else: - ard_num_dims = ( - len(active_dims) - if active_dims is not None - else len(searchspace.comp_rep_columns) - ) + active_dims, ard_num_dims = self._get_dimensions(searchspace) # Extract keywords with non-default values. This is required since gpytorch # makes use of kwargs, i.e. differentiates if certain keywords are explicitly @@ -111,8 +92,7 @@ def to_gpytorch( # in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured). # Exception: initial values are not used during construction but are set # on the created object (see code at the end of the method). - exclude = {fields(Kernel).parameter_names.name} - missing = set(unmatched) - set(kernel_attrs) - exclude + missing = set(unmatched) - set(kernel_attrs) - self._whitelisted_attributes if leftover := {m for m in missing if not m.endswith("_initial_value")}: raise UnmatchedAttributeError(leftover) @@ -156,11 +136,48 @@ def to_gpytorch( class BasicKernel(Kernel, ABC): """Abstract base class for all basic kernels.""" + parameter_names: tuple[str, ...] | None = field( + default=None, + converter=optional_c(tuple), + validator=optional_v(deep_iterable(member_validator=instance_of(str))), + kw_only=True, + ) + """An optional set of names specifiying the parameters the kernel should act on.""" + + @override + @classproperty + def _whitelisted_attributes(cls) -> frozenset[str]: + return frozenset({"parameter_names"}) + + def _get_dimensions(self, searchspace): + if self.parameter_names is None: + active_dims = None + else: + active_dims = list( + chain( + *[ + searchspace.get_comp_rep_parameter_indices(name) + for name in self.parameter_names + ] + ) + ) + + # We use automatic relevance determination for all kernels + ard_num_dims = ( + len(active_dims) + if active_dims is not None + else len(searchspace.comp_rep_columns) + ) + return active_dims, ard_num_dims + @define(frozen=True) class CompositeKernel(Kernel, ABC): """Abstract base class for all composite kernels.""" + def _get_dimensions(self, searchspace): + return None, None + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() From 86e110bc6df9d490abea22c791104860f858c1d2 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 12 Feb 2026 15:01:05 +0100 Subject: [PATCH 06/20] Fix condition in BayBEKernelFactory --- baybe/surrogates/gaussian_process/presets/baybe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/presets/baybe.py b/baybe/surrogates/gaussian_process/presets/baybe.py index 41b2350c38..858fe0f3d7 100644 --- a/baybe/surrogates/gaussian_process/presets/baybe.py +++ b/baybe/surrogates/gaussian_process/presets/baybe.py @@ -34,7 +34,7 @@ def __call__( ) -> Kernel: from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory - is_multitask = searchspace.n_tasks > 0 + is_multitask = searchspace.task_idx is not None factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory return factory()(searchspace, train_x, train_y) From 3714c5ffca34f214bfa0407665668e9046b949e2 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 12 Feb 2026 15:04:42 +0100 Subject: [PATCH 07/20] Add deprecation mechanism for breaking change in kernel logic --- baybe/settings.py | 12 ++++++- baybe/surrogates/gaussian_process/core.py | 44 +++++++++++++++++++++-- tests/test_deprecations.py | 37 +++++++++++++++++++ 3 files changed, 90 insertions(+), 3 deletions(-) diff --git a/baybe/settings.py b/baybe/settings.py index f05ff913bb..36393fe429 100644 --- a/baybe/settings.py +++ b/baybe/settings.py @@ -20,7 +20,7 @@ from baybe._optional.info import FPSAMPLE_INSTALLED, POLARS_INSTALLED from baybe.exceptions import NotAllowedError, OptionalImportError from baybe.utils.basic import classproperty -from baybe.utils.boolean import AutoBool, to_bool +from baybe.utils.boolean import AutoBool, strtobool, to_bool from baybe.utils.random import _RandomState if TYPE_CHECKING: @@ -50,6 +50,16 @@ def _validate_whitelist_env_vars(vars: dict[str, str], /) -> None: f"Allowed values for 'BAYBE_TEST_ENV' are " f"'CORETEST', 'FULLTEST', and 'GPUTEST'. Given: '{value}'" ) + + if (value := vars.pop("BAYBE_DISABLE_CUSTOM_KERNEL_WARNING", None)) is not None: + try: + strtobool(value) + except ValueError as ex: + raise ValueError( + f"Invalid value for 'BAYBE_DISABLE_CUSTOM_KERNEL_WARNING'. " + f"Expected a truthy value to disable the error, but got '{value}'." + ) from ex + if vars: raise RuntimeError(f"Unknown 'BAYBE_*' environment variables: {set(vars)}") diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 79664d74aa..25c3355728 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -4,13 +4,16 @@ import gc import importlib +import os from functools import partial from typing import TYPE_CHECKING, ClassVar -from attrs import define, field +from attrs import Converter, define, field +from attrs.converters import pipe from attrs.validators import instance_of from typing_extensions import Self, override +from baybe.exceptions import DeprecationError from baybe.kernels.base import Kernel from baybe.parameters.base import Parameter from baybe.searchspace.core import SearchSpace @@ -20,6 +23,7 @@ to_component_factory, ) from baybe.surrogates.gaussian_process.components.kernel import ( + ICMKernelFactory, KernelFactoryProtocol, ) from baybe.surrogates.gaussian_process.components.likelihood import ( @@ -34,6 +38,7 @@ BayBELikelihoodFactory, BayBEMeanFactory, ) +from baybe.utils.boolean import strtobool from baybe.utils.conversion import to_string if TYPE_CHECKING: @@ -92,6 +97,17 @@ def numerical_indices(self) -> list[int]: ] +def _mark_custom_kernel( + value: Kernel | KernelFactoryProtocol | None, self: GaussianProcessSurrogate +) -> Kernel | KernelFactoryProtocol: + """Mark the surrogate as using a custom kernel (for deprecation purposes).""" + if value is None: + return BayBEKernelFactory() + + self._custom_kernel = True + return value + + @define class GaussianProcessSurrogate(Surrogate): """A Gaussian process surrogate model.""" @@ -114,10 +130,16 @@ class GaussianProcessSurrogate(Surrogate): supports_transfer_learning: ClassVar[bool] = True # See base class. + _custom_kernel: bool = field(init=False, default=False, repr=False, eq=False) + # For deprecation only! + kernel_factory: KernelFactoryProtocol = field( alias="kernel_or_factory", + converter=pipe( # type: ignore[misc] + Converter(_mark_custom_kernel, takes_self=True), + partial(to_component_factory, component_type=GPComponentType.KERNEL), + ), factory=BayBEKernelFactory, - converter=partial(to_component_factory, component_type=GPComponentType.KERNEL), # type: ignore[misc] ) """The factory used to create the kernel for the Gaussian process. @@ -218,6 +240,24 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: assert self._searchspace is not None # provided by base class context = _ModelContext(self._searchspace) + if ( + context.is_multitask + and self._custom_kernel + and not strtobool(os.getenv("BAYBE_DISABLE_CUSTOM_KERNEL_WARNING", "False")) + ): + raise DeprecationError( + f"We noticed that you are using a custom kernel architecture on a " + f"search space that includes a task parameter. Please note that the " + f"kernel logic of '{GaussianProcessSurrogate.__name__}' has changed: " + f"the task kernel is no longer automatically added and must now be " + f"explicitly included in your kernel (factory). " + f"The '{ICMKernelFactory.__name__}' provides a suitable interface " + f"for this purpose. If you are aware of this breaking change and wish " + f"to proceed with your current kernel architecture, you can disable " + f"this error by setting the 'BAYBE_DISABLE_CUSTOM_KERNEL_WARNING' " + f"environment variable to a truthy value." + ) + ### Input/output scaling # NOTE: For GPs, we let BoTorch handle scaling (see [Scaling Workaround] above) input_transform = Normalize( diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 5268ae2013..20c27366b7 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -2,6 +2,7 @@ import os import warnings +from contextlib import nullcontext from itertools import pairwise from pathlib import Path from unittest.mock import patch @@ -21,8 +22,10 @@ ) from baybe.constraints.base import Constraint from baybe.exceptions import DeprecationError +from baybe.kernels.basic import MaternKernel from baybe.objectives.desirability import DesirabilityObjective from baybe.objectives.single import SingleTargetObjective +from baybe.parameters.categorical import TaskParameter from baybe.parameters.enum import SubstanceEncoding from baybe.parameters.numerical import ( NumericalDiscreteParameter, @@ -32,9 +35,11 @@ BotorchRecommender, ) from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender +from baybe.searchspace.core import SearchSpace from baybe.searchspace.discrete import SubspaceDiscrete from baybe.searchspace.validation import get_transform_parameters from baybe.settings import Settings +from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate from baybe.targets import NumericalTarget from baybe.targets import NumericalTarget as ModernTarget from baybe.targets._deprecated import ( @@ -45,6 +50,7 @@ ) from baybe.targets.binary import BinaryTarget from baybe.transformations.basic import AffineTransformation +from baybe.utils.dataframe import create_fake_input from baybe.utils.random import set_random_seed, temporary_seed @@ -557,3 +563,34 @@ def test_deprecated_cache_environment_variables(monkeypatch, value: str, expecte DeprecationWarning, match="'BAYBE_CACHE_DIR' has been deprecated" ): assert Settings(restore_environment=True).cache_directory == expected + + +@pytest.mark.parametrize("custom", [False, True], ids=["default", "custom"]) +@pytest.mark.parametrize("env", [False, True], ids=["no_env", "env"]) +@pytest.mark.parametrize( + "task", + [False, True], +) +def test_multitask_kernel_deprecation(monkeypatch, custom: bool, env: bool, task: bool): + """Providing a custom kernel in a transfer learning context raises a deprecation + error unless explicitly disabled via environment variable.""" # noqa + parameters = [NumericalDiscreteParameter("p", [0, 1])] + if task: + parameters.append(TaskParameter("task", ["a", "b"])) + searchspace = SearchSpace.from_product(parameters) + objective = NumericalTarget("t").to_objective() + measurements = create_fake_input( + searchspace.parameters, objective.targets, n_rows=2 + ) + kernel = MaternKernel() if custom else None + + if env: + monkeypatch.setenv("BAYBE_DISABLE_CUSTOM_KERNEL_WARNING", "True") + + context = ( + pytest.raises(DeprecationError) + if task and custom and not env + else nullcontext() + ) + with context: + GaussianProcessSurrogate(kernel).fit(searchspace, objective, measurements) From 66b3dfb5e9e13ff79890939551967836c61a399d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 3 Mar 2026 09:15:26 +0100 Subject: [PATCH 08/20] Import KernelFactory to components/__init__.py --- baybe/surrogates/gaussian_process/components/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/baybe/surrogates/gaussian_process/components/__init__.py b/baybe/surrogates/gaussian_process/components/__init__.py index a9e11f4afe..7d2e8a4f11 100644 --- a/baybe/surrogates/gaussian_process/components/__init__.py +++ b/baybe/surrogates/gaussian_process/components/__init__.py @@ -1,6 +1,7 @@ """Gaussian process surrogate components.""" from baybe.surrogates.gaussian_process.components.kernel import ( + KernelFactory, KernelFactoryProtocol, PlainKernelFactory, ) @@ -16,6 +17,7 @@ __all__ = [ # Kernel + "KernelFactory", "KernelFactoryProtocol", "PlainKernelFactory", # Likelihood From d91623fd22b5fd501b364db6b5db4bb4be34c83f Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 4 Mar 2026 09:47:37 +0100 Subject: [PATCH 09/20] Add citation to docstring --- baybe/surrogates/gaussian_process/components/kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/components/kernel.py b/baybe/surrogates/gaussian_process/components/kernel.py index 337459f2aa..1f038345ff 100644 --- a/baybe/surrogates/gaussian_process/components/kernel.py +++ b/baybe/surrogates/gaussian_process/components/kernel.py @@ -64,7 +64,7 @@ def __attrs_post_init__(self): class ICMKernelFactory(KernelFactoryProtocol): """A kernel factory that constructs an ICM kernel for transfer learning. - ICM: Intrinsic model of coregionalization + ICM: Intrinsic Coregionalization Model :cite:p:`NIPS2007_66368270` """ base_kernel_factory: KernelFactoryProtocol = field(alias="base_kernel_or_factory") From 91bd65464eb649b5a6df7b8d25643f440e0641e1 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 4 Mar 2026 11:39:55 +0100 Subject: [PATCH 10/20] Fix typing --- baybe/kernels/base.py | 27 +++++++++++++++++------ baybe/parameters/selector.py | 3 ++- baybe/surrogates/gaussian_process/core.py | 2 +- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/baybe/kernels/base.py b/baybe/kernels/base.py index c29abea813..4b807150f5 100644 --- a/baybe/kernels/base.py +++ b/baybe/kernels/base.py @@ -18,7 +18,7 @@ from baybe.searchspace.core import SearchSpace from baybe.serialization.mixin import SerialMixin from baybe.settings import active_settings -from baybe.utils.basic import classproperty, get_baseclasses, match_attributes +from baybe.utils.basic import classproperty, get_baseclasses, match_attributes, to_tuple if TYPE_CHECKING: import torch @@ -44,7 +44,9 @@ def to_factory(self) -> PlainKernelFactory: return PlainKernelFactory(self) @abstractmethod - def _get_dimensions(self, searchspace: SearchSpace) -> tuple[tuple[int, ...], int]: + def _get_dimensions( + self, searchspace: SearchSpace + ) -> tuple[tuple[int, ...] | None, int | None]: """Get the active dimensions and the number of ARD dimensions.""" def to_gpytorch( @@ -138,8 +140,13 @@ class BasicKernel(Kernel, ABC): parameter_names: tuple[str, ...] | None = field( default=None, - converter=optional_c(tuple), - validator=optional_v(deep_iterable(member_validator=instance_of(str))), + converter=optional_c(to_tuple), + validator=optional_v( + deep_iterable( + iterable_validator=instance_of(tuple), + member_validator=instance_of(str), + ) + ), kw_only=True, ) """An optional set of names specifiying the parameters the kernel should act on.""" @@ -149,11 +156,14 @@ class BasicKernel(Kernel, ABC): def _whitelisted_attributes(cls) -> frozenset[str]: return frozenset({"parameter_names"}) - def _get_dimensions(self, searchspace): + @override + def _get_dimensions( + self, searchspace: SearchSpace + ) -> tuple[tuple[int, ...] | None, int | None]: if self.parameter_names is None: active_dims = None else: - active_dims = list( + active_dims = tuple( chain( *[ searchspace.get_comp_rep_parameter_indices(name) @@ -175,7 +185,10 @@ def _get_dimensions(self, searchspace): class CompositeKernel(Kernel, ABC): """Abstract base class for all composite kernels.""" - def _get_dimensions(self, searchspace): + @override + def _get_dimensions( + self, searchspace: SearchSpace + ) -> tuple[tuple[int, ...] | None, int | None]: return None, None diff --git a/baybe/parameters/selector.py b/baybe/parameters/selector.py index 4de9e89856..fcafe92b70 100644 --- a/baybe/parameters/selector.py +++ b/baybe/parameters/selector.py @@ -8,6 +8,7 @@ from typing_extensions import override from baybe.parameters.base import Parameter +from baybe.utils.basic import to_tuple class ParameterSelectorProtocol(Protocol): @@ -39,7 +40,7 @@ def __call__(self, parameter: Parameter) -> bool: class TypeSelector(ParameterSelector): """Select parameters by type.""" - parameter_types: tuple[type[Parameter], ...] = field(converter=tuple) + parameter_types: tuple[type[Parameter], ...] = field(converter=to_tuple) """The parameter types to be selected.""" @override diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 25c3355728..617a4a247c 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -136,7 +136,7 @@ class GaussianProcessSurrogate(Surrogate): kernel_factory: KernelFactoryProtocol = field( alias="kernel_or_factory", converter=pipe( # type: ignore[misc] - Converter(_mark_custom_kernel, takes_self=True), + Converter(_mark_custom_kernel, takes_self=True), # type: ignore[call-overload] partial(to_component_factory, component_type=GPComponentType.KERNEL), ), factory=BayBEKernelFactory, From 02bfc9778979eba26dea37c649106b4c6f377bb6 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 4 Mar 2026 14:12:01 +0100 Subject: [PATCH 11/20] Add parameter selection to kernel hypothesis strategies --- tests/hypothesis_strategies/kernels.py | 221 +++++++++++++++---------- 1 file changed, 136 insertions(+), 85 deletions(-) diff --git a/tests/hypothesis_strategies/kernels.py b/tests/hypothesis_strategies/kernels.py index 6ee8c1203b..0d002e0d1f 100644 --- a/tests/hypothesis_strategies/kernels.py +++ b/tests/hypothesis_strategies/kernels.py @@ -1,5 +1,6 @@ """Hypothesis strategies for kernels.""" +from collections.abc import Sequence from enum import Enum import hypothesis.strategies as st @@ -18,6 +19,7 @@ ) from baybe.kernels.composite import AdditiveKernel, ProductKernel, ScaleKernel from tests.hypothesis_strategies.basic import positive_finite_floats +from tests.hypothesis_strategies.parameters import parameter_names from tests.hypothesis_strategies.priors import priors @@ -29,99 +31,143 @@ class KernelType(Enum): PRODUCT = "PRODUCT" -linear_kernels = st.builds( - LinearKernel, - variance_prior=st.one_of(st.none(), priors()), - variance_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates linear kernels.""" - -matern_kernels = st.builds( - MaternKernel, - nu=st.sampled_from((0.5, 1.5, 2.5)), - lengthscale_prior=st.one_of(st.none(), priors()), - lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates Matern kernels.""" - -periodic_kernels = st.builds( - PeriodicKernel, - lengthscale_prior=st.one_of(st.none(), priors()), - lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), - period_length_prior=st.one_of(st.none(), priors()), - period_length_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates periodic kernels.""" - -piecewise_polynomial_kernels = st.builds( - PiecewisePolynomialKernel, - q=st.integers(min_value=0, max_value=3), - lengthscale_prior=st.one_of(st.none(), priors()), - lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates piecewise polynomial kernels.""" - -polynomial_kernels = st.builds( - PolynomialKernel, - power=st.integers(min_value=0), - offset_prior=st.one_of(st.none(), priors()), - offset_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates polynomial kernels.""" - -rbf_kernels = st.builds( - RBFKernel, - lengthscale_prior=st.one_of(st.none(), priors()), - lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates radial basis function (RBF) kernels.""" - -rff_kernels = st.builds( - RFFKernel, - num_samples=st.integers(min_value=1), - lengthscale_prior=st.one_of(st.none(), priors()), - lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates random Fourier features (RFF) kernels.""" - -rq_kernels = st.builds( - RQKernel, - lengthscale_prior=st.one_of(st.none(), priors()), - lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), -) -"""A strategy that generates rational quadratic (RQ) kernels.""" +def active_parameter_names(names: Sequence[str] | None = None): + """A strategy generating optional parameter names for kernels to operate on.""" + if names is None: + return st.one_of( + st.none(), st.lists(parameter_names, min_size=1, max_size=5, unique=True) + ) + return st.just(names) + + +def linear_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates linear kernels.""" + return st.builds( + LinearKernel, + parameter_names=active_parameter_names(parameter_names), + variance_prior=st.one_of(st.none(), priors()), + variance_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) + + +def matern_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates Matern kernels.""" + return st.builds( + MaternKernel, + parameter_names=active_parameter_names(parameter_names), + nu=st.sampled_from((0.5, 1.5, 2.5)), + lengthscale_prior=st.one_of(st.none(), priors()), + lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) + + +def periodic_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates periodic kernels.""" + return st.builds( + PeriodicKernel, + parameter_names=active_parameter_names(parameter_names), + lengthscale_prior=st.one_of(st.none(), priors()), + lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), + period_length_prior=st.one_of(st.none(), priors()), + period_length_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) + + +def piecewise_polynomial_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates piecewise polynomial kernels.""" + return st.builds( + PiecewisePolynomialKernel, + parameter_names=active_parameter_names(parameter_names), + q=st.integers(min_value=0, max_value=3), + lengthscale_prior=st.one_of(st.none(), priors()), + lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) + + +def polynomial_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates polynomial kernels.""" + return st.builds( + PolynomialKernel, + parameter_names=active_parameter_names(parameter_names), + power=st.integers(min_value=0), + offset_prior=st.one_of(st.none(), priors()), + offset_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) + + +def rbf_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates radial basis function (RBF) kernels.""" + return st.builds( + RBFKernel, + parameter_names=active_parameter_names(parameter_names), + lengthscale_prior=st.one_of(st.none(), priors()), + lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) + + +def rff_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates random Fourier features (RFF) kernels.""" + return st.builds( + RFFKernel, + parameter_names=active_parameter_names(parameter_names), + num_samples=st.integers(min_value=1), + lengthscale_prior=st.one_of(st.none(), priors()), + lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) + + +def rq_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates rational quadratic (RQ) kernels.""" + return st.builds( + RQKernel, + parameter_names=active_parameter_names(parameter_names), + lengthscale_prior=st.one_of(st.none(), priors()), + lengthscale_initial_value=st.one_of(st.none(), positive_finite_floats()), + ) @st.composite -def index_kernels(draw: st.DrawFn): +def index_kernels( + draw: st.DrawFn, + parameter_names: Sequence[str] | None = None, +): """A strategy that generates index kernels.""" num_tasks = draw(st.integers(min_value=2, max_value=5)) rank = draw(st.integers(min_value=1, max_value=num_tasks)) + names = draw(active_parameter_names(parameter_names)) if draw(st.booleans()): - return PositiveIndexKernel(num_tasks=num_tasks, rank=rank) - return IndexKernel(num_tasks=num_tasks, rank=rank) - - -base_kernels = st.one_of( - [ - matern_kernels, # on top because it is the default for many use cases - linear_kernels, - rbf_kernels, - rq_kernels, - rff_kernels, - index_kernels(), - piecewise_polynomial_kernels, - polynomial_kernels, - periodic_kernels, - ] -) -"""A strategy that generates base kernels to be used within more complex kernels.""" + return PositiveIndexKernel( + parameter_names=names, + num_tasks=num_tasks, + rank=rank, + ) + return IndexKernel(parameter_names=names, num_tasks=num_tasks, rank=rank) + + +def base_kernels(parameter_names: Sequence[str] | None = None): + """A strategy that generates base kernels to be used within more complex kernels.""" + return st.one_of( + [ + matern_kernels(parameter_names), # on top because it is the default + linear_kernels(parameter_names), + rbf_kernels(parameter_names), + rq_kernels(parameter_names), + rff_kernels(parameter_names), + index_kernels(parameter_names=parameter_names), + piecewise_polynomial_kernels(parameter_names), + polynomial_kernels(parameter_names), + periodic_kernels(parameter_names), + ] + ) @st.composite -def single_kernels(draw: st.DrawFn): +def single_kernels( + draw: st.DrawFn, + parameter_names: Sequence[str] | None = None, +): """Generate single kernels (i.e., without kernel arithmetic, except scaling).""" - base_kernel = draw(base_kernels) + base_kernel = draw(base_kernels(parameter_names)) add_scale = draw(st.booleans()) if add_scale: return ScaleKernel( @@ -136,14 +182,19 @@ def single_kernels(draw: st.DrawFn): @st.composite -def kernels(draw: st.DrawFn): +def kernels( + draw: st.DrawFn, + parameter_names: Sequence[str] | None = None, +): """Generate :class:`baybe.kernels.base.Kernel`.""" kernel_type = draw(st.sampled_from(KernelType)) if kernel_type is KernelType.SINGLE: - return draw(single_kernels()) + return draw(single_kernels(parameter_names=parameter_names)) - base_kernels = draw(st.lists(single_kernels(), min_size=2)) + base_kernels = draw( + st.lists(single_kernels(parameter_names=parameter_names), min_size=2) + ) if kernel_type is KernelType.ADDITIVE: return AdditiveKernel(base_kernels) if kernel_type is KernelType.PRODUCT: From ca37ecf8612a9d20595e527799379ae72178c123 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 1 Apr 2026 12:57:49 +0200 Subject: [PATCH 12/20] Drop batch_shape argument from Kernel.to_gpytorch --- baybe/kernels/base.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/baybe/kernels/base.py b/baybe/kernels/base.py index 4b807150f5..407b876dc1 100644 --- a/baybe/kernels/base.py +++ b/baybe/kernels/base.py @@ -21,8 +21,6 @@ from baybe.utils.basic import classproperty, get_baseclasses, match_attributes, to_tuple if TYPE_CHECKING: - import torch - from baybe.surrogates.gaussian_process.components.kernel import PlainKernelFactory @@ -49,12 +47,7 @@ def _get_dimensions( ) -> tuple[tuple[int, ...] | None, int | None]: """Get the active dimensions and the number of ARD dimensions.""" - def to_gpytorch( - self, - searchspace: SearchSpace, - *, - batch_shape: torch.Size | None = None, - ): + def to_gpytorch(self, searchspace: SearchSpace): """Create the gpytorch representation of the kernel.""" import gpytorch.kernels @@ -64,7 +57,7 @@ def to_gpytorch( # makes use of kwargs, i.e. differentiates if certain keywords are explicitly # passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)` # fails if we explicitly pass `ard_num_dims=None`. - kw: dict[str, Any] = dict(batch_shape=batch_shape, active_dims=active_dims) + kw: dict[str, Any] = dict(active_dims=active_dims) kw = {k: v for k, v in kw.items() if v is not None} # Get corresponding gpytorch kernel class and its base classes From cec1e5d187f25faaa1626ba170b813084d020c40 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 1 Apr 2026 13:58:34 +0200 Subject: [PATCH 13/20] Update kernel assembly test --- tests/test_kernels.py | 76 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 7d7fc4cf68..b64f19fc68 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -3,12 +3,13 @@ from typing import Any import numpy as np -import torch from attrs import asdict, has from hypothesis import given from baybe.kernels.base import BasicKernel, Kernel from baybe.kernels.basic import IndexKernel +from baybe.parameters import NumericalContinuousParameter +from baybe.searchspace.core import SearchSpace from tests.hypothesis_strategies.kernels import kernels # TODO: Consider deprecating these attribute names to avoid inconsistencies @@ -18,18 +19,55 @@ """Dictionary for resolving name differences between BayBE and GPyTorch attributes.""" -def validate_gpytorch_kernel_components(obj: Any, mapped: Any, **kwargs) -> None: +def _collect_parameter_names(kernel: Kernel) -> set[str]: + """Collect all parameter names involved in a kernel structure. + + Args: + kernel: A BayBE kernel (basic or composite). + + Returns: + A set of all parameter names found in the kernel structure. + """ + parameter_names = set() + + # If it's a BasicKernel, add its parameter_names + if isinstance(kernel, BasicKernel) and kernel.parameter_names is not None: + parameter_names.update(kernel.parameter_names) + + # Recursively collect from composite kernels + kernel_dict = asdict(kernel, recurse=False) + for value in kernel_dict.values(): + # Handle single nested kernel (e.g., ScaleKernel.base_kernel) + if isinstance(value, Kernel): + parameter_names.update(_collect_parameter_names(value)) + # Handle tuple of kernels (e.g., AdditiveKernel.base_kernels) + elif isinstance(value, tuple) and all(isinstance(k, Kernel) for k in value): + for k in value: + parameter_names.update(_collect_parameter_names(k)) + + return parameter_names + + +def validate_gpytorch_kernel_components( + obj: Any, mapped: Any, searchspace: SearchSpace +) -> None: """Validate that all kernel components are correctly translated to GPyTorch. Args: obj: An object occurring as part of a BayBE kernel. mapped: The corresponding object in the translated GPyTorch kernel. - **kwargs: Optional kernel arguments that were passed to the GPyTorch kernel. + searchspace: The search space used for the translation. """ # Assert that the kernel kwargs are correctly mapped if isinstance(obj, BasicKernel): - for k, v in kwargs.items(): - assert torch.tensor(getattr(mapped, k)).equal(torch.tensor(v)) + active_dims, ard_num_dims = obj._get_dimensions(searchspace) + + assert mapped.ard_num_dims == ard_num_dims + assert active_dims == ( + tuple(mapped.active_dims.tolist()) + if mapped.active_dims is not None + else None + ) # Compare attribute by attribute for name, component in asdict(obj, recurse=False).items(): @@ -38,6 +76,10 @@ def validate_gpytorch_kernel_components(obj: Any, mapped: Any, **kwargs) -> None if component is None: continue + # Skip BayBE-only attributes that have no GPyTorch counterpart + if isinstance(obj, Kernel) and name in obj._whitelisted_attributes: + continue + # Resolve attribute naming differences mapped_name = _RENAME_DICT.get(name, name) @@ -71,12 +113,14 @@ def validate_gpytorch_kernel_components(obj: Any, mapped: Any, **kwargs) -> None # If the component is itself another attrs object, recurse elif has(component): - validate_gpytorch_kernel_components(component, mapped_component, **kwargs) + validate_gpytorch_kernel_components( + component, mapped_component, searchspace + ) # Same for collections of BayBE objects (coming from composite kernels) elif isinstance(component, tuple) and all(has(c) for c in component): for c, m in zip(component, mapped_component): - validate_gpytorch_kernel_components(c, m, **kwargs) + validate_gpytorch_kernel_components(c, m, searchspace) # On the lowest component level, simply check for equality else: @@ -87,14 +131,14 @@ def validate_gpytorch_kernel_components(obj: Any, mapped: Any, **kwargs) -> None def test_kernel_assembly(kernel: Kernel): """Turning a BayBE kernel into a GPyTorch kernel raises no errors and all its components are translated correctly.""" # noqa - # Create some arbitrary kernel kwargs to ensure that they are correctly translated - kwargs = dict( - ard_num_dims=np.random.randint(0, 32), - batch_shape=torch.Size( - [np.random.randint(0, 4) for _ in range(np.random.randint(0, 4))] - ), - active_dims=[np.random.randint(0, 4) for _ in range(np.random.randint(0, 4))], + + # Create a search space containing parameters referenced by the kernel + parameter_names = _collect_parameter_names(kernel) + if not parameter_names: + parameter_names = ["x"] + searchspace = SearchSpace.from_product( + [NumericalContinuousParameter(name, (0, 1)) for name in parameter_names] ) - k = kernel.to_gpytorch(**kwargs) - validate_gpytorch_kernel_components(kernel, k, **kwargs) + k = kernel.to_gpytorch(searchspace) + validate_gpytorch_kernel_components(kernel, k, searchspace) From daaa8f377a3f5aa413640bb94787f09b850a156d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 1 Apr 2026 14:38:05 +0200 Subject: [PATCH 14/20] Refactor handling of constructor-only attributes in test --- tests/test_kernels.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index b64f19fc68..6db1ebc0e8 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -48,7 +48,7 @@ def _collect_parameter_names(kernel: Kernel) -> set[str]: return parameter_names -def validate_gpytorch_kernel_components( +def validate_gpytorch_kernel_components( # noqa: DOC501 obj: Any, mapped: Any, searchspace: SearchSpace ) -> None: """Validate that all kernel components are correctly translated to GPyTorch. @@ -83,33 +83,33 @@ def validate_gpytorch_kernel_components( # Resolve attribute naming differences mapped_name = _RENAME_DICT.get(name, name) - # If the attribute does not exist in the GPyTorch version, ... + # If the attribute does not exist in the GPyTorch version, it must have some + # special handling on the GPyTorch side ... if (mapped_component := getattr(mapped, mapped_name, None)) is None: - # ... it must have some special handling on the GPyTorch side - # >>>>> - # TODO: this will be refactored in #748, which requires changes to the - # test logic anyways + # The number of tasks is reflected by the the constructed covariance matrix if isinstance(obj, IndexKernel) and name == "num_tasks": - # Special case for IndexKernel.num_tasks, which is not an actual - # attribute of the GPyTorch kernel assert mapped.covar_factor.shape[-2] == component continue + + # The rank is reflected by the the constructed covariance matrix elif isinstance(obj, IndexKernel) and name == "rank": - # Special case for IndexKernel.rank, which is not an actual attribute of - # the GPyTorch kernel assert mapped.covar_factor.shape[-1] == component continue - # <<<<< - # ... or it must be an initial value. Because setting initial values + # Initial values are directly applied. Because setting initial values # involves going through constraint transformations on GPyTorch side (i.e., # difference between `` and `raw_`), the numerical values will # not be exact, so we check only for approximate matches. - assert name.endswith("_initial_value") - assert np.allclose( - component, - getattr(mapped, name.removesuffix("_initial_value")).detach().numpy(), - ) + elif name.endswith("_initial_value"): + assert np.allclose( + component, + getattr(mapped, name.removesuffix("_initial_value")) + .detach() + .numpy(), + ) + continue + + raise AssertionError(f"Kernel component not correctly mapped: {name}") # If the component is itself another attrs object, recurse elif has(component): From 90d4641e78635721886b65118596b2603a810256 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 1 Apr 2026 14:54:12 +0200 Subject: [PATCH 15/20] Fix logic of custom kernel converter helper --- baybe/surrogates/gaussian_process/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 617a4a247c..fb6a496565 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -98,11 +98,11 @@ def numerical_indices(self) -> list[int]: def _mark_custom_kernel( - value: Kernel | KernelFactoryProtocol | None, self: GaussianProcessSurrogate + value: Kernel | KernelFactoryProtocol, self: GaussianProcessSurrogate ) -> Kernel | KernelFactoryProtocol: """Mark the surrogate as using a custom kernel (for deprecation purposes).""" - if value is None: - return BayBEKernelFactory() + if type(value) is BayBEKernelFactory: + return value self._custom_kernel = True return value From 2dd13164b8c9c747813e3d313714cc13736276b2 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 1 Apr 2026 16:59:42 +0200 Subject: [PATCH 16/20] Validate that GP component factories are callable --- baybe/surrogates/gaussian_process/core.py | 5 ++++- tests/test_deprecations.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index fb6a496565..7bc2ddb166 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -10,7 +10,7 @@ from attrs import Converter, define, field from attrs.converters import pipe -from attrs.validators import instance_of +from attrs.validators import instance_of, is_callable from typing_extensions import Self, override from baybe.exceptions import DeprecationError @@ -140,6 +140,7 @@ class GaussianProcessSurrogate(Surrogate): partial(to_component_factory, component_type=GPComponentType.KERNEL), ), factory=BayBEKernelFactory, + validator=is_callable(), ) """The factory used to create the kernel for the Gaussian process. @@ -153,6 +154,7 @@ class GaussianProcessSurrogate(Surrogate): alias="mean_or_factory", factory=BayBEMeanFactory, converter=partial(to_component_factory, component_type=GPComponentType.MEAN), # type: ignore[misc] + validator=is_callable(), ) """The factory used to create the mean function for the Gaussian process. @@ -167,6 +169,7 @@ class GaussianProcessSurrogate(Surrogate): converter=partial( # type: ignore[misc] to_component_factory, component_type=GPComponentType.LIKELIHOOD ), + validator=is_callable(), ) """The factory used to create the likelihood for the Gaussian process. diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 20c27366b7..fcb61b4989 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -582,7 +582,7 @@ def test_multitask_kernel_deprecation(monkeypatch, custom: bool, env: bool, task measurements = create_fake_input( searchspace.parameters, objective.targets, n_rows=2 ) - kernel = MaternKernel() if custom else None + args = (MaternKernel(),) if custom else () if env: monkeypatch.setenv("BAYBE_DISABLE_CUSTOM_KERNEL_WARNING", "True") @@ -593,4 +593,4 @@ def test_multitask_kernel_deprecation(monkeypatch, custom: bool, env: bool, task else nullcontext() ) with context: - GaussianProcessSurrogate(kernel).fit(searchspace, objective, measurements) + GaussianProcessSurrogate(*args).fit(searchspace, objective, measurements) From 118e019c1fee69605f4ee85ef85237f3b340856b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 4 Mar 2026 09:56:53 +0100 Subject: [PATCH 17/20] Update CHANGELOG.md --- CHANGELOG.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c301a7370..e3ba106ff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,13 +15,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Gaussian process component factories - Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process components, enabling full low-level customization +- Factories for all Gaussian process components - `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate` -- Interpoint constraints for continuous search spaces +- `parameters/selector.py` module enabling convenient parameter subselection +- `parameter_names` attribute to basic kernels for controlling the considered parameters - `IndexKernel` and `PositiveIndexKernel` classes +- Interpoint constraints for continuous search spaces ### Breaking Changes - `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples instead of a single tuple (needed for interpoint constraints) +- `Kernel.to_gpytorch` now takes a `SearchSpace` instead of explicit `ard_num_dims`, + `batch_shape` and `active_dims` arguments, as kernels now automatically adjust this + configuration to the given search space +- `GaussianProcessSurrogate` no longer automatically adds a task kernel in multi-task + scenarios. Custom kernel architectures must now explicitly include the task kernel, + e.g. via `ICMKernelFactory` ### Removed - `parallel_runs` argument from `simulate_scenarios`, since parallelization @@ -30,6 +39,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `GaussianProcessSurrogate.from_preset` ### Deprecations +- Using a custom kernel with `GaussianProcessSurrogate` in a multi-task context now + raises a `DeprecationError` to alert users about the changed kernel logic. This can + be suppressed by setting the `BAYBE_DISABLE_CUSTOM_KERNEL_WARNING` environment + variable to a truthy value - `set_random_seed` and `temporary_seed` utility functions - The environment variables `BAYBE_NUMPY_USE_SINGLE_PRECISION`/`BAYBE_TORCH_USE_SINGLE_PRECISION` have been From b1215a60b39ba29e562811a33b7cbec1e4cb5542 Mon Sep 17 00:00:00 2001 From: emiliencgm Date: Thu, 26 Feb 2026 16:45:06 +0100 Subject: [PATCH 18/20] Add CHENKernelFactory --- .../gaussian_process/presets/chen.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 baybe/surrogates/gaussian_process/presets/chen.py diff --git a/baybe/surrogates/gaussian_process/presets/chen.py b/baybe/surrogates/gaussian_process/presets/chen.py new file mode 100644 index 0000000000..12cb9ff03e --- /dev/null +++ b/baybe/surrogates/gaussian_process/presets/chen.py @@ -0,0 +1,62 @@ +"""Adaptive Prior proposed in the paper: +Guanming Chen, Maximilian Fleck, Thijs Stuyver. Leveraging Hidden-Space Representations Effectively in Bayesian Optimization for Experiment Design through Dimension-Aware Hyperpriors. ChemRxiv. 09 February 2026. +DOI: https://doi.org/10.26434/chemrxiv.10001986/v2""" + +from __future__ import annotations + +import gc +from typing import TYPE_CHECKING + +import numpy as np +from attrs import define +from typing_extensions import override + +from baybe.kernels.basic import MaternKernel +from baybe.kernels.composite import ScaleKernel +from baybe.parameters import TaskParameter +from baybe.priors.basic import GammaPrior +from baybe.surrogates.gaussian_process.kernel_factory import KernelFactory + +if TYPE_CHECKING: + from torch import Tensor + + from baybe.kernels.base import Kernel + from baybe.searchspace.core import SearchSpace + +import math + +@define +class CHENKernelFactory(KernelFactory): + """ Surrogate model with an adaptive hyperprior proposed in the paper: + Guanming Chen, Maximilian Fleck, Thijs Stuyver. Leveraging Hidden-Space Representations Effectively in Bayesian Optimization for Experiment Design through Dimension-Aware Hyperpriors. ChemRxiv. 09 February 2026. + DOI: https://doi.org/10.26434/chemrxiv.10001986/v2 + """ + + @override + def __call__( + self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor + ) -> Kernel: + effective_dims = train_x.shape[-1] - len( + [p for p in searchspace.parameters if isinstance(p, TaskParameter)] + ) + + x = math.sqrt(effective_dims) + l_mean = 0.4 * x + 4.0 + + lengthscale_prior = GammaPrior(2.0*l_mean, 2.0) + lengthscale_initial_value = l_mean + outputscale_prior = GammaPrior(1.0*l_mean, 1.0) + outputscale_initial_value = l_mean + + return ScaleKernel( + MaternKernel( + nu=2.5, + lengthscale_prior=lengthscale_prior, + lengthscale_initial_value=lengthscale_initial_value, + ), + outputscale_prior=outputscale_prior, + outputscale_initial_value=outputscale_initial_value, + ) + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() From 438ea8c10ef954f5d8403b4fd4ab580250a5b809 Mon Sep 17 00:00:00 2001 From: maxfleck Date: Thu, 26 Feb 2026 21:33:36 +0100 Subject: [PATCH 19/20] Import CHENKernelFactory in __init__.py --- baybe/surrogates/gaussian_process/presets/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/baybe/surrogates/gaussian_process/presets/__init__.py b/baybe/surrogates/gaussian_process/presets/__init__.py index deb7de9e64..434fbf560f 100644 --- a/baybe/surrogates/gaussian_process/presets/__init__.py +++ b/baybe/surrogates/gaussian_process/presets/__init__.py @@ -7,6 +7,9 @@ BayBEMeanFactory, ) +# Chen preset +from baybe.surrogates.gaussian_process.presets.chen import CHENKernelFactory + # Core from baybe.surrogates.gaussian_process.presets.core import GaussianProcessPreset @@ -31,6 +34,8 @@ "BayBEKernelFactory", "BayBELikelihoodFactory", "BayBEMeanFactory", + # Chen preset + "CHENKernelFactory", # EDBO preset "EDBOKernelFactory", "EDBOLikelihoodFactory", From c01b540eb109f34dfed549db3f25bba9c4dc288e Mon Sep 17 00:00:00 2001 From: tstuyver Date: Fri, 27 Feb 2026 00:17:44 +0100 Subject: [PATCH 20/20] Update CONTRIBUTORS.md --- CONTRIBUTORS.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 6270ff77c0..f639c03a8a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -39,4 +39,10 @@ - Kathrin Skubch (Merck KGaA, Darmstadt, Germany):\ Transfer learning regression benchmarks infrastructure - Myra Zmarsly (Merck Life Science KGaA, Darmstadt, Germany):\ - Identification of non-dominated parameter configurations \ No newline at end of file + Identification of non-dominated parameter configurations +- Thijs Stuyver (PSL University, Paris, France):\ + Adaptive hyper-prior tailored for reaction yield optimization tasks +- Maximilian Fleck (PSL University, Paris, France):\ + Adaptive hyper-prior tailored for reaction yield optimization tasks +- Guanming Chen (PSL University, Paris, France):\ + Adaptive hyper-prior tailored for reaction yield optimization tasks