-
Notifications
You must be signed in to change notification settings - Fork 77
Transfer Learning Decorator #790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev/gp
Are you sure you want to change the base?
Changes from all commits
0611ff6
43d1e8e
542029b
e41a78d
ded217f
05b89b2
21ed537
86ff984
41a1fc7
f535cde
4833819
4d75d5a
1ccad46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import functools | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Iterable | ||
| from functools import partial | ||
|
|
@@ -22,6 +23,7 @@ | |
| to_parameter_selector, | ||
| ) | ||
| from baybe.searchspace.core import SearchSpace | ||
| from baybe.serialization.mixin import SerialMixin | ||
| from baybe.surrogates.gaussian_process.components.generic import ( | ||
| GPComponentFactoryProtocol, | ||
| GPComponentType, | ||
|
|
@@ -44,7 +46,7 @@ | |
|
|
||
|
|
||
| @define | ||
| class _PureKernelFactory(KernelFactoryProtocol, ABC): | ||
| class _PureKernelFactory(KernelFactoryProtocol, SerialMixin, ABC): | ||
| """Base class for pure kernel factories.""" | ||
|
|
||
| # For internal use only: sanity check mechanism to remind developers of new | ||
|
|
@@ -75,6 +77,14 @@ def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...]: | |
| selector = self.parameter_selector or (lambda _: True) | ||
| return tuple(p.name for p in searchspace.parameters if selector(p)) | ||
|
|
||
| def _get_effective_dimensionality(self, searchspace: SearchSpace) -> int: | ||
| """Get the number of computational columns for the selected parameters.""" | ||
| return len( | ||
| searchspace.get_comp_rep_parameter_indices( | ||
| self.parameter_selector or (lambda _: True) | ||
| ) | ||
| ) | ||
|
|
||
| def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None: | ||
| """Validate that the given parameters are supported by the factory. | ||
|
|
||
|
|
@@ -115,6 +125,94 @@ def _make( | |
| """Construct the kernel.""" | ||
|
|
||
|
|
||
| def _enable_transfer_learning( | ||
| cls: type[_PureKernelFactory], name: str | None = None, / | ||
| ) -> type[_PureKernelFactory]: | ||
| """Class decorator enabling BayBE's default transfer learning mechanism. | ||
|
|
||
| When the search space contains a task parameter, the decorated factory | ||
| automatically composes its kernel with BayBE's default task kernel. | ||
| Otherwise, the factory behaves unchanged. | ||
|
|
||
| When used as a decorator (without ``name``), the class is modified in-place. | ||
| When called with a ``name`` argument, a new subclass is created so that the | ||
| original class remains unmodified. The latter form is intended for cases where | ||
| the original class is reused independently elsewhere. | ||
|
|
||
| Args: | ||
| cls: The kernel factory class to decorate. | ||
| name: Optional name for the created class. If provided, a new subclass is | ||
| created instead of modifying ``cls`` in-place. | ||
|
|
||
| Raises: | ||
| TypeError: If the factory already supports task parameters. | ||
|
|
||
| Returns: | ||
| The decorated kernel factory class with transfer learning enabled. | ||
| """ | ||
| if cls._supported_parameter_kinds & _ParameterKind.TASK: | ||
| raise TypeError(f"'{cls.__name__}' already supports task parameters.") | ||
|
|
||
| # This distinction is important for serialization so that the classes can be | ||
| # correctly identified by their names in the subclass registry | ||
| if name is not None: | ||
| # Create a sibling class so the original class remains unmodified. | ||
| # We use cls.__bases__ (not (cls,)) because the new class is conceptually | ||
| # an equivalent variant, not a specialization. Concrete (non-dunder) | ||
| # attributes are copied so the sibling has the same behavior. | ||
| # __module__ must be set explicitly because the Protocol metaclass | ||
| # would otherwise default it to "abc". | ||
| ns = { | ||
| k: v | ||
| for k, v in cls.__dict__.items() | ||
| if not (k.startswith("__") and k.endswith("__")) | ||
| } | ||
| ns["__doc__"] = cls.__doc__ | ||
| ns["__module__"] = cls.__module__ | ||
| target_cls = type(name, cls.__bases__, ns) | ||
| else: | ||
| # Modify the class in-place (avoids name collision in subclass registry) | ||
| target_cls = cls | ||
|
|
||
| original_call = cls.__call__ | ||
| original_supported_kinds = cls._supported_parameter_kinds | ||
| _task_exclude_selector = TypeSelector((TaskParameter,), exclude=True) | ||
|
|
||
| @functools.wraps(original_call) | ||
| def __call__(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor): | ||
| # Temporarily narrow the supported parameter kinds to those of the original | ||
| # class. If the decorator logic is correct, the original factory should never | ||
| # see the extended scope, but this acts as a sanity check to prevent regressions | ||
| broadened_kinds = target_cls._supported_parameter_kinds # type: ignore[attr-defined] | ||
| target_cls._supported_parameter_kinds = original_supported_kinds # type: ignore[attr-defined] | ||
|
|
||
| # Split off the task parameters | ||
| original_selector = self.parameter_selector | ||
| if original_selector is None: | ||
| self.parameter_selector = _task_exclude_selector | ||
| else: | ||
| self.parameter_selector = lambda p: ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, maybe I'm looking at this from the wrong perspective, but isn't that what we want to happen? Example: My search space contains 3 parameters (2 regular and 1 task):
Or am I overlooking something? If you see a problem, a minimum example would be helpful |
||
| _task_exclude_selector(p) and original_selector(p) | ||
| ) | ||
|
|
||
| try: | ||
|
AdrianSosic marked this conversation as resolved.
|
||
| base_kernel = original_call(self, searchspace, train_x, train_y) | ||
| finally: | ||
| target_cls._supported_parameter_kinds = broadened_kinds # type: ignore[attr-defined] | ||
| self.parameter_selector = original_selector | ||
|
|
||
| if searchspace.task_idx is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't oversee 100% where we would put the logic for TL modes that affect other components than the kernel later. But, the current version will definetely work well with the dispatching between
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my current naive vision, the decorator would then not just be applied to the kernel factor but to all components that require changes. It would process all of them and apply the necessary changes, e.g.
However, this is exactly what I meant with Any action items at this point I need to take care of, or can I close? |
||
| icm = ICMKernelFactory(base_kernel_or_factory=base_kernel) | ||
| return icm(searchspace, train_x, train_y) | ||
| return base_kernel | ||
|
|
||
| target_cls.__call__ = __call__ # type: ignore[method-assign] | ||
| target_cls._supported_parameter_kinds = ( # type: ignore[attr-defined] | ||
| cls._supported_parameter_kinds | _ParameterKind.TASK | ||
| ) | ||
| return target_cls | ||
|
|
||
|
|
||
| @define | ||
| class _MetaKernelFactory(KernelFactoryProtocol, ABC): | ||
| """Base class for meta kernel factories that orchestrate other kernel factories.""" | ||
|
|
@@ -150,18 +248,25 @@ class ICMKernelFactory(_MetaKernelFactory): | |
| @base_kernel_factory.default | ||
| def _default_base_kernel_factory(self) -> KernelFactoryProtocol: | ||
| from baybe.surrogates.gaussian_process.presets.baybe import ( | ||
| BayBENumericalKernelFactory, | ||
| _BayBENumericalKernelFactory, | ||
| ) | ||
|
|
||
| return BayBENumericalKernelFactory(TypeSelector((TaskParameter,), exclude=True)) | ||
| assert ( | ||
| _BayBENumericalKernelFactory._supported_parameter_kinds | ||
| is _ParameterKind.REGULAR | ||
| ) | ||
| return _BayBENumericalKernelFactory( | ||
| TypeSelector((TaskParameter,), exclude=True) | ||
| ) | ||
|
|
||
| @task_kernel_factory.default | ||
| def _default_task_kernel_factory(self) -> KernelFactoryProtocol: | ||
| from baybe.surrogates.gaussian_process.presets.baybe import ( | ||
| BayBETaskKernelFactory, | ||
| _BayBETaskKernelFactory, | ||
| ) | ||
|
|
||
| return BayBETaskKernelFactory(TypeSelector((TaskParameter,))) | ||
| assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these new asserts in the defaults for mypy or for another purpose?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, not for mypy, but to guarantee that the defaults that are injected here actually are compatible with the contract that the ICM mechanism requires. However, this is probably from the earlier drafting days and should actually become a proper validation step. Would it be fine for you if I turned this into actual validators?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was commenting due to this simple principle: asserts for anything other than pytest or mypy type-narrowing should be avoided, in the latter case ideally they'd get a comment so we dont have to always redo this kind of thread here |
||
| return _BayBETaskKernelFactory() | ||
|
|
||
| @override | ||
| def __call__( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,39 +26,24 @@ | |
| from baybe.surrogates.gaussian_process.presets.edbo_smoothed import ( | ||
| SmoothedEDBOKernelFactory, | ||
| SmoothedEDBOLikelihoodFactory, | ||
| _SmoothedEDBONumericalKernelFactory, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| @define | ||
| class BayBEKernelFactory(_PureKernelFactory): | ||
| """The default kernel factory for Gaussian process surrogates.""" | ||
|
|
||
| _supported_parameter_kinds: ClassVar[_ParameterKind] = ( | ||
| _ParameterKind.REGULAR | _ParameterKind.TASK | ||
| ) | ||
| # See base class. | ||
|
|
||
| @override | ||
| def _make( | ||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||
| ) -> Kernel: | ||
| from baybe.surrogates.gaussian_process.components.kernel import ICMKernelFactory | ||
| class _BayBENumericalKernelFactory(_SmoothedEDBONumericalKernelFactory): | ||
| """The default numerical kernel factory for GP surrogates.""" | ||
|
|
||
| is_multitask = searchspace.task_idx is not None | ||
| factory = ICMKernelFactory if is_multitask else BayBENumericalKernelFactory | ||
| return factory()(searchspace, train_x, train_y) | ||
|
|
||
|
|
||
| BayBENumericalKernelFactory = SmoothedEDBOKernelFactory | ||
| """The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 | ||
| class BayBEKernelFactory(SmoothedEDBOKernelFactory): # type: ignore[valid-type, misc] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I roughly get the need for this (horrendously feeling) But could you sum tis up in 1-2 sentences in the PR description? It seems the commits also have explanation but it seemingly changed during the course of them and they don't really get the point across very nicely |
||
| """The default kernel factory for GP surrogates.""" | ||
|
|
||
|
|
||
| @define | ||
| class BayBETaskKernelFactory(_PureKernelFactory): | ||
| """The factory providing the default task kernel for Gaussian process surrogates.""" | ||
| class _BayBETaskKernelFactory(_PureKernelFactory): | ||
| """The default task kernel factory for GP surrogates.""" | ||
|
|
||
| _uses_parameter_names: ClassVar[bool] = True | ||
| # See base class. | ||
|
|
@@ -83,11 +68,13 @@ def _make( | |
| ) | ||
|
|
||
|
|
||
| BayBEMeanFactory = LazyConstantMeanFactory | ||
| """The factory providing the default mean function for Gaussian process surrogates.""" | ||
| class BayBEMeanFactory(LazyConstantMeanFactory): | ||
| """The default mean factory for GP surrogates.""" | ||
|
|
||
|
|
||
| class BayBELikelihoodFactory(SmoothedEDBOLikelihoodFactory): | ||
| """The default likelihood factory for GP surrogates.""" | ||
|
|
||
| BayBELikelihoodFactory = SmoothedEDBOLikelihoodFactory | ||
| """The factory providing the default likelihood for Gaussian process surrogates.""" | ||
|
|
||
|
|
||
| @define | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Scienfitz: in principle ready and working. However, I have to admit that this was significantly more painful than anticipated, with many footguns along the way. So I'm open to a very harsh review and a complete change of direction, if you prefer and have an alternative/simpler idea.
But I hope that you get my intent for this: I think we need some mechanism that lets us say
fill this preset with our default approach for a certain aspect that the preset does not specify, and the filling should be very much done without copying code since the BayBE defaults are expected to move. So we need something like a single source of truth. That said: maybe you have some smarter idea.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only annoying ting is the
_NAME/NAMEthing but I dont know a better alternative