Transfer Learning Decorator#790
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a class-decorator-based mechanism to enable BayBE’s default transfer learning behavior for Gaussian Process kernel factories (automatically composing a base kernel with the default task kernel when a task parameter is present), and refactors several kernel presets to use it.
Changes:
- Added
_enable_transfer_learningdecorator in the GP kernel factory component layer. - Updated preset kernel factories (EDBO, CHEN, Smoothed EDBO, BayBE) to rely on the decorator rather than per-preset task-exclusion selectors / manual orchestration.
- Adjusted kernel-factory tests and internal factory naming to reflect the new split between numerical-only and task-aware factories.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
baybe/surrogates/gaussian_process/components/kernel.py |
Adds _enable_transfer_learning and updates ICM defaults to use internal numerical/task factories. |
baybe/surrogates/gaussian_process/presets/edbo.py |
Applies transfer-learning decorator and removes the preset’s default task-excluding selector. |
baybe/surrogates/gaussian_process/presets/edbo_smoothed.py |
Splits numerical kernel factory from transfer-learning-enabled alias. |
baybe/surrogates/gaussian_process/presets/chen.py |
Applies transfer-learning decorator and removes the preset’s default task-excluding selector. |
baybe/surrogates/gaussian_process/presets/baybe.py |
Refactors default kernel factory wiring to reuse Smoothed EDBO and internal task factory. |
tests/test_kernel_factories.py |
Updates imports/usages to the new internal factory names. |
Comments suppressed due to low confidence (1)
baybe/surrogates/gaussian_process/presets/edbo_smoothed.py:49
_SmoothedEDBONumericalKernelFactory._makeuseseffective_dims = train_x.shape[-1]and passesparameter_names=self.get_parameter_names(searchspace)(which isNoneunless a selector is provided). Under transfer learning/multitask search spaces, this will count the task index dimension ineffective_dimsand make the base kernel operate on the task column as well, altering the smoothed prior interpolation. The numerical kernel used as the ICM base kernel should excludeTaskParameterby default, and dimension logic should be computed from the selected active dims/parameter names (not the rawtrain_xwidth).
@define
class _SmoothedEDBONumericalKernelFactory(_PureKernelFactory):
"""A factory providing the core numerical kernel for the smoothed EDBO preset."""
_uses_parameter_names: ClassVar[bool] = True
# See base class.
@override
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
effective_dims = train_x.shape[-1]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
6b3abe7 to
a713bfe
Compare
|
@copilot: please review |
I've completed a comprehensive code review of the transfer learning decorator implementation. Here are my findings: ✅ Strengths
🔍 Observations
✨ Code QualityThe implementation is well-structured, properly documented, and follows BayBE's patterns. The decorator successfully abstracts transfer learning while maintaining backward compatibility and correct serialization behavior. No issues found. The implementation is production-ready. |
| """Construct the kernel.""" | ||
|
|
||
|
|
||
| def _enable_transfer_learning( |
There was a problem hiding this comment.
@Scienfitz: in principle ready and working. However, I have to admit that this was significantly more painful than anticipated, with many footguns along the way. So I'm open to a very harsh review and a complete change of direction, if you prefer and have an alternative/simpler idea.
But I hope that you get my intent for this: I think we need some mechanism that lets us say fill this preset with our default approach for a certain aspect that the preset does not specify, and the filling should be very much done without copying code since the BayBE defaults are expected to move. So we need something like a single source of truth. That said: maybe you have some smarter idea.
There was a problem hiding this comment.
only annoying ting is the _NAME / NAME thing but I dont know a better alternative
|
@AdrianSosic after our chat you were not sure whether this PR is still relevant, now seeing it still there I assume its still relevant and should be reviewed? |
| if original_selector is None: | ||
| self.parameter_selector = _task_exclude_selector | ||
| else: | ||
| self.parameter_selector = lambda p: ( |
There was a problem hiding this comment.
The TaskParameter will be excluded here if the orgininal parameter_selector contained it. Should we maybe raise a warning?
There was a problem hiding this comment.
Hm, maybe I'm looking at this from the wrong perspective, but isn't that what we want to happen?
Example: My search space contains 3 parameters (2 regular and 1 task): (p1, p2, p_task). If I now call a factory that is TL-compatible (which the factories that get decorated de facto are), then two things can happen:
- I pass a selector that excludes the task parameter, e.g. filters down to just
p1. In this case, I want that the inner logic gets executed, i.e. the execution internally dispatches to the original non-decorated factory. - I pass a selector that includes the task parameter, e.g. selects
p1andtask. In this case,taskshould be split off, the regular factory should be called onp1, and then we assemble everything via the ICM mechanism.
Or am I overlooking something? If you see a problem, a minimum example would be helpful
| target_cls._supported_parameter_kinds = broadened_kinds # type: ignore[attr-defined] | ||
| self.parameter_selector = original_selector | ||
|
|
||
| if searchspace.task_idx is not None: |
There was a problem hiding this comment.
I can't oversee 100% where we would put the logic for TL modes that affect other components than the kernel later. But, the current version will definetely work well with the dispatching between IndexKernel and PositiveIndexKernel, which will happen within the ICMKernelFactory.
There was a problem hiding this comment.
In my current naive vision, the decorator would then not just be applied to the kernel factor but to all components that require changes. It would process all of them and apply the necessary changes, e.g.
- for kernel-TL, it would do exactly what we do no, i.e. only modify the kernel factory
- for mean-injection-TL, it would not affect the kernel but alter the mean-factory to do the prior-mean stuff
- ...
However, this is exactly what I meant with I can't yet foresee exactly, so let's keep the decorator private for now :D But yeah, potentially the decorator logic might become quite a beast, since it not only needs to capture TL stuff across all components, but maybe also multi-fidelity stuff etc in the end. My current fear is that the complexity might get out of hands...
Any action items at this point I need to take care of, or can I close?
|
|
||
| BayBENumericalKernelFactory = SmoothedEDBOKernelFactory | ||
| """The factory providing the default numerical kernel for Gaussian process surrogates.""" # noqa: E501 | ||
| class BayBEKernelFactory(SmoothedEDBOKernelFactory): # type: ignore[valid-type, misc] |
There was a problem hiding this comment.
I think I roughly get the need for this (horrendously feeling) _NAME / NAME split for some of the objects (to which I couldnt come up witha better idea either)
But could you sum tis up in 1-2 sentences in the PR description? It seems the commits also have explanation but it seemingly changed during the course of them and they don't really get the point across very nicely
* Provides a single source of truth for defining the TL logic * Enables TL for non-TL presets by applying the decorator
`_enable_transfer_learning` now accepts an optional `name` parameter so that the dynamically created class can have the correct `__name__` when the function is called directly (rather than used as a decorator). This fixes serialization for `SmoothedEDBOKernelFactory`, which was previously serialized as `_SmoothedEDBONumericalKernelFactory`.
Simple aliases like `BayBEKernelFactory = SmoothedEDBOKernelFactory` cause the serialized type name to be that of the underlying class, which means the identity is lost on deserialization. Using thin subclasses ensures each factory has its own stable `__name__`.
When used as a decorator (@_enable_transfer_learning), modify the class in-place instead of creating a subclass with the same __name__. The previous approach left two concrete classes with identical names in the subclass registry, causing find_subclass to resolve to the @Define- processed intermediate (without the TL wrapper) during deserialization. When called with an explicit name argument (for cases like SmoothedEDBOKernelFactory where the original class is reused elsewhere), the subclass approach is preserved since the distinct name avoids any collision.
The Protocol metaclass (_ProtocolMeta) defaults __module__ to 'abc' when creating classes via 3-arg type(). Set it explicitly from the parent class so that SmoothedEDBOKernelFactory correctly reports its module as baybe.surrogates.gaussian_process.presets.edbo_smoothed.
5229f7f to
41a1fc7
Compare
| # __module__ must be set explicitly because the Protocol metaclass | ||
| # would otherwise default it to "abc". | ||
| target_cls = type( | ||
| name, (cls,), {"__doc__": cls.__doc__, "__module__": cls.__module__} |
There was a problem hiding this comment.
isnt it strange that it is made a subclass? in a sense it should be an equivalent copy, not a subclass
There was a problem hiding this comment.
Fully agree. It's just that the subclassing mechanism is a bit less works since it automatically populates the attributes and stuff. But if you prefer the conceptually cleaner approach: 1ccad46
Shall I keep it?
There was a problem hiding this comment.
cool, if it works I'd prefer it 👍
| ) | ||
|
|
||
| return BayBETaskKernelFactory(TypeSelector((TaskParameter,))) | ||
| assert _BayBETaskKernelFactory._supported_parameter_kinds is _ParameterKind.TASK |
There was a problem hiding this comment.
are these new asserts in the defaults for mypy or for another purpose?
There was a problem hiding this comment.
No, not for mypy, but to guarantee that the defaults that are injected here actually are compatible with the contract that the ICM mechanism requires. However, this is probably from the earlier drafting days and should actually become a proper validation step. Would it be fine for you if I turned this into actual validators?
There was a problem hiding this comment.
I was commenting due to this simple principle: asserts for anything other than pytest or mypy type-narrowing should be avoided, in the latter case ideally they'd get a comment so we dont have to always redo this kind of thread here
…orProtocol Extends the method to accept either a parameter name (existing behavior) or a selector, returning the combined comp-rep indices of all matching parameters. Uses this to simplify _get_effective_dimensionality in _PureKernelFactory and the inline dimensionality sums in the EDBO likelihood factories.
Parameter name uniqueness is already enforced as a searchspace invariant, making the multi-match case impossible. The no-match case now returns an empty tuple, consistent with the selector path and the existing behavior for parameters absent from the comp-rep.
Co-authored-by: Martin Fitzner <martin.fitzner@merckgroup.com>
DevPR, parent is #745
Last piece to the puzzle:
Presets (i.e. papers, packages, etc) can dictate certain aspects of the GP model while not saying anything about other aspects. For example, both EDBO and CHEN focus on the kernel priors but don't even consider transfer learning at all. This is a general issue, and can also cover other things like multi-fidelity etc.
For these cases, we want to follow the approach
if not defined, use BayBE default mechanism/setting. However, this requires to abstract these settings/mechanism into reusable structures. This PR takes care of this step for transfer learning (which is currently the only mechanism that needs to be ported) in the form of a class decorator. Because other mechanisms will follow in the future and their extent isn't yet fully clear (e.g. multi-fidelity or transfer learning via mean injection), we keep this decorator private for now. A possible future extension of the decorator could have the form@enable_mechanism(transfer_learning=True, multi_fidelity=True)that then accepts any existing GP component and makes the necessary adjustments.