Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b84af1d
Absorb index kernel construction into ICMKernelFactory
AdrianSosic Feb 12, 2026
e365862
Add parameter selectors
AdrianSosic Feb 12, 2026
dc9ee90
Rename protocols
AdrianSosic Feb 12, 2026
63e64fa
Implement active kernel dimension control
AdrianSosic Feb 12, 2026
da8ffc0
Move parameter_names attribute down to BasicKernel subclass
AdrianSosic Feb 12, 2026
86e110b
Fix condition in BayBEKernelFactory
AdrianSosic Feb 12, 2026
3714c5f
Add deprecation mechanism for breaking change in kernel logic
AdrianSosic Feb 12, 2026
66b3dfb
Import KernelFactory to components/__init__.py
AdrianSosic Mar 3, 2026
d91623f
Add citation to docstring
AdrianSosic Mar 4, 2026
91bd654
Fix typing
AdrianSosic Mar 4, 2026
02bfc97
Add parameter selection to kernel hypothesis strategies
AdrianSosic Mar 4, 2026
ca37ecf
Drop batch_shape argument from Kernel.to_gpytorch
AdrianSosic Apr 1, 2026
cec1e5d
Update kernel assembly test
AdrianSosic Apr 1, 2026
daaa8f3
Refactor handling of constructor-only attributes in test
AdrianSosic Apr 1, 2026
90d4641
Fix logic of custom kernel converter helper
AdrianSosic Apr 1, 2026
2dd1316
Validate that GP component factories are callable
AdrianSosic Apr 1, 2026
118e019
Update CHANGELOG.md
AdrianSosic Mar 4, 2026
aa07e2a
Simplify converter logic
AdrianSosic Apr 10, 2026
70a4a3e
Make ParameterSelector class abstract
AdrianSosic Apr 10, 2026
eee1504
Fix test parametrization
AdrianSosic Apr 10, 2026
d486c84
Rename selector.py to selectors.py
AdrianSosic Apr 10, 2026
8480607
Fix variable reference in kernel translation test
AdrianSosic Apr 10, 2026
7cd2321
Use chain.from_iterable instead of unpacking
AdrianSosic Apr 10, 2026
9cc4647
Replace hard-coded parameter type in deprecation error message
AdrianSosic Apr 10, 2026
9c5cf18
Add comment explaining the role of active_dims=None
AdrianSosic Apr 10, 2026
d13e365
Fix keyword logic in Kernel.to_gpytorch
AdrianSosic Apr 10, 2026
1707e3f
Add NameSelector class
AdrianSosic Apr 10, 2026
c859d85
Refine NameSelector.parameter_names field specification
AdrianSosic Apr 10, 2026
3a1c421
Make KernelFactory class abstract
AdrianSosic Apr 10, 2026
66b6f06
Remove unnecessary keyword assignment
AdrianSosic Apr 10, 2026
4edf552
Add converters to ICMKernelFactory
AdrianSosic Apr 10, 2026
a1bc442
Revise comment on kernel attribute matching
AdrianSosic Apr 10, 2026
4811dae
Fix converter specification
AdrianSosic Apr 10, 2026
4347c76
Add convenience converter for parameter selectors
AdrianSosic Apr 10, 2026
88468a1
Turn KernelFactory into mixin class for parameter selection
AdrianSosic Apr 10, 2026
2d414a2
Override default selectors for existing factories
AdrianSosic Apr 10, 2026
1b248a5
Turn assert statement into proper AssertionError
AdrianSosic Apr 10, 2026
f18cd32
Use regex by default for NameSelector
AdrianSosic Apr 13, 2026
a6d611e
Rename parameter_types attribute of TypeSelector to types
AdrianSosic Apr 13, 2026
55582fc
Rephrase changelog entry
AdrianSosic Apr 13, 2026
54a5674
Merge branch 'dev/gp' into refactor/multi_task
AdrianSosic Apr 13, 2026
20f106b
Fix typing issues
AdrianSosic Apr 13, 2026
3c38d34
Add ignore to docs/conf.py
AdrianSosic Apr 13, 2026
f943eeb
Add missing searchspace argument
AdrianSosic Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Gaussian process component factories
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
components, enabling full low-level customization
- Factories for all Gaussian process components
- `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories
- `parameter_names` attribute to basic kernels for controlling the considered parameters
- `IndexKernel` and `PositiveIndexKernel` classes
- Interpoint constraints for continuous search spaces
- `IndexKernel` and `PositiveIndexKernel` classes
- Addition and multiplication operators for kernel objects, enabling kernel
Expand All @@ -25,6 +29,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)
- `Kernel.to_gpytorch` now takes a `SearchSpace` instead of explicit `ard_num_dims`,
`batch_shape` and `active_dims` arguments, as kernels now automatically adjust this
configuration to the given search space
- `GaussianProcessSurrogate` no longer automatically adds a task kernel in multi-task
scenarios. Custom kernel architectures must now explicitly include the task kernel,
e.g. via `ICMKernelFactory`

### Removed
- `parallel_runs` argument from `simulate_scenarios`, since parallelization
Expand All @@ -33,6 +43,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`GaussianProcessSurrogate.from_preset`

### Deprecations
- Using a custom kernel with `GaussianProcessSurrogate` in a multi-task context now
raises a `DeprecationError` to alert users about the changed kernel logic. This can
be suppressed by setting the `BAYBE_DISABLE_CUSTOM_KERNEL_WARNING` environment
variable to a truthy value
- `set_random_seed` and `temporary_seed` utility functions
- The environment variables
`BAYBE_NUMPY_USE_SINGLE_PRECISION`/`BAYBE_TORCH_USE_SINGLE_PRECISION` have been
Expand Down
106 changes: 79 additions & 27 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
from __future__ import annotations

import gc
from abc import ABC
from collections.abc import Sequence
from abc import ABC, abstractmethod
from itertools import chain
from typing import TYPE_CHECKING, Any

from attrs import define
from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import deep_iterable, instance_of
from attrs.validators import optional as optional_v
from typing_extensions import override

from baybe.exceptions import UnmatchedAttributeError
from baybe.priors.base import Prior
from baybe.searchspace.core import SearchSpace
from baybe.serialization.mixin import SerialMixin
from baybe.settings import active_settings
from baybe.utils.basic import get_baseclasses, match_attributes
from baybe.utils.basic import classproperty, get_baseclasses, match_attributes, to_tuple

if TYPE_CHECKING:
import torch

from baybe.surrogates.gaussian_process.components.kernel import PlainKernelFactory


Expand Down Expand Up @@ -83,6 +86,11 @@ def __rmul__(self, other: Any) -> Kernel:
# Enable use with math.prod(), which starts with 1 * first_element.
return self.__mul__(other)

@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
"""Attribute names to exclude from gpytorch matching."""
return frozenset()

def to_factory(self) -> PlainKernelFactory:
"""Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501
from baybe.surrogates.gaussian_process.components.kernel import (
Expand All @@ -91,24 +99,17 @@ def to_factory(self) -> PlainKernelFactory:

return PlainKernelFactory(self)

def to_gpytorch(
self,
*,
ard_num_dims: int | None = None,
batch_shape: torch.Size | None = None,
active_dims: Sequence[int] | None = None,
):
@abstractmethod
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
"""Get the active dimensions and the number of ARD dimensions."""

def to_gpytorch(self, searchspace: SearchSpace):
"""Create the gpytorch representation of the kernel."""
import gpytorch.kernels

# Extract keywords with non-default values. This is required since gpytorch
# makes use of kwargs, i.e. differentiates if certain keywords are explicitly
# passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)`
# fails if we explicitly pass `ard_num_dims=None`.
kw: dict[str, Any] = dict(
ard_num_dims=ard_num_dims, batch_shape=batch_shape, active_dims=active_dims
)
kw = {k: v for k, v in kw.items() if v is not None}
active_dims, ard_num_dims = self._get_dimensions(searchspace)

# Get corresponding gpytorch kernel class and its base classes
try:
Expand All @@ -134,10 +135,12 @@ def to_gpytorch(
unmatched_attrs.update(unmatched)

# Sanity check: all attributes of the BayBE kernel need a corresponding match
# in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured).
# Exceptions: initial values and trainability flags are not used during
# construction but are set on the created object after construction.
missing = set(unmatched) - set(kernel_attrs)
# with the gpytorch kernel signature (otherwise, the BayBE kernel class is
# misconfigured). Exceptions: initial values and trainability flags are not used
# during construction but are set on the created object after construction.
missing = (
set(unmatched_attrs) - set(kernel_attrs) - self._whitelisted_attributes
)
if leftover := {
m
for m in missing
Expand All @@ -154,15 +157,17 @@ def to_gpytorch(

# Convert specified inner kernels to gpytorch, if provided
kernel_dict = {
key: value.to_gpytorch(**kw)
key: value.to_gpytorch(searchspace)
for key, value in kernel_attrs.items()
if isinstance(value, Kernel)
}

# Create the kernel with all its inner gpytorch objects
kernel_attrs.update(kernel_dict)
kernel_attrs.update(prior_dict)
gpytorch_kernel = kernel_cls(**kernel_attrs, **kw)
gpytorch_kernel = kernel_cls(
**kernel_attrs, ard_num_dims=ard_num_dims, active_dims=active_dims
)

# If the kernel has a lengthscale, set its initial value
if kernel_cls.has_lengthscale:
Expand All @@ -185,11 +190,58 @@ def to_gpytorch(
class BasicKernel(Kernel, ABC):
"""Abstract base class for all basic kernels."""

parameter_names: tuple[str, ...] | None = field(
default=None,
converter=optional_c(to_tuple),
Comment thread
Scienfitz marked this conversation as resolved.
validator=optional_v(
deep_iterable(
iterable_validator=instance_of(tuple),
member_validator=instance_of(str),
)
),
kw_only=True,
)
"""An optional set of names specifiying the parameters the kernel should act on."""

@override
@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
return frozenset({"parameter_names"})

@override
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
if self.parameter_names is None:
# `None` is gpytorch's default indicating that all dimensions are active
active_dims = None
Comment thread
Scienfitz marked this conversation as resolved.
else:
active_dims = tuple(
chain.from_iterable(
searchspace.get_comp_rep_parameter_indices(name)
for name in self.parameter_names
)
)

# We use automatic relevance determination for all kernels
ard_num_dims = (
Comment thread
Scienfitz marked this conversation as resolved.
len(active_dims)
if active_dims is not None
else len(searchspace.comp_rep_columns)
)
return active_dims, ard_num_dims


@define(frozen=True)
class CompositeKernel(Kernel, ABC):
"""Abstract base class for all composite kernels."""

@override
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
return None, None


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
167 changes: 167 additions & 0 deletions baybe/parameters/selectors.py
Comment thread
AdrianSosic marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Parameter selectors."""

import re
from abc import ABC, abstractmethod
from collections.abc import Collection
from typing import ClassVar, Protocol

from attrs import Converter, define, field
from attrs.converters import optional
from attrs.validators import deep_iterable, instance_of, min_len
from typing_extensions import override

from baybe.parameters.base import Parameter
from baybe.searchspace.core import SearchSpace
from baybe.utils.basic import to_tuple
from baybe.utils.conversion import nonstring_to_tuple


class ParameterSelectorProtocol(Protocol):
"""Type protocol specifying the interface parameter selectors need to implement."""

def __call__(self, parameter: Parameter) -> bool:
"""Determine if a parameter should be included in the selection."""


@define
class ParameterSelector(ParameterSelectorProtocol, ABC):
"""Base class for parameter selectors."""

exclude: bool = field(default=False, validator=instance_of(bool), kw_only=True)
"""Boolean flag indicating whether invert the selection criterion."""

@abstractmethod
Comment thread
AdrianSosic marked this conversation as resolved.
def _is_match(self, parameter: Parameter) -> bool:
"""Determine if a parameter meets the selection criterion."""
Comment thread
AdrianSosic marked this conversation as resolved.

@override
def __call__(self, parameter: Parameter) -> bool:
"""Determine if a parameter should be included in the selection."""
result = self._is_match(parameter)
return not result if self.exclude else result

Comment thread
Scienfitz marked this conversation as resolved.

@define
class TypeSelector(ParameterSelector):
"""Select parameters by type."""

types: tuple[type[Parameter], ...] = field(converter=to_tuple)
"""The parameter types to be selected."""

@override
def _is_match(self, parameter: Parameter) -> bool:
return isinstance(parameter, self.types)


@define
class NameSelector(ParameterSelector):
"""Select parameters by name patterns."""

patterns: tuple[str, ...] = field(
converter=Converter( # type: ignore
nonstring_to_tuple, takes_self=True, takes_field=True
),
validator=[
min_len(1),
deep_iterable(member_validator=instance_of(str)),
],
)
"""The patterns to be matched against."""

regex: bool = field(default=True, validator=instance_of(bool), kw_only=True)
"""If ``False``, the provided patterns are interpreted as literal strings."""

@override
def _is_match(self, parameter: Parameter) -> bool:
if self.regex:
return any(re.fullmatch(p, parameter.name) for p in self.patterns)
return parameter.name in self.patterns


def to_parameter_selector(
x: (
str
| type[Parameter]
| Collection[str]
| Collection[type[Parameter]]
| ParameterSelectorProtocol
),
/,
) -> ParameterSelectorProtocol:
"""Convert shorthand notations to parameter selectors.

Convenience converter that allows users to specify parameter selectors using
simpler types:

* A callable (i.e., an existing selector or any object satisfying
:class:`ParameterSelectorProtocol`) is passed through unchanged.
* A single string is interpreted as a parameter name and wrapped into a
:class:`NameSelector`.
* A single :class:`~baybe.parameters.base.Parameter` subclass is wrapped into a
:class:`TypeSelector`.
* A collection of strings is converted to a :class:`NameSelector`.
* A collection of :class:`~baybe.parameters.base.Parameter` subclasses is converted
to a :class:`TypeSelector`.

Args:
x: The object to convert.

Returns:
The corresponding parameter selector.

Raises:
TypeError: If the input cannot be converted to a parameter selector.
"""
if isinstance(x, str):
return NameSelector([x])

if isinstance(x, type) and issubclass(x, Parameter):
return TypeSelector([x])

if callable(x):
return x

# At this point, x should be a collection of strings or parameter types
items = tuple(x)

if all(isinstance(item, str) for item in items):
return NameSelector(items)

if all(isinstance(item, type) and issubclass(item, Parameter) for item in items):
return TypeSelector(items)

raise TypeError(f"Cannot convert {x!r} to a parameter selector.")


@define
class _ParameterSelectorMixin:
"""A mixin class to enable parameter selection."""

# For internal use only: sanity check mechanism to remind developers of new
# subclasses to actually use the parameter selector when it is provided
# TODO: Perhaps we can find a more elegant way to enforce this by design
_uses_parameter_names: ClassVar[bool] = False

parameter_selector: ParameterSelectorProtocol | None = field(
default=None, converter=optional(to_parameter_selector), kw_only=True
)
"""An optional selector to specify which parameters are to be considered."""

def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None:
"""Get the names of the parameters to be considered."""
if self.parameter_selector is None:
return None

return tuple(
p.name for p in searchspace.parameters if self.parameter_selector(p)
)

def __attrs_post_init__(self):
if self.parameter_selector is not None and not self._uses_parameter_names:
raise AssertionError(
f"A `parameter_selector` was provided to "
f"`{type(self).__name__}`, but the class does not set "
f"`_uses_parameter_names = True`. Subclasses that accept a "
f"parameter selector must explicitly set this flag to confirm "
f"they actually use the selected parameter names."
)
Loading
Loading