Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
2c1da48
Add transfer learning test
AdrianSosic Feb 10, 2026
cefd5d6
Deduplicate task parameter logic in search space
AdrianSosic Feb 10, 2026
d23945d
Clean up Gaussian process class
AdrianSosic Feb 10, 2026
9092b90
GP Cleanup (#743)
AdrianSosic Feb 10, 2026
5e8c7de
Extract current GP defaults in into separate EDBO module
AdrianSosic Feb 10, 2026
6b48716
Add EDBO presets
AdrianSosic Feb 10, 2026
577abd2
Introduce generic component factories
AdrianSosic Feb 10, 2026
19220bd
Make serialization utilties handle generics
AdrianSosic Feb 10, 2026
b0cd172
Correctly (un)structure generic classes
AdrianSosic Feb 10, 2026
50e531e
Add execution path for non-generic classes
AdrianSosic Feb 11, 2026
23a71a7
Add support for GPyTorch kernels
AdrianSosic Feb 11, 2026
f669a21
Block serialization of GPyTorch kernels
AdrianSosic Feb 11, 2026
8234082
Enable configuration of GP mean
AdrianSosic Feb 11, 2026
e15e243
Enable configuration of GP likelihood
AdrianSosic Feb 11, 2026
108e082
Complete the current preset framework
AdrianSosic Feb 11, 2026
50d0a2a
Update CHANGELOG.md
AdrianSosic Feb 11, 2026
aeb5d8a
Update attribute docstrings
AdrianSosic Feb 11, 2026
ffb8702
Reorganize modules into subpackage
AdrianSosic Feb 11, 2026
d908fcb
Add missing components to TypeVar
AdrianSosic Feb 11, 2026
9df998f
Add missing entries to __init__.py
AdrianSosic Feb 11, 2026
7a509d4
Deduplicate code in default preset
AdrianSosic Feb 11, 2026
b1f6215
Add execution path for non-generic classes
AdrianSosic Feb 11, 2026
0056011
Fix typing issues
AdrianSosic Feb 11, 2026
45d0486
Fix assert_never import
AdrianSosic Feb 11, 2026
ee4cde1
Simplify preset loading using name aliases
AdrianSosic Feb 17, 2026
a40f1f8
Generalize converter and tests from kernels to arbitrary components
AdrianSosic Feb 25, 2026
cb847a2
Rename utils.py to factories.py
AdrianSosic Feb 25, 2026
1e913ec
Fix comments
AdrianSosic Feb 25, 2026
0ed6865
Move logic from make_gp_from_preset to from_preset classmethod
AdrianSosic Feb 25, 2026
87496b2
Enable preset selection via string
AdrianSosic Feb 25, 2026
d1789bd
Rename default.py preset to baybe.py
AdrianSosic Feb 25, 2026
22de289
Enable preset default overrides
AdrianSosic Feb 25, 2026
eb3d8d8
Add GP preset test
AdrianSosic Feb 25, 2026
c89d156
Fix imports in __init__.py
AdrianSosic Feb 25, 2026
b4d21d2
Fix typos in named imports
AdrianSosic Feb 25, 2026
763efe3
Fix preset name in changelog
AdrianSosic Feb 25, 2026
70481a6
Update import paths in docstrings
AdrianSosic Feb 25, 2026
438a196
Silence sphinx warnings
AdrianSosic Mar 2, 2026
e1867d0
Fix outdated reference in docstring
AdrianSosic Mar 2, 2026
e43d885
Assert component overrides in preset test
AdrianSosic Mar 3, 2026
aaae262
Test preset loading also without overrides
AdrianSosic Mar 3, 2026
3dba696
Add missing aliases to BAYBE preset module
AdrianSosic Mar 3, 2026
c9e5c2b
Fix mypy errors by using direct imports
AdrianSosic Mar 3, 2026
a407a87
Update CHANGELOG.md
AdrianSosic Mar 4, 2026
8a7b465
Add missing type annotations
AdrianSosic Mar 4, 2026
ff2ef6a
Test GP fitting for overridden presets
AdrianSosic Mar 4, 2026
697d87e
Fix outdated test name
AdrianSosic Mar 4, 2026
1eab7ce
Add GP to component factory names
AdrianSosic Mar 5, 2026
6b49c9f
Move LazyConstantMeanFactory to components/mean.py
AdrianSosic Mar 5, 2026
7bd2c2c
Add mean and likelihood to GP string representation
AdrianSosic Mar 5, 2026
27788de
Add BayBE and GPyTorch component aliases
AdrianSosic Mar 5, 2026
b091d75
Simplify unstructuring logic
AdrianSosic Mar 5, 2026
462cb9c
Fix typo in comment
AdrianSosic Mar 5, 2026
13831b1
Add validation of GP components
AdrianSosic Mar 6, 2026
b9bca21
Silence mypy warnings
AdrianSosic Mar 6, 2026
37b4f44
Add enum member docstrings
AdrianSosic Mar 6, 2026
aef007b
Rename default preset factories to BayBE factories
AdrianSosic Mar 6, 2026
57a02e5
Add more detailed comments on unstructuring with generics
AdrianSosic Mar 6, 2026
8732a0f
Add generic GP components (#746)
AdrianSosic Mar 6, 2026
5684d3e
Add IndexKernel class
AdrianSosic Feb 11, 2026
bd4390d
Add hypothesis strategy for IndexKernel
AdrianSosic Feb 11, 2026
2da7d3f
Add PositiveIndexKernel class
AdrianSosic Mar 3, 2026
82db96d
Add positive index kernel path to hypothesis strategy
AdrianSosic Mar 3, 2026
5b8a1a6
Fix active_dims argument
AdrianSosic Mar 3, 2026
1455365
Move import statement to except block
AdrianSosic Mar 4, 2026
59be817
Fix inequality in error message
AdrianSosic Mar 4, 2026
6840dc3
Add temporary test workaround for IndexKernel classes
AdrianSosic Mar 6, 2026
af70d31
Index Kernel (#747)
AdrianSosic Mar 6, 2026
b84af1d
Absorb index kernel construction into ICMKernelFactory
AdrianSosic Feb 12, 2026
e365862
Add parameter selectors
AdrianSosic Feb 12, 2026
dc9ee90
Rename protocols
AdrianSosic Feb 12, 2026
63e64fa
Implement active kernel dimension control
AdrianSosic Feb 12, 2026
da8ffc0
Move parameter_names attribute down to BasicKernel subclass
AdrianSosic Feb 12, 2026
fb56269
Rename AdditiveKernel to SumKernel with deprecation shim
Scienfitz Mar 11, 2026
3436518
Add outputscale_trainable flag to ScaleKernel
Scienfitz Mar 11, 2026
c4be330
Add kernel arithmetic operators without flattening
Scienfitz Mar 11, 2026
7dc03ee
Add flattening to kernel arithmetic operators
Scienfitz Mar 11, 2026
4c6a3d9
Add tests for kernel arithmetic operators and deprecation
Scienfitz Mar 11, 2026
f74e79a
Rewrite test kernel compositions using arithmetic operators
Scienfitz Mar 11, 2026
b28e256
Fix deprecation serialization
Scienfitz Mar 12, 2026
305cb04
Update CHANGELOG
Scienfitz Mar 11, 2026
4766696
Simplify test
Scienfitz Mar 12, 2026
7bcc328
Add test to assure frozen output scale is not trained
Scienfitz Mar 12, 2026
d18d87c
Update tests/test_kernels.py
Scienfitz Mar 12, 2026
cfd4362
Enable builtin `sum` with kernels
Scienfitz Mar 18, 2026
082f0c3
Make docstring more precise
Scienfitz Mar 30, 2026
18c07b0
Improve docstring
Scienfitz Mar 30, 2026
65c7d75
Revert SumKernel rename
Scienfitz Mar 30, 2026
7907ca8
Enable builtin math.prod() with kernels
Scienfitz Mar 30, 2026
808a14e
Expand kernel operator tests with sum/prod checks
Scienfitz Mar 30, 2026
8947afc
Remove unimportant tests
Scienfitz Mar 30, 2026
92d53c6
Avoid ScaleKernel wrapping for multiplication with 1
AdrianSosic Mar 31, 2026
86e110b
Fix condition in BayBEKernelFactory
AdrianSosic Feb 12, 2026
3714c5f
Add deprecation mechanism for breaking change in kernel logic
AdrianSosic Feb 12, 2026
66b3dfb
Import KernelFactory to components/__init__.py
AdrianSosic Mar 3, 2026
d91623f
Add citation to docstring
AdrianSosic Mar 4, 2026
91bd654
Fix typing
AdrianSosic Mar 4, 2026
02bfc97
Add parameter selection to kernel hypothesis strategies
AdrianSosic Mar 4, 2026
ca37ecf
Drop batch_shape argument from Kernel.to_gpytorch
AdrianSosic Apr 1, 2026
cec1e5d
Update kernel assembly test
AdrianSosic Apr 1, 2026
daaa8f3
Refactor handling of constructor-only attributes in test
AdrianSosic Apr 1, 2026
90d4641
Fix logic of custom kernel converter helper
AdrianSosic Apr 1, 2026
2dd1316
Validate that GP component factories are callable
AdrianSosic Apr 1, 2026
118e019
Update CHANGELOG.md
AdrianSosic Mar 4, 2026
a769a97
Drop unnecessary condition in __rmul__
AdrianSosic Apr 2, 2026
aa07e2a
Simplify converter logic
AdrianSosic Apr 10, 2026
70a4a3e
Make ParameterSelector class abstract
AdrianSosic Apr 10, 2026
eee1504
Fix test parametrization
AdrianSosic Apr 10, 2026
d486c84
Rename selector.py to selectors.py
AdrianSosic Apr 10, 2026
8480607
Fix variable reference in kernel translation test
AdrianSosic Apr 10, 2026
7cd2321
Use chain.from_iterable instead of unpacking
AdrianSosic Apr 10, 2026
9cc4647
Replace hard-coded parameter type in deprecation error message
AdrianSosic Apr 10, 2026
9c5cf18
Add comment explaining the role of active_dims=None
AdrianSosic Apr 10, 2026
d13e365
Fix keyword logic in Kernel.to_gpytorch
AdrianSosic Apr 10, 2026
1707e3f
Add NameSelector class
AdrianSosic Apr 10, 2026
c859d85
Refine NameSelector.parameter_names field specification
AdrianSosic Apr 10, 2026
3a1c421
Make KernelFactory class abstract
AdrianSosic Apr 10, 2026
66b6f06
Remove unnecessary keyword assignment
AdrianSosic Apr 10, 2026
4edf552
Add converters to ICMKernelFactory
AdrianSosic Apr 10, 2026
a1bc442
Revise comment on kernel attribute matching
AdrianSosic Apr 10, 2026
4811dae
Fix converter specification
AdrianSosic Apr 10, 2026
4347c76
Add convenience converter for parameter selectors
AdrianSosic Apr 10, 2026
88468a1
Turn KernelFactory into mixin class for parameter selection
AdrianSosic Apr 10, 2026
2d414a2
Override default selectors for existing factories
AdrianSosic Apr 10, 2026
1b248a5
Turn assert statement into proper AssertionError
AdrianSosic Apr 10, 2026
fc88c0a
Add Convenience Kernel Arithmetic (#763)
AdrianSosic Apr 13, 2026
f18cd32
Use regex by default for NameSelector
AdrianSosic Apr 13, 2026
a6d611e
Rename parameter_types attribute of TypeSelector to types
AdrianSosic Apr 13, 2026
55582fc
Rephrase changelog entry
AdrianSosic Apr 13, 2026
54a5674
Merge branch 'dev/gp' into refactor/multi_task
AdrianSosic Apr 13, 2026
20f106b
Fix typing issues
AdrianSosic Apr 13, 2026
3c38d34
Add ignore to docs/conf.py
AdrianSosic Apr 13, 2026
f943eeb
Add missing searchspace argument
AdrianSosic Apr 13, 2026
9ef2604
Kernel dimension control (#748)
AdrianSosic Apr 13, 2026
59e9697
Move _ParameterSelectorMixin to components/kernel.py
AdrianSosic Apr 15, 2026
bcc51bb
Turn mixin class back into ABC
AdrianSosic Apr 15, 2026
9c6051c
Add ParameterKind Flag enum and Parameter.kind property
AdrianSosic Apr 2, 2026
447cf6a
Wire parameter kind validation into kernel factories via template method
AdrianSosic Apr 15, 2026
5845f0f
Remove dead TaskParameter filtering in kernel factories
AdrianSosic Apr 2, 2026
741997b
Update CHANGELOG.md
AdrianSosic Apr 2, 2026
762b8b9
Add parameter support test
AdrianSosic Apr 16, 2026
d6f3128
Make ICMKernelFactory gpytorch-compatible
AdrianSosic Apr 16, 2026
6ee7b70
Move __attrs_post_init__ to the right place
AdrianSosic Apr 22, 2026
fad7525
Rename _KernelFactory to _PureKernelFactory
AdrianSosic Apr 24, 2026
5813085
Add intermediate _MetaKernelFactory base class
AdrianSosic Apr 24, 2026
6161097
Use regex for private classes in conf.py nitpick ignore
AdrianSosic Apr 24, 2026
c92d3c7
Make ParameterKind machinery private
AdrianSosic Apr 24, 2026
07d5d16
Parameter validation for kernel factories (#776)
AdrianSosic Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,43 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Settings option for random seed control
- `identify_non_dominated_configurations` method to `Campaign` and `Objective`
for determining the Pareto front
- Gaussian process component factories
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
components, enabling full low-level customization
- Factories for all Gaussian process components
- `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories
- `parameter_names` attribute to basic kernels for controlling the considered parameters
- `ParameterKind` flag enum for classifying parameters by their role and automatic
parameter kind validation in kernel factories
- `IndexKernel` and `PositiveIndexKernel` classes
- Interpoint constraints for continuous search spaces
- `IndexKernel` and `PositiveIndexKernel` classes
- Addition and multiplication operators for kernel objects, enabling kernel
composition via `+` (sum) and `*` (product), as well as `constant * kernel`
for creating a `ScaleKernel` with a fixed output scale

### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)
- `Kernel.to_gpytorch` now takes a `SearchSpace` instead of explicit `ard_num_dims`,
`batch_shape` and `active_dims` arguments, as kernels now automatically adjust this
configuration to the given search space
- `GaussianProcessSurrogate` no longer automatically adds a task kernel in multi-task
scenarios. Custom kernel architectures must now explicitly include the task kernel,
e.g. via `ICMKernelFactory`

### Removed
- `parallel_runs` argument from `simulate_scenarios`, since parallelization
can now be conveniently controlled via the new `Settings` mechanism
- `make_gp_from_preset` utility function, since the same functionality is offered by
`GaussianProcessSurrogate.from_preset`

### Deprecations
- Using a custom kernel with `GaussianProcessSurrogate` in a multi-task context now
raises a `DeprecationError` to alert users about the changed kernel logic. This can
be suppressed by setting the `BAYBE_DISABLE_CUSTOM_KERNEL_WARNING` environment
variable to a truthy value
- `set_random_seed` and `temporary_seed` utility functions
- The environment variables
`BAYBE_NUMPY_USE_SINGLE_PRECISION`/`BAYBE_TORCH_USE_SINGLE_PRECISION` have been
Expand Down
186 changes: 155 additions & 31 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,123 @@
from __future__ import annotations

import gc
from abc import ABC
from abc import ABC, abstractmethod
from itertools import chain
from typing import TYPE_CHECKING, Any

from attrs import define
from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import deep_iterable, instance_of
from attrs.validators import optional as optional_v
from typing_extensions import override

from baybe.exceptions import UnmatchedAttributeError
from baybe.priors.base import Prior
from baybe.searchspace.core import SearchSpace
from baybe.serialization.mixin import SerialMixin
from baybe.settings import active_settings
from baybe.utils.basic import get_baseclasses, match_attributes
from baybe.utils.basic import classproperty, get_baseclasses, match_attributes, to_tuple

if TYPE_CHECKING:
import torch

from baybe.surrogates.gaussian_process.kernel_factory import PlainKernelFactory
from baybe.surrogates.gaussian_process.components.kernel import PlainKernelFactory


@define(frozen=True)
class Kernel(ABC, SerialMixin):
"""Abstract base class for all kernels."""

def __add__(self, other: Any) -> Kernel:
"""Create a sum kernel from two kernels.

Flattens nested sums so that ``(a + b) + c`` yields
``AdditiveKernel([a, b, c])`` instead of
``AdditiveKernel([AdditiveKernel([a, b]), c])``.
"""
if isinstance(other, Kernel):
from baybe.kernels.composite import AdditiveKernel

left = self.base_kernels if isinstance(self, AdditiveKernel) else (self,)
right = (
other.base_kernels if isinstance(other, AdditiveKernel) else (other,)
)
return AdditiveKernel([*left, *right])
return NotImplemented

def __radd__(self, other: Any) -> Kernel:
"""Support right-hand addition for kernel objects."""
# Enable use with built-in sum(), which starts with 0 + first_element.
if other == 0:
return self
if isinstance(other, Kernel):
return self.__add__(other)
return NotImplemented

def __mul__(self, other: Any) -> Kernel:
"""Create a product kernel or scale kernel.

When multiplied with another kernel, a product kernel is created. Nested
products are flattened so that ``(a * b) * c`` yields
``ProductKernel([a, b, c])``. When multiplied with a numeric constant, a scale
kernel with a fixed (non-trainable) output scale is created.
"""
if isinstance(other, Kernel):
from baybe.kernels.composite import ProductKernel

left = self.base_kernels if isinstance(self, ProductKernel) else (self,)
right = other.base_kernels if isinstance(other, ProductKernel) else (other,)
return ProductKernel([*left, *right])
if isinstance(other, (int, float)):
if other == 1:
return self

from baybe.kernels.composite import ScaleKernel

return ScaleKernel(
base_kernel=self,
outputscale_initial_value=float(other),
outputscale_trainable=False,
)
return NotImplemented

def __rmul__(self, other: Any) -> Kernel:
"""Support right-hand multiplication, enabling ``constant * kernel``."""
# Enable use with math.prod(), which starts with 1 * first_element.
return self.__mul__(other)

@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
"""Attribute names to exclude from gpytorch matching."""
return frozenset()

def to_factory(self) -> PlainKernelFactory:
"""Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.kernel_factory.PlainKernelFactory`.""" # noqa: E501
from baybe.surrogates.gaussian_process.kernel_factory import PlainKernelFactory
"""Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.components.PlainKernelFactory`.""" # noqa: E501
from baybe.surrogates.gaussian_process.components.kernel import (
PlainKernelFactory,
)

return PlainKernelFactory(self)

def to_gpytorch(
self,
*,
ard_num_dims: int | None = None,
batch_shape: torch.Size | None = None,
active_dims: tuple[int, ...] | None = None,
):
@abstractmethod
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
"""Get the active dimensions and the number of ARD dimensions."""

def to_gpytorch(self, searchspace: SearchSpace):
"""Create the gpytorch representation of the kernel."""
import gpytorch.kernels

# Extract keywords with non-default values. This is required since gpytorch
# makes use of kwargs, i.e. differentiates if certain keywords are explicitly
# passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)`
# fails if we explicitly pass `ard_num_dims=None`.
kw: dict[str, Any] = dict(
ard_num_dims=ard_num_dims, batch_shape=batch_shape, active_dims=active_dims
)
kw = {k: v for k, v in kw.items() if v is not None}
active_dims, ard_num_dims = self._get_dimensions(searchspace)

# Get corresponding gpytorch kernel class and its base classes
kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
try:
kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
except AttributeError:
import botorch.models.kernels.positive_index

kernel_cls = getattr(
botorch.models.kernels.positive_index, self.__class__.__name__
)
base_classes = get_baseclasses(kernel_cls, abstract=True)

# Fetch the necessary gpytorch constructor parameters of the kernel.
Expand All @@ -66,11 +135,17 @@ def to_gpytorch(
unmatched_attrs.update(unmatched)

# Sanity check: all attributes of the BayBE kernel need a corresponding match
# in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured).
# Exception: initial values are not used during construction but are set
# on the created object (see code at the end of the method).
missing = set(unmatched) - set(kernel_attrs)
if leftover := {m for m in missing if not m.endswith("_initial_value")}:
# with the gpytorch kernel signature (otherwise, the BayBE kernel class is
# misconfigured). Exceptions: initial values and trainability flags are not used
# during construction but are set on the created object after construction.
missing = (
set(unmatched_attrs) - set(kernel_attrs) - self._whitelisted_attributes
)
if leftover := {
m
for m in missing
if not m.endswith("_initial_value") and not m.endswith("_trainable")
}:
raise UnmatchedAttributeError(leftover)

# Convert specified priors to gpytorch, if provided
Expand All @@ -82,15 +157,17 @@ def to_gpytorch(

# Convert specified inner kernels to gpytorch, if provided
kernel_dict = {
key: value.to_gpytorch(**kw)
key: value.to_gpytorch(searchspace)
for key, value in kernel_attrs.items()
if isinstance(value, Kernel)
}

# Create the kernel with all its inner gpytorch objects
kernel_attrs.update(kernel_dict)
kernel_attrs.update(prior_dict)
gpytorch_kernel = kernel_cls(**kernel_attrs, **kw)
gpytorch_kernel = kernel_cls(
**kernel_attrs, ard_num_dims=ard_num_dims, active_dims=active_dims
)

# If the kernel has a lengthscale, set its initial value
if kernel_cls.has_lengthscale:
Expand All @@ -113,11 +190,58 @@ def to_gpytorch(
class BasicKernel(Kernel, ABC):
"""Abstract base class for all basic kernels."""

parameter_names: tuple[str, ...] | None = field(
default=None,
converter=optional_c(to_tuple),
validator=optional_v(
deep_iterable(
iterable_validator=instance_of(tuple),
member_validator=instance_of(str),
)
),
kw_only=True,
)
"""An optional set of names specifiying the parameters the kernel should act on."""

@override
@classproperty
def _whitelisted_attributes(cls) -> frozenset[str]:
return frozenset({"parameter_names"})

@override
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
if self.parameter_names is None:
# `None` is gpytorch's default indicating that all dimensions are active
active_dims = None
else:
active_dims = tuple(
chain.from_iterable(
searchspace.get_comp_rep_parameter_indices(name)
for name in self.parameter_names
)
)

# We use automatic relevance determination for all kernels
ard_num_dims = (
len(active_dims)
if active_dims is not None
else len(searchspace.comp_rep_columns)
)
return active_dims, ard_num_dims


@define(frozen=True)
class CompositeKernel(Kernel, ABC):
"""Abstract base class for all composite kernels."""

@override
def _get_dimensions(
self, searchspace: SearchSpace
) -> tuple[tuple[int, ...] | None, int | None]:
return None, None


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
24 changes: 24 additions & 0 deletions baybe/kernels/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,29 @@ class RQKernel(BasicKernel):
"""An optional initial value for the kernel lengthscale."""


@define(frozen=True)
class IndexKernel(BasicKernel):
"""An index kernel for transfer learning across tasks."""

num_tasks: int = field(validator=[instance_of(int), ge(2)])
"""The number of tasks."""

rank: int = field(validator=[instance_of(int), ge(1)])
"""The rank of the task covariance matrix."""

@rank.validator
def _validate_rank(self, _, rank: int):
if rank > self.num_tasks:
raise ValueError(
f"The rank of the task covariance matrix must be smaller than "
f"the number of tasks. Got rank {rank} > {self.num_tasks} tasks."
)


@define(frozen=True)
class PositiveIndexKernel(IndexKernel):
"""A positive index kernel for transfer learning across tasks."""


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
8 changes: 8 additions & 0 deletions baybe/kernels/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class ScaleKernel(CompositeKernel):
)
"""An optional initial value for the output scale."""

outputscale_trainable: bool = field(default=True, validator=instance_of(bool))
"""Boolean flag indicating whether the output scale is trainable.

If ``False``, the output scale is frozen at its initial value and excluded from
optimization."""

@override
def to_gpytorch(self, *args, **kwargs):
import torch
Expand All @@ -45,6 +51,8 @@ def to_gpytorch(self, *args, **kwargs):
gpytorch_kernel.outputscale = torch.tensor(
initial_value, dtype=active_settings.DTypeFloatTorch
)
if not self.outputscale_trainable:
gpytorch_kernel.raw_outputscale.requires_grad_(False)
return gpytorch_kernel


Expand Down
8 changes: 8 additions & 0 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from baybe.utils.metadata import MeasurableMetadata, to_metadata

if TYPE_CHECKING:
from baybe.parameters.enum import _ParameterKind
from baybe.searchspace.continuous import SubspaceContinuous
from baybe.searchspace.core import SearchSpace
from baybe.searchspace.discrete import SubspaceDiscrete
Expand Down Expand Up @@ -77,6 +78,13 @@ def is_discrete(self) -> bool:
"""Boolean indicating if this is a discrete parameter."""
return isinstance(self, DiscreteParameter)

@property
def _kind(self) -> _ParameterKind:
"""The kind of the parameter."""
from baybe.parameters.enum import _ParameterKind

return _ParameterKind.from_parameter(self)

@property
@abstractmethod
def comp_rep_columns(self) -> tuple[str, ...]:
Expand Down
Loading
Loading