Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- Interpoint constraints for continuous search spaces
- `IndexKernel` and `PositiveIndexKernel` classes
- Addition and multiplication operators for kernel objects, enabling kernel
composition via `+` (sum) and `*` (product), as well as `constant * kernel`
for creating a `ScaleKernel` with a fixed output scale

### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
Expand Down
68 changes: 65 additions & 3 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,64 @@
class Kernel(ABC, SerialMixin):
"""Abstract base class for all kernels."""

def __add__(self, other: Any) -> Kernel:
Comment thread
AdrianSosic marked this conversation as resolved.
"""Create a sum kernel from two kernels.

Flattens nested sums so that ``(a + b) + c`` yields
``AdditiveKernel([a, b, c])`` instead of
``AdditiveKernel([AdditiveKernel([a, b]), c])``.
"""
if isinstance(other, Kernel):
from baybe.kernels.composite import AdditiveKernel

left = self.base_kernels if isinstance(self, AdditiveKernel) else (self,)
right = (
other.base_kernels if isinstance(other, AdditiveKernel) else (other,)
)
return AdditiveKernel([*left, *right])
return NotImplemented

def __radd__(self, other: Any) -> Kernel:
"""Support right-hand addition for kernel objects."""
# Enable use with built-in sum(), which starts with 0 + first_element.
if other == 0:
return self
if isinstance(other, Kernel):
return self.__add__(other)
Comment thread
AdrianSosic marked this conversation as resolved.
return NotImplemented

def __mul__(self, other: Any) -> Kernel:
"""Create a product kernel or scale kernel.

When multiplied with another kernel, a product kernel is created. Nested
products are flattened so that ``(a * b) * c`` yields
``ProductKernel([a, b, c])``. When multiplied with a numeric constant, a scale
kernel with a fixed (non-trainable) output scale is created.
"""
if isinstance(other, Kernel):
from baybe.kernels.composite import ProductKernel

left = self.base_kernels if isinstance(self, ProductKernel) else (self,)
right = other.base_kernels if isinstance(other, ProductKernel) else (other,)
return ProductKernel([*left, *right])
if isinstance(other, (int, float)):
if other == 1:
return self

from baybe.kernels.composite import ScaleKernel

return ScaleKernel(
base_kernel=self,
outputscale_initial_value=float(other),
outputscale_trainable=False,
)
return NotImplemented

def __rmul__(self, other: Any) -> Kernel:
"""Support right-hand multiplication, enabling ``constant * kernel``."""
# Enable use with math.prod(), which starts with 1 * first_element.
return self.__mul__(other)

def to_factory(self) -> PlainKernelFactory:
"""Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501
from baybe.surrogates.gaussian_process.components.kernel import (
Expand Down Expand Up @@ -77,10 +135,14 @@ def to_gpytorch(

# Sanity check: all attributes of the BayBE kernel need a corresponding match
# in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured).
# Exception: initial values are not used during construction but are set
# on the created object (see code at the end of the method).
# Exceptions: initial values and trainability flags are not used during
# construction but are set on the created object after construction.
missing = set(unmatched) - set(kernel_attrs)
if leftover := {m for m in missing if not m.endswith("_initial_value")}:
if leftover := {
m
for m in missing
if not m.endswith("_initial_value") and not m.endswith("_trainable")
}:
raise UnmatchedAttributeError(leftover)

# Convert specified priors to gpytorch, if provided
Expand Down
8 changes: 8 additions & 0 deletions baybe/kernels/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class ScaleKernel(CompositeKernel):
)
"""An optional initial value for the output scale."""

outputscale_trainable: bool = field(default=True, validator=instance_of(bool))
"""Boolean flag indicating whether the output scale is trainable.

If ``False``, the output scale is frozen at its initial value and excluded from
optimization."""

@override
def to_gpytorch(self, *args, **kwargs):
import torch
Expand All @@ -45,6 +51,8 @@ def to_gpytorch(self, *args, **kwargs):
gpytorch_kernel.outputscale = torch.tensor(
initial_value, dtype=active_settings.DTypeFloatTorch
)
if not self.outputscale_trainable:
gpytorch_kernel.raw_outputscale.requires_grad_(False)
return gpytorch_kernel


Expand Down
4 changes: 2 additions & 2 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytest import param

from baybe.kernels.basic import MaternKernel, RBFKernel
from baybe.kernels.composite import AdditiveKernel, ScaleKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters.numerical import NumericalContinuousParameter
from baybe.surrogates.gaussian_process.components.generic import PlainGPComponentFactory
from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate
Expand All @@ -24,7 +24,7 @@
objective = NumericalTarget("t").to_objective()
measurements = create_fake_input(searchspace.parameters, objective.targets, n_rows=100)

baybe_kernel = ScaleKernel(AdditiveKernel([MaternKernel(), RBFKernel()]))
baybe_kernel = ScaleKernel(MaternKernel() + RBFKernel())
gpytorch_kernel = GPyTorchScaleKernel(GPyTorchMaternKernel() + GPyTorchRBFKernel())


Expand Down
21 changes: 8 additions & 13 deletions tests/test_iterations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
RFFKernel,
RQKernel,
)
from baybe.kernels.composite import AdditiveKernel, ProductKernel, ScaleKernel
from baybe.kernels.composite import ScaleKernel
from baybe.objectives.pareto import ParetoObjective
from baybe.priors import (
GammaPrior,
Expand Down Expand Up @@ -214,18 +214,13 @@ def ongoing_campaign(campaign):
]

valid_composite_kernels = [
AdditiveKernel([MaternKernel(1.5), MaternKernel(2.5)]),
AdditiveKernel([PolynomialKernel(1), PolynomialKernel(2), PolynomialKernel(3)]),
AdditiveKernel([RBFKernel(), RQKernel(), PolynomialKernel(1)]),
ProductKernel([MaternKernel(1.5), MaternKernel(2.5)]),
ProductKernel([RBFKernel(), RQKernel(), PolynomialKernel(1)]),
ProductKernel([PolynomialKernel(1), PolynomialKernel(2), PolynomialKernel(3)]),
AdditiveKernel(
[
ProductKernel([MaternKernel(1.5), MaternKernel(2.5)]),
AdditiveKernel([MaternKernel(1.5), MaternKernel(2.5)]),
]
),
MaternKernel(1.5) + MaternKernel(2.5),
PolynomialKernel(1) + PolynomialKernel(2) + PolynomialKernel(3),
RBFKernel() + RQKernel() + PolynomialKernel(1),
MaternKernel(1.5) * MaternKernel(2.5),
RBFKernel() * RQKernel() * PolynomialKernel(1),
PolynomialKernel(1) * PolynomialKernel(2) * PolynomialKernel(3),
(MaternKernel(1.5) * MaternKernel(2.5)) + (MaternKernel(1.5) + MaternKernel(2.5)),
]

valid_kernels = valid_base_kernels + valid_scale_kernels + valid_composite_kernels
Expand Down
75 changes: 74 additions & 1 deletion tests/test_kernels.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""Kernel tests."""

import math
from typing import Any

import numpy as np
import pytest
import torch
from attrs import asdict, has
from hypothesis import given
from pytest import param

from baybe.kernels.base import BasicKernel, Kernel
from baybe.kernels.basic import IndexKernel
from baybe.kernels.basic import IndexKernel, MaternKernel, RBFKernel
from baybe.kernels.composite import AdditiveKernel, ProductKernel, ScaleKernel
from tests.hypothesis_strategies.kernels import kernels

# TODO: Consider deprecating these attribute names to avoid inconsistencies
Expand Down Expand Up @@ -59,6 +63,11 @@ def validate_gpytorch_kernel_components(obj: Any, mapped: Any, **kwargs) -> None
continue
# <<<<<

# ... or it must be a trainability flag, which is a BayBE-only concept
# applied after GPyTorch kernel construction.
if name.endswith("_trainable"):
continue

# ... or it must be an initial value. Because setting initial values
# involves going through constraint transformations on GPyTorch side (i.e.,
# difference between `<attr>` and `raw_<attr>`), the numerical values will
Expand Down Expand Up @@ -98,3 +107,67 @@ def test_kernel_assembly(kernel: Kernel):

k = kernel.to_gpytorch(**kwargs)
validate_gpytorch_kernel_components(kernel, k, **kwargs)


def test_add_produces_additive_kernel():
"""Adding two kernels produces an AdditiveKernel."""
result = MaternKernel() + RBFKernel()
assert isinstance(result, AdditiveKernel)
assert len(result.base_kernels) == 2
assert isinstance(result.base_kernels[0], MaternKernel)
assert isinstance(result.base_kernels[1], RBFKernel)
assert sum([MaternKernel(), RBFKernel()]) == result


def test_add_chain_flattens():
"""Chaining additions flattens into a single AdditiveKernel."""
result = MaternKernel() + RBFKernel() + MaternKernel(0.5)
assert isinstance(result, AdditiveKernel)
assert len(result.base_kernels) == 3


def test_mul_produces_product_kernel():
"""Multiplying two kernels produces a ProductKernel."""
result = MaternKernel() * RBFKernel()
assert isinstance(result, ProductKernel)
assert len(result.base_kernels) == 2
assert isinstance(result.base_kernels[0], MaternKernel)
assert isinstance(result.base_kernels[1], RBFKernel)
assert math.prod([MaternKernel(), RBFKernel()]) == result


def test_mul_chain_flattens():
"""Chaining multiplications flattens into a single ProductKernel."""
result = MaternKernel() * RBFKernel() * MaternKernel(0.5)
assert isinstance(result, ProductKernel)
assert len(result.base_kernels) == 3


@pytest.mark.parametrize(
("left", "right"),
[
param(MaternKernel(), 3.0, id="kernel_times_float"),
param(MaternKernel(), 5, id="kernel_times_int"),
param(3.0, MaternKernel(), id="float_times_kernel"),
],
)
def test_mul_constant_produces_constant_scale_kernel(left, right):
"""Multiplying a kernel with a numeric constant produces a fixed ScaleKernel."""
result = left * right
gpytorch_kernel = result.to_gpytorch()
initial_outputscale = gpytorch_kernel.outputscale.item()

assert isinstance(result, ScaleKernel)
assert result.outputscale_trainable is False
expected_outputscale = left if isinstance(right, Kernel) else right
assert initial_outputscale == expected_outputscale
assert not result.to_gpytorch().raw_outputscale.requires_grad

# Create a dummy input and compute a loss through the kernel to assert training
# does not affect the output scale
x = torch.randn(5, 1)
loss = gpytorch_kernel(x).evaluate().sum()
loss.backward()
optimizer = torch.optim.SGD(gpytorch_kernel.parameters(), lr=0.1)
optimizer.step()
assert gpytorch_kernel.outputscale.item() == initial_outputscale
Loading