Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
aba1641
Add IndexKernel class
AdrianSosic Feb 11, 2026
e54676d
Add hypothesis strategy for IndexKernel
AdrianSosic Feb 11, 2026
be48c5e
Add PositiveIndexKernel class
AdrianSosic Mar 3, 2026
9680f8d
Add positive index kernel path to hypothesis strategy
AdrianSosic Mar 3, 2026
1df0e58
Absorb index kernel construction into ICMKernelFactory
AdrianSosic Feb 12, 2026
738c526
Add parameter selectors
AdrianSosic Feb 12, 2026
8e9d66d
Rename protocols
AdrianSosic Feb 12, 2026
7c67cd9
Implement active kernel dimension control
AdrianSosic Feb 12, 2026
93ea6d0
Move parameter_names attribute down to BasicKernel subclass
AdrianSosic Feb 12, 2026
9784b8b
Fix condition in DefaultKernelFactory
AdrianSosic Feb 12, 2026
152c60d
Add deprecation mechanism for breaking change in kernel logic
AdrianSosic Feb 12, 2026
f6e8566
Import KernelFactory to components/__init__.py
AdrianSosic Mar 3, 2026
aad110e
Enable multitask mode for surrogate streamlit
AdrianSosic Mar 2, 2026
cefaefa
Add BOTORCH preset
AdrianSosic Mar 2, 2026
d26024d
Extend BoTorch preset test to multitask case
AdrianSosic Mar 2, 2026
3554c54
Enable automatic kernel translation in ICMKernelFactory
AdrianSosic Mar 2, 2026
215ebfb
Add custom GPyTorch components to replicate BoTorch logic
AdrianSosic Mar 2, 2026
ed9cc55
Extend BoTorch factories to multitask case
AdrianSosic Mar 2, 2026
ce97de1
Add missing input normalization to test
AdrianSosic Mar 2, 2026
5df253d
Use regular MLL for multitask GP fitting
AdrianSosic Mar 2, 2026
5f4ee2f
Add kernel active dimension validation to ICMKernelFactory
AdrianSosic Mar 2, 2026
c3885f6
Update CHANGELOG.md
AdrianSosic Mar 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
for determining the Pareto front
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
components, enabling full low-level customization
- `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- `BOTORCH`, `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- Interpoint constraints for continuous search spaces
- `IndexKernel` and `PositiveIndexKernel` classes

### Changed
- Gaussian processes no longer invoke leave-one-out training for multitask scenarios but
can now rely on improved model priors for good generalization

### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
Expand Down
82 changes: 69 additions & 13 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
from __future__ import annotations

import gc
from abc import ABC
from collections.abc import Sequence
from abc import ABC, abstractmethod
from itertools import chain
from typing import TYPE_CHECKING, Any

from attrs import define
from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import deep_iterable, instance_of
from attrs.validators import optional as optional_v
from typing_extensions import override

from baybe.exceptions import UnmatchedAttributeError
from baybe.priors.base import Prior
from baybe.searchspace.core import SearchSpace
from baybe.serialization.mixin import SerialMixin
from baybe.settings import active_settings
from baybe.utils.basic import get_baseclasses, match_attributes
from baybe.utils.basic import classproperty, get_baseclasses, match_attributes

if TYPE_CHECKING:
import torch
Expand All @@ -25,6 +30,11 @@
class Kernel(ABC, SerialMixin):
"""Abstract base class for all kernels."""

@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
"""Attribute names to exclude from gpytorch matching."""
return frozenset()

def to_factory(self) -> PlainKernelFactory:
"""Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501
from baybe.surrogates.gaussian_process.components.kernel import (
Expand All @@ -33,27 +43,36 @@ def to_factory(self) -> PlainKernelFactory:

return PlainKernelFactory(self)

@abstractmethod
def _get_dimensions(self, searchspace: SearchSpace) -> tuple[tuple[int, ...], int]:
"""Get the active dimensions and the number of ARD dimensions."""

def to_gpytorch(
self,
searchspace: SearchSpace,
*,
ard_num_dims: int | None = None,
batch_shape: torch.Size | None = None,
active_dims: Sequence[int] | None = None,
):
"""Create the gpytorch representation of the kernel."""
import botorch.models.kernels.positive_index
import gpytorch.kernels

active_dims, ard_num_dims = self._get_dimensions(searchspace)

# Extract keywords with non-default values. This is required since gpytorch
# makes use of kwargs, i.e. differentiates if certain keywords are explicitly
# passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)`
# fails if we explicitly pass `ard_num_dims=None`.
kw: dict[str, Any] = dict(
ard_num_dims=ard_num_dims, batch_shape=batch_shape, active_dims=active_dims
)
kw: dict[str, Any] = dict(batch_shape=batch_shape, active_dims=active_dims)
kw = {k: v for k, v in kw.items() if v is not None}

# Get corresponding gpytorch kernel class and its base classes
kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
try:
kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
except AttributeError:
kernel_cls = getattr(
botorch.models.kernels.positive_index, self.__class__.__name__
)
base_classes = get_baseclasses(kernel_cls, abstract=True)

# Fetch the necessary gpytorch constructor parameters of the kernel.
Expand All @@ -72,7 +91,7 @@ def to_gpytorch(
# in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured).
# Exception: initial values are not used during construction but are set
# on the created object (see code at the end of the method).
missing = set(unmatched) - set(kernel_attrs)
missing = set(unmatched) - set(kernel_attrs) - self._whitelisted_attributes
if leftover := {m for m in missing if not m.endswith("_initial_value")}:
raise UnmatchedAttributeError(leftover)

Expand All @@ -85,15 +104,15 @@ def to_gpytorch(

# Convert specified inner kernels to gpytorch, if provided
kernel_dict = {
key: value.to_gpytorch(**kw)
key: value.to_gpytorch(searchspace, **kw)
for key, value in kernel_attrs.items()
if isinstance(value, Kernel)
}

# Create the kernel with all its inner gpytorch objects
kernel_attrs.update(kernel_dict)
kernel_attrs.update(prior_dict)
gpytorch_kernel = kernel_cls(**kernel_attrs, **kw)
gpytorch_kernel = kernel_cls(**kernel_attrs, ard_num_dims=ard_num_dims, **kw)

# If the kernel has a lengthscale, set its initial value
if kernel_cls.has_lengthscale:
Expand All @@ -116,11 +135,48 @@ def to_gpytorch(
class BasicKernel(Kernel, ABC):
"""Abstract base class for all basic kernels."""

parameter_names: tuple[str, ...] | None = field(
default=None,
converter=optional_c(tuple),
validator=optional_v(deep_iterable(member_validator=instance_of(str))),
kw_only=True,
)
"""An optional set of names specifiying the parameters the kernel should act on."""

@override
@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
return frozenset({"parameter_names"})

def _get_dimensions(self, searchspace):
if self.parameter_names is None:
active_dims = None
else:
active_dims = list(
chain(
*[
searchspace.get_comp_rep_parameter_indices(name)
for name in self.parameter_names
]
)
)

# We use automatic relevance determination for all kernels
ard_num_dims = (
len(active_dims)
if active_dims is not None
else len(searchspace.comp_rep_columns)
)
return active_dims, ard_num_dims


@define(frozen=True)
class CompositeKernel(Kernel, ABC):
"""Abstract base class for all composite kernels."""

def _get_dimensions(self, searchspace):
return None, None


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
24 changes: 24 additions & 0 deletions baybe/kernels/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,29 @@ class RQKernel(BasicKernel):
"""An optional initial value for the kernel lengthscale."""


@define(frozen=True)
class IndexKernel(BasicKernel):
"""An index kernel for transfer learning across tasks."""

num_tasks: int = field(validator=[instance_of(int), ge(2)])
"""The number of tasks."""

rank: int = field(validator=[instance_of(int), ge(1)])
"""The rank of the task covariance matrix."""

@rank.validator
def _validate_rank(self, _, rank: int):
if rank > self.num_tasks:
raise ValueError(
f"The rank of the task covariance matrix must be smaller than "
f"the number of tasks. Got rank {rank} >= {self.num_tasks} tasks."
)


@define(frozen=True)
class PositiveIndexKernel(IndexKernel):
"""A positive index kernel for transfer learning across tasks."""


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
47 changes: 47 additions & 0 deletions baybe/parameters/selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Parameter selectors."""

from abc import abstractmethod
from typing import Protocol

from attrs import define, field
from attrs.validators import instance_of
from typing_extensions import override

from baybe.parameters.base import Parameter


class ParameterSelectorProtocol(Protocol):
"""Type protocol specifying the interface parameter selectors need to implement."""

def __call__(self, parameter: Parameter) -> bool:
"""Determine if a parameter should be included in the selection."""


@define
class ParameterSelector(ParameterSelectorProtocol):
"""Base class for parameter selectors."""

exclude: bool = field(default=False, validator=instance_of(bool), kw_only=True)
"""Boolean flag indicating whether invert the selection criterion."""

@abstractmethod
def _is_match(self, parameter: Parameter) -> bool:
"""Determine if a parameter meets the selection criterion."""

@override
def __call__(self, parameter: Parameter) -> bool:
"""Determine if a parameter should be included in the selection."""
result = self._is_match(parameter)
return not result if self.exclude else result


@define
class TypeSelector(ParameterSelector):
"""Select parameters by type."""

parameter_types: tuple[type[Parameter], ...] = field(converter=tuple)
"""The parameter types to be selected."""

@override
def _is_match(self, parameter: Parameter) -> bool:
return isinstance(parameter, self.parameter_types)
12 changes: 11 additions & 1 deletion baybe/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from baybe._optional.info import FPSAMPLE_INSTALLED, POLARS_INSTALLED
from baybe.exceptions import NotAllowedError, OptionalImportError
from baybe.utils.basic import classproperty
from baybe.utils.boolean import AutoBool, to_bool
from baybe.utils.boolean import AutoBool, strtobool, to_bool
from baybe.utils.random import _RandomState

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,6 +50,16 @@ def _validate_whitelist_env_vars(vars: dict[str, str], /) -> None:
f"Allowed values for 'BAYBE_TEST_ENV' are "
f"'CORETEST', 'FULLTEST', and 'GPUTEST'. Given: '{value}'"
)

if (value := vars.pop("BAYBE_DISABLE_CUSTOM_KERNEL_WARNING", None)) is not None:
try:
strtobool(value)
except ValueError as ex:
raise ValueError(
f"Invalid value for 'BAYBE_DISABLE_CUSTOM_KERNEL_WARNING'. "
f"Expected a truthy value to disable the error, but got '{value}'."
) from ex

if vars:
raise RuntimeError(f"Unknown 'BAYBE_*' environment variables: {set(vars)}")

Expand Down
10 changes: 6 additions & 4 deletions baybe/surrogates/gaussian_process/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@

from baybe.surrogates.gaussian_process.components.kernel import (
KernelFactory,
KernelFactoryProtocol,
PlainKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
LikelihoodFactory,
LikelihoodFactoryProtocol,
PlainLikelihoodFactory,
)
from baybe.surrogates.gaussian_process.components.mean import (
MeanFactory,
MeanFactoryProtocol,
PlainMeanFactory,
)

__all__ = [
# Kernel
"KernelFactory",
"KernelFactoryProtocol",
"PlainKernelFactory",
# Likelihood
"LikelihoodFactory",
"LikelihoodFactoryProtocol",
"PlainLikelihoodFactory",
# Mean
"MeanFactory",
"MeanFactoryProtocol",
"PlainMeanFactory",
]
57 changes: 57 additions & 0 deletions baybe/surrogates/gaussian_process/components/_gpytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Custom GPyTorch components."""

import torch
from botorch.models.multitask import _compute_multitask_mean
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
from gpytorch.constraints import GreaterThan
from gpytorch.likelihoods.hadamard_gaussian_likelihood import HadamardGaussianLikelihood
from gpytorch.means.multitask_mean import Mean, MultitaskMean
from gpytorch.priors import LogNormalPrior
from torch import Tensor
from torch.nn import Module


class HadamardConstantMean(Mean):
"""A GPyTorch mean function implementing BoTorch's multitask mean logic.

Analogous to GPyTorch's
https://github.com/cornellius-gp/gpytorch/blob/main/gpytorch/likelihoods/hadamard_gaussian_likelihood.py
but where the logic is applied to the mean function, i.e. we learn a different
(constant) mean for each task.
"""

def __init__(self, mean_module: Module, num_tasks: int, task_feature: int):
super().__init__()
self.multitask_mean = MultitaskMean(mean_module, num_tasks=num_tasks)
self.task_feature = task_feature

def forward(self, x: Tensor) -> Tensor:
# Convert task feature to positive index
task_feature = self.task_feature % x.shape[-1]

# Split input into task and non-task components
x_before = x[..., :task_feature]
task_idcs = x[..., task_feature : task_feature + 1]
x_after = x[..., task_feature + 1 :]

return _compute_multitask_mean(
self.multitask_mean, x_before, task_idcs, x_after
)


def make_botorch_multitask_likelihood(
num_tasks: int, task_feature: int
) -> HadamardGaussianLikelihood:
"""Adapted from :class:`botorch.models.multitask.MultiTaskGP`."""
noise_prior = LogNormalPrior(loc=-4.0, scale=1.0)
return HadamardGaussianLikelihood(
num_tasks=num_tasks,
batch_shape=torch.Size(),
noise_prior=noise_prior,
noise_constraint=GreaterThan(
MIN_INFERRED_NOISE_LEVEL,
transform=None,
initial_value=noise_prior.mode,
),
task_feature_index=task_feature,
)
8 changes: 5 additions & 3 deletions baybe/surrogates/gaussian_process/components/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _validate_component(instance, attribute: Attribute, value: Any):
)


class ComponentFactory(Protocol, Generic[_T_co]):
class ComponentFactoryProtocol(Protocol, Generic[_T_co]):
"""A protocol defining the interface expected for GP component factories."""

def __call__(
Expand All @@ -61,7 +61,7 @@ def __call__(


@define(frozen=True)
class PlainComponentFactory(ComponentFactory[_T_co], SerialMixin):
class PlainComponentFactory(ComponentFactoryProtocol[_T_co], SerialMixin):
"""A trivial factory that returns a fixed pre-defined component upon request."""

component: _T_co = field(validator=_validate_component)
Expand All @@ -74,7 +74,9 @@ def __call__(
return self.component


def to_component_factory(x: Component | ComponentFactory, /) -> ComponentFactory:
def to_component_factory(
x: Component | ComponentFactoryProtocol, /
) -> ComponentFactoryProtocol:
"""Wrap a component into a plain component factory (with factory passthrough)."""
if isinstance(x, Component) or _is_gpytorch_component_class(type(x)):
return PlainComponentFactory(x)
Expand Down
Loading
Loading