Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Gaussian process component factories
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
components, enabling full low-level customization
- Factories for all Gaussian process components
- `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- Interpoint constraints for continuous search spaces
- `parameters/selector.py` module enabling convenient parameter subselection
- `parameter_names` attribute to basic kernels for controlling the considered parameters
- `IndexKernel` and `PositiveIndexKernel` classes
- Interpoint constraints for continuous search spaces

### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)
- `Kernel.to_gpytorch` now takes a `SearchSpace` instead of explicit `ard_num_dims`,
`batch_shape` and `active_dims` arguments, as kernels now automatically adjust this
configuration to the given search space
- `GaussianProcessSurrogate` no longer automatically adds a task kernel in multi-task
scenarios. Custom kernel architectures must now explicitly include the task kernel,
e.g. via `ICMKernelFactory`

### Removed
- `parallel_runs` argument from `simulate_scenarios`, since parallelization
Expand All @@ -30,6 +39,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`GaussianProcessSurrogate.from_preset`

### Deprecations
- Using a custom kernel with `GaussianProcessSurrogate` in a multi-task context now
raises a `DeprecationError` to alert users about the changed kernel logic. This can
be suppressed by setting the `BAYBE_DISABLE_CUSTOM_KERNEL_WARNING` environment
variable to a truthy value
- `set_random_seed` and `temporary_seed` utility functions
- The environment variables
`BAYBE_NUMPY_USE_SINGLE_PRECISION`/`BAYBE_TORCH_USE_SINGLE_PRECISION` have been
Expand Down
8 changes: 7 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,10 @@
- Kathrin Skubch (Merck KGaA, Darmstadt, Germany):\
Transfer learning regression benchmarks infrastructure
- Myra Zmarsly (Merck Life Science KGaA, Darmstadt, Germany):\
Identification of non-dominated parameter configurations
Identification of non-dominated parameter configurations
- Thijs Stuyver (PSL University, Paris, France):\
Adaptive hyper-prior tailored for reaction yield optimization tasks
- Maximilian Fleck (PSL University, Paris, France):\
Adaptive hyper-prior tailored for reaction yield optimization tasks
- Guanming Chen (PSL University, Paris, France):\
Adaptive hyper-prior tailored for reaction yield optimization tasks
94 changes: 75 additions & 19 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,36 @@
from __future__ import annotations

import gc
from abc import ABC
from collections.abc import Sequence
from abc import ABC, abstractmethod
from itertools import chain
from typing import TYPE_CHECKING, Any

from attrs import define
from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import deep_iterable, instance_of
from attrs.validators import optional as optional_v
from typing_extensions import override

from baybe.exceptions import UnmatchedAttributeError
from baybe.priors.base import Prior
from baybe.searchspace.core import SearchSpace
from baybe.serialization.mixin import SerialMixin
from baybe.settings import active_settings
from baybe.utils.basic import get_baseclasses, match_attributes
from baybe.utils.basic import classproperty, get_baseclasses, match_attributes, to_tuple

if TYPE_CHECKING:
import torch

from baybe.surrogates.gaussian_process.components.kernel import PlainKernelFactory


@define(frozen=True)
class Kernel(ABC, SerialMixin):
"""Abstract base class for all kernels."""

@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
"""Attribute names to exclude from gpytorch matching."""
return frozenset()

def to_factory(self) -> PlainKernelFactory:
"""Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501
from baybe.surrogates.gaussian_process.components.kernel import (
Expand All @@ -33,23 +41,23 @@ def to_factory(self) -> PlainKernelFactory:

return PlainKernelFactory(self)

def to_gpytorch(
self,
*,
ard_num_dims: int | None = None,
batch_shape: torch.Size | None = None,
active_dims: Sequence[int] | None = None,
):
@abstractmethod
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
"""Get the active dimensions and the number of ARD dimensions."""

def to_gpytorch(self, searchspace: SearchSpace):
"""Create the gpytorch representation of the kernel."""
import gpytorch.kernels

active_dims, ard_num_dims = self._get_dimensions(searchspace)

# Extract keywords with non-default values. This is required since gpytorch
# makes use of kwargs, i.e. differentiates if certain keywords are explicitly
# passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)`
# fails if we explicitly pass `ard_num_dims=None`.
kw: dict[str, Any] = dict(
ard_num_dims=ard_num_dims, batch_shape=batch_shape, active_dims=active_dims
)
kw: dict[str, Any] = dict(active_dims=active_dims)
kw = {k: v for k, v in kw.items() if v is not None}

# Get corresponding gpytorch kernel class and its base classes
Expand Down Expand Up @@ -79,7 +87,7 @@ def to_gpytorch(
# in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured).
# Exception: initial values are not used during construction but are set
# on the created object (see code at the end of the method).
missing = set(unmatched) - set(kernel_attrs)
missing = set(unmatched) - set(kernel_attrs) - self._whitelisted_attributes
if leftover := {m for m in missing if not m.endswith("_initial_value")}:
raise UnmatchedAttributeError(leftover)

Expand All @@ -92,15 +100,15 @@ def to_gpytorch(

# Convert specified inner kernels to gpytorch, if provided
kernel_dict = {
key: value.to_gpytorch(**kw)
key: value.to_gpytorch(searchspace, **kw)
for key, value in kernel_attrs.items()
if isinstance(value, Kernel)
}

# Create the kernel with all its inner gpytorch objects
kernel_attrs.update(kernel_dict)
kernel_attrs.update(prior_dict)
gpytorch_kernel = kernel_cls(**kernel_attrs, **kw)
gpytorch_kernel = kernel_cls(**kernel_attrs, ard_num_dims=ard_num_dims, **kw)

# If the kernel has a lengthscale, set its initial value
if kernel_cls.has_lengthscale:
Expand All @@ -123,11 +131,59 @@ def to_gpytorch(
class BasicKernel(Kernel, ABC):
"""Abstract base class for all basic kernels."""

parameter_names: tuple[str, ...] | None = field(
default=None,
converter=optional_c(to_tuple),
validator=optional_v(
deep_iterable(
iterable_validator=instance_of(tuple),
member_validator=instance_of(str),
)
),
kw_only=True,
)
"""An optional set of names specifiying the parameters the kernel should act on."""

@override
@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
return frozenset({"parameter_names"})

@override
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
if self.parameter_names is None:
active_dims = None
else:
active_dims = tuple(
chain(
*[
searchspace.get_comp_rep_parameter_indices(name)
for name in self.parameter_names
]
)
)

# We use automatic relevance determination for all kernels
ard_num_dims = (
len(active_dims)
if active_dims is not None
else len(searchspace.comp_rep_columns)
)
return active_dims, ard_num_dims


@define(frozen=True)
class CompositeKernel(Kernel, ABC):
"""Abstract base class for all composite kernels."""

@override
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
return None, None


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
48 changes: 48 additions & 0 deletions baybe/parameters/selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Parameter selectors."""

from abc import abstractmethod
from typing import Protocol

from attrs import define, field
from attrs.validators import instance_of
from typing_extensions import override

from baybe.parameters.base import Parameter
from baybe.utils.basic import to_tuple


class ParameterSelectorProtocol(Protocol):
"""Type protocol specifying the interface parameter selectors need to implement."""

def __call__(self, parameter: Parameter) -> bool:
"""Determine if a parameter should be included in the selection."""


@define
class ParameterSelector(ParameterSelectorProtocol):
"""Base class for parameter selectors."""

exclude: bool = field(default=False, validator=instance_of(bool), kw_only=True)
"""Boolean flag indicating whether invert the selection criterion."""

@abstractmethod
def _is_match(self, parameter: Parameter) -> bool:
"""Determine if a parameter meets the selection criterion."""

@override
def __call__(self, parameter: Parameter) -> bool:
"""Determine if a parameter should be included in the selection."""
result = self._is_match(parameter)
return not result if self.exclude else result


@define
class TypeSelector(ParameterSelector):
"""Select parameters by type."""

parameter_types: tuple[type[Parameter], ...] = field(converter=to_tuple)
"""The parameter types to be selected."""

@override
def _is_match(self, parameter: Parameter) -> bool:
return isinstance(parameter, self.parameter_types)
12 changes: 11 additions & 1 deletion baybe/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from baybe._optional.info import FPSAMPLE_INSTALLED, POLARS_INSTALLED
from baybe.exceptions import NotAllowedError, OptionalImportError
from baybe.utils.basic import classproperty
from baybe.utils.boolean import AutoBool, to_bool
from baybe.utils.boolean import AutoBool, strtobool, to_bool
from baybe.utils.random import _RandomState

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,6 +50,16 @@ def _validate_whitelist_env_vars(vars: dict[str, str], /) -> None:
f"Allowed values for 'BAYBE_TEST_ENV' are "
f"'CORETEST', 'FULLTEST', and 'GPUTEST'. Given: '{value}'"
)

if (value := vars.pop("BAYBE_DISABLE_CUSTOM_KERNEL_WARNING", None)) is not None:
try:
strtobool(value)
except ValueError as ex:
raise ValueError(
f"Invalid value for 'BAYBE_DISABLE_CUSTOM_KERNEL_WARNING'. "
f"Expected a truthy value to disable the error, but got '{value}'."
) from ex

if vars:
raise RuntimeError(f"Unknown 'BAYBE_*' environment variables: {set(vars)}")

Expand Down
10 changes: 6 additions & 4 deletions baybe/surrogates/gaussian_process/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@

from baybe.surrogates.gaussian_process.components.kernel import (
KernelFactory,
KernelFactoryProtocol,
PlainKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
LikelihoodFactory,
LikelihoodFactoryProtocol,
PlainLikelihoodFactory,
)
from baybe.surrogates.gaussian_process.components.mean import (
LazyConstantMeanFactory,
MeanFactory,
MeanFactoryProtocol,
PlainMeanFactory,
)

__all__ = [
# Kernel
"KernelFactory",
"KernelFactoryProtocol",
"PlainKernelFactory",
# Likelihood
"LikelihoodFactory",
"LikelihoodFactoryProtocol",
"PlainLikelihoodFactory",
# Mean
"LazyConstantMeanFactory",
"MeanFactory",
"MeanFactoryProtocol",
"PlainMeanFactory",
]
8 changes: 4 additions & 4 deletions baybe/surrogates/gaussian_process/components/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _validate_component(instance: Any, attribute: Attribute, value: Any) -> None
)


class GPComponentFactory(Protocol, Generic[_T_co]):
class GPComponentFactoryProtocol(Protocol, Generic[_T_co]):
"""A protocol defining the interface expected for GP component factories."""

def __call__(
Expand All @@ -105,7 +105,7 @@ def __call__(


@define(frozen=True)
class PlainGPComponentFactory(GPComponentFactory[_T_co], SerialMixin):
class PlainGPComponentFactory(GPComponentFactoryProtocol[_T_co], SerialMixin):
"""A trivial factory that returns a fixed pre-defined component upon request."""

component: _T_co = field(validator=_validate_component)
Expand All @@ -119,11 +119,11 @@ def __call__(


def to_component_factory(
obj: GPComponent | GPComponentFactory,
obj: GPComponent | GPComponentFactoryProtocol,
/,
*,
component_type: GPComponentType | None = None,
) -> GPComponentFactory:
) -> GPComponentFactoryProtocol:
"""Wrap a component into a plain component factory (with factory passthrough).

Args:
Expand Down
Loading