Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gc
from collections.abc import Iterable, Sequence
from enum import Enum
from typing import cast
from typing import TYPE_CHECKING, cast

import pandas as pd
from attrs import define, field
Expand All @@ -28,6 +28,9 @@
from baybe.serialization import SerialMixin, converter, select_constructor_hook
from baybe.utils.conversion import to_string

if TYPE_CHECKING:
from baybe.parameters.selectors import ParameterSelectorProtocol


class SearchSpaceType(Enum):
"""Enum class for different types of search spaces and respective compatibility."""
Expand Down Expand Up @@ -287,35 +290,38 @@ def n_tasks(self) -> int:
return 1
return len(task_param.values)

def get_comp_rep_parameter_indices(self, name: str, /) -> tuple[int, ...]:
"""Find a parameter's column indices in the computational representation.
def get_comp_rep_parameter_indices(
self,
name_or_selector: str | ParameterSelectorProtocol,
/,
) -> tuple[int, ...]:
"""Find comp-rep column indices for a parameter selection.

When called with a parameter name, returns the indices for that single
parameter. When called with a
:class:`~baybe.parameters.selectors.ParameterSelectorProtocol`,
returns the combined indices for all matching parameters.

Args:
name: The name of the parameter whose columns indices are to be retrieved.

Raises:
ValueError: If no parameter with the provided name exists.
ValueError: If more than one parameter with the provided name exists.
name_or_selector: Either the name of a single parameter or a selector
that filters parameters to be included.

Returns:
A tuple containing the integer indices of the columns in the computational
representation associated with the parameter. When the parameter is not part
of the computational representation, an empty tuple is returned.
representation associated with the selected parameter(s). When a selected
parameter is not part of the computational representation, it contributes
no indices.
"""
params = self.get_parameters_by_name([name])
if len(params) < 1:
raise ValueError(
f"There exists no parameter named '{name}' in the search space."
)
if len(params) > 1:
raise ValueError(
f"There exist multiple parameter matches for '{name}' in the search "
f"space."
)
p = params[0]
if isinstance(name_or_selector, str):
params: list[Parameter] = [
p for p in self.parameters if p.name == name_or_selector
]
else:
params = [p for p in self.parameters if name_or_selector(p)]

return tuple(
i
for p in params
for i, col in enumerate(self.comp_rep_columns)
if col in p.comp_rep_columns
)
Expand Down
115 changes: 110 additions & 5 deletions baybe/surrogates/gaussian_process/components/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import functools
from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import partial
Expand All @@ -22,6 +23,7 @@
to_parameter_selector,
)
from baybe.searchspace.core import SearchSpace
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.gaussian_process.components.generic import (
GPComponentFactoryProtocol,
GPComponentType,
Expand All @@ -44,7 +46,7 @@


@define
class _PureKernelFactory(KernelFactoryProtocol, ABC):
class _PureKernelFactory(KernelFactoryProtocol, SerialMixin, ABC):
"""Base class for pure kernel factories."""

# For internal use only: sanity check mechanism to remind developers of new
Expand Down Expand Up @@ -75,6 +77,14 @@ def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...]:
selector = self.parameter_selector or (lambda _: True)
return tuple(p.name for p in searchspace.parameters if selector(p))

def _get_effective_dimensionality(self, searchspace: SearchSpace) -> int:
"""Get the number of computational columns for the selected parameters."""
return len(
searchspace.get_comp_rep_parameter_indices(
self.parameter_selector or (lambda _: True)
)
)

def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None:
"""Validate that the given parameters are supported by the factory.

Expand Down Expand Up @@ -115,6 +125,94 @@ def _make(
"""Construct the kernel."""


def _enable_transfer_learning(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Scienfitz: in principle ready and working. However, I have to admit that this was significantly more painful than anticipated, with many footguns along the way. So I'm open to a very harsh review and a complete change of direction, if you prefer and have an alternative/simpler idea.

But I hope that you get my intent for this: I think we need some mechanism that lets us say fill this preset with our default approach for a certain aspect that the preset does not specify, and the filling should be very much done without copying code since the BayBE defaults are expected to move. So we need something like a single source of truth. That said: maybe you have some smarter idea.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only annoying ting is the _NAME / NAME thing but I dont know a better alternative

cls: type[_PureKernelFactory], name: str | None = None, /
) -> type[_PureKernelFactory]:
"""Class decorator enabling BayBE's default transfer learning mechanism.

When the search space contains a task parameter, the decorated factory
automatically composes its kernel with BayBE's default task kernel.
Otherwise, the factory behaves unchanged.

When used as a decorator (without ``name``), the class is modified in-place.
When called with a ``name`` argument, a new subclass is created so that the
original class remains unmodified. The latter form is intended for cases where
the original class is reused independently elsewhere.

Args:
cls: The kernel factory class to decorate.
name: Optional name for the created class. If provided, a new subclass is
created instead of modifying ``cls`` in-place.

Raises:
TypeError: If the factory already supports task parameters.

Returns:
The decorated kernel factory class with transfer learning enabled.
"""
if cls._supported_parameter_kinds & _ParameterKind.TASK:
raise TypeError(f"'{cls.__name__}' already supports task parameters.")

# This distinction is important for serialization so that the classes can be
# correctly identified by their names in the subclass registry
if name is not None:
# Create a sibling class so the original class remains unmodified.
# We use cls.__bases__ (not (cls,)) because the new class is conceptually
# an equivalent variant, not a specialization. Concrete (non-dunder)
# attributes are copied so the sibling has the same behavior.
# __module__ must be set explicitly because the Protocol metaclass
# would otherwise default it to "abc".
ns = {
k: v
for k, v in cls.__dict__.items()
if not (k.startswith("__") and k.endswith("__"))
}
ns["__doc__"] = cls.__doc__
ns["__module__"] = cls.__module__
target_cls = type(name, cls.__bases__, ns)
else:
# Modify the class in-place (avoids name collision in subclass registry)
target_cls = cls

original_call = cls.__call__
original_supported_kinds = cls._supported_parameter_kinds
_task_exclude_selector = TypeSelector((TaskParameter,), exclude=True)

@functools.wraps(original_call)
def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor):
# Temporarily narrow the supported parameter kinds to those of the original
# class. If the decorator logic is correct, the original factory should never
# see the extended scope, but this acts as a sanity check to prevent regressions
broadened_kinds = target_cls._supported_parameter_kinds # type: ignore[attr-defined]
target_cls._supported_parameter_kinds = original_supported_kinds # type: ignore[attr-defined]

# Split off the task parameters
original_selector = self.parameter_selector
if original_selector is None:
self.parameter_selector = _task_exclude_selector
else:
self.parameter_selector = lambda p: (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TaskParameter will be excluded here if the orgininal parameter_selector contained it. Should we maybe raise a warning?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, maybe I'm looking at this from the wrong perspective, but isn't that what we want to happen?

Example: My search space contains 3 parameters (2 regular and 1 task): (p1, p2, p_task). If I now call a factory that is TL-compatible (which the factories that get decorated de facto are), then two things can happen:

  1. I pass a selector that excludes the task parameter, e.g. filters down to just p1. In this case, I want that the inner logic gets executed, i.e. the execution internally dispatches to the original non-decorated factory.
  2. I pass a selector that includes the task parameter, e.g. selects p1 and task. In this case, task should be split off, the regular factory should be called on p1, and then we assemble everything via the ICM mechanism.

Or am I overlooking something? If you see a problem, a minimum example would be helpful

_task_exclude_selector(p) and original_selector(p)
)

try:
Comment thread
AdrianSosic marked this conversation as resolved.
base_kernel = original_call(self, searchspace, train_x, train_y)
finally:
target_cls._supported_parameter_kinds = broadened_kinds # type: ignore[attr-defined]
self.parameter_selector = original_selector

if searchspace.task_idx is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't oversee 100% where we would put the logic for TL modes that affect other components than the kernel later. But, the current version will definetely work well with the dispatching between IndexKernel and PositiveIndexKernel, which will happen within the ICMKernelFactory.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my current naive vision, the decorator would then not just be applied to the kernel factor but to all components that require changes. It would process all of them and apply the necessary changes, e.g.

  • for kernel-TL, it would do exactly what we do no, i.e. only modify the kernel factory
  • for mean-injection-TL, it would not affect the kernel but alter the mean-factory to do the prior-mean stuff
  • ...

However, this is exactly what I meant with I can't yet foresee exactly, so let's keep the decorator private for now :D But yeah, potentially the decorator logic might become quite a beast, since it not only needs to capture TL stuff across all components, but maybe also multi-fidelity stuff etc in the end. My current fear is that the complexity might get out of hands...

Any action items at this point I need to take care of, or can I close?

icm = ICMKernelFactory(base_kernel_or_factory=base_kernel)
return icm(searchspace, train_x, train_y)
return base_kernel

target_cls.__call__ = __call__ # type: ignore[method-assign]
target_cls._supported_parameter_kinds = ( # type: ignore[attr-defined]
cls._supported_parameter_kinds | _ParameterKind.TASK
)
return target_cls


@define
class _MetaKernelFactory(KernelFactoryProtocol, ABC):
"""Base class for meta kernel factories that orchestrate other kernel factories."""
Expand Down Expand Up @@ -150,18 +248,25 @@ class ICMKernelFactory(_MetaKernelFactory):
@base_kernel_factory.default
def _default_base_kernel_factory(self) -> KernelFactoryProtocol:
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBENumericalKernelFactory,
_BayBENumericalKernelFactory,
)

return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True))
assert (
_BayBENumericalKernelFactory._supported_parameter_kinds
is _ParameterKind.REGULAR
)
return _BayBENumericalKernelFactory(
TypeSelector((TaskParameter,), exclude=True)
)

@task_kernel_factory.default
def _default_task_kernel_factory(self) -> KernelFactoryProtocol:
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBETaskKernelFactory,
_BayBETaskKernelFactory,
)

return BayBETaskKernelFactory(TypeSelector((TaskParameter,)))
assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these new asserts in the defaults for mypy or for another purpose?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not for mypy, but to guarantee that the defaults that are injected here actually are compatible with the contract that the ICM mechanism requires. However, this is probably from the earlier drafting days and should actually become a proper validation step. Would it be fine for you if I turned this into actual validators?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was commenting due to this simple principle: asserts for anything other than pytest or mypy type-narrowing should be avoided, in the latter case ideally they'd get a comment so we dont have to always redo this kind of thread here

return _BayBETaskKernelFactory()

@override
def __call__(
Expand Down
39 changes: 13 additions & 26 deletions baybe/surrogates/gaussian_process/presets/baybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,24 @@
from baybe.surrogates.gaussian_process.presets.edbo_smoothed import (
SmoothedEDBOKernelFactory,
SmoothedEDBOLikelihoodFactory,
_SmoothedEDBONumericalKernelFactory,
)

if TYPE_CHECKING:
from torch import Tensor


@define
class BayBEKernelFactory(_PureKernelFactory):
"""The default kernel factory for Gaussian process surrogates."""

_supported_parameter_kinds: ClassVar[_ParameterKind] = (
_ParameterKind.REGULAR | _ParameterKind.TASK
)
# See base class.

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory
class _BayBENumericalKernelFactory(_SmoothedEDBONumericalKernelFactory):
"""The default numerical kernel factory for GP surrogates."""

is_multitask = searchspace.task_idx is not None
factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory
return factory()(searchspace, train_x, train_y)


BayBENumericalKernelFactory = SmoothedEDBOKernelFactory
"""The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501
class BayBEKernelFactory(SmoothedEDBOKernelFactory): # type: ignore[valid-type, misc]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I roughly get the need for this (horrendously feeling) _NAME / NAME split for some of the objects (to which I couldnt come up witha better idea either)

But could you sum tis up in 1-2 sentences in the PR description? It seems the commits also have explanation but it seemingly changed during the course of them and they don't really get the point across very nicely

"""The default kernel factory for GP surrogates."""


@define
class BayBETaskKernelFactory(_PureKernelFactory):
"""The factory providing the default task kernel for Gaussian process surrogates."""
class _BayBETaskKernelFactory(_PureKernelFactory):
"""The default task kernel factory for GP surrogates."""

_uses_parameter_names: ClassVar[bool] = True
# See base class.
Expand All @@ -83,11 +68,13 @@ def _make(
)


BayBEMeanFactory = LazyConstantMeanFactory
"""The factory providing the default mean function for Gaussian process surrogates."""
class BayBEMeanFactory(LazyConstantMeanFactory):
"""The default mean factory for GP surrogates."""


class BayBELikelihoodFactory(SmoothedEDBOLikelihoodFactory):
"""The default likelihood factory for GP surrogates."""

BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory
"""The factory providing the default likelihood for Gaussian process surrogates."""


@define
Expand Down
19 changes: 5 additions & 14 deletions baybe/surrogates/gaussian_process/presets/chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,17 @@
import math
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from attrs import define
from typing_extensions import override

from baybe.kernels.basic import MaternKernel
from baybe.kernels.composite import ScaleKernel
from baybe.parameters.categorical import TaskParameter
from baybe.parameters.selectors import (
ParameterSelectorProtocol,
TypeSelector,
to_parameter_selector,
)
from baybe.priors.basic import GammaPrior
from baybe.surrogates.gaussian_process.components.fit_criterion import (
_MLLForNonTLFitCriterionFactory,
)
from baybe.surrogates.gaussian_process.components.kernel import (
_enable_transfer_learning,
_PureKernelFactory,
)
from baybe.surrogates.gaussian_process.components.likelihood import (
Expand All @@ -36,24 +31,20 @@
from baybe.searchspace.core import SearchSpace


@_enable_transfer_learning
@define
class CHENKernelFactory(_PureKernelFactory):
"""A factory providing adaptive hyperprior kernels as proposed by :cite:p:`Chen2026`.""" # noqa: E501

_uses_parameter_names: ClassVar[bool] = True
# See base class.
Comment thread
AdrianSosic marked this conversation as resolved.

parameter_selector: ParameterSelectorProtocol | None = field(
factory=lambda: TypeSelector([TaskParameter], exclude=True),
converter=to_parameter_selector,
)
# TODO: Reuse base attribute (https://github.com/python-attrs/attrs/pull/1429)

@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
lengthscale = 0.4 * math.sqrt(train_x.shape[-1]) + 4.0
n_dimensions = self._get_effective_dimensionality(searchspace)
lengthscale = 0.4 * math.sqrt(n_dimensions) + 4.0
lengthscale_prior = GammaPrior(2.0 * lengthscale, 2.0)
lengthscale_initial_value = lengthscale
outputscale_prior = GammaPrior(1.0 * lengthscale, 1.0)
Expand Down
Loading
Loading