Conversation
Resolve the naming inconsistency between AdditiveKernel and ProductKernel by renaming AdditiveKernel to SumKernel, forming a consistent SumKernel/ProductKernel pair. A deprecation shim is provided so that the old AdditiveKernel name still works but emits a DeprecationWarning.
Add a boolean flag to control whether the output scale parameter is trainable. When set to False, the raw_outputscale parameter is frozen (requires_grad=False) after GPyTorch kernel construction. This enables creating ScaleKernels with fixed output scales, which is needed for the upcoming constant * kernel operator.
Add __add__, __mul__, __radd__, and __rmul__ to the Kernel base class: - kernel + kernel creates a SumKernel - kernel * kernel creates a ProductKernel - kernel * constant / constant * kernel creates a ScaleKernel with a fixed (non-trainable) output scale
Nested sums and products are now flattened: (a + b) + c produces SumKernel([a, b, c]) instead of SumKernel([SumKernel([a, b]), c]), and likewise for ProductKernel. This can be reverted independently while preserving the core operator functionality.
Test suite covering: - Addition operator producing SumKernel with flattening - Multiplication operator producing ProductKernel with flattening - Constant multiplication producing ScaleKernel with frozen outputscale - TypeError for unsupported operand types - AdditiveKernel deprecation warning - Serialization roundtrips for operator-constructed kernels
Replace explicit SumKernel([...]) and ProductKernel([...]) construction in tests with the new + and * operators for more readable kernel composition.
f75ad4d to
7bcc328
Compare
There was a problem hiding this comment.
Pull request overview
This PR improves BayBE kernel ergonomics by adding operator overloading for kernel composition and aligning composite-kernel naming (AdditiveKernel → SumKernel) with a backwards-compatible deprecation path (including deserialization support).
Changes:
- Renames
AdditiveKerneltoSumKernel, adds deprecation wrapper and a cattrs structure hook to deserialize legacy"type": "AdditiveKernel"configs. - Adds
Kernelarithmetic via+and*, including flattening of nestedSumKernel/ProductKernelcompositions. - Adds constant scaling via
constant * kernel/kernel * constantby producing aScaleKernelwithoutputscale_trainable=False.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
baybe/kernels/base.py |
Adds __add__/__mul__ (flattening) and a deserialization hook for legacy AdditiveKernel configs; updates gpytorch attribute matching to tolerate *_trainable. |
baybe/kernels/composite.py |
Introduces ScaleKernel.outputscale_trainable and renames AdditiveKernel to SumKernel. |
baybe/kernels/deprecation.py |
Adds deprecated AdditiveKernel factory that warns and returns SumKernel. |
baybe/kernels/__init__.py |
Re-exports SumKernel and routes AdditiveKernel through the deprecation wrapper. |
tests/test_kernels.py |
Adds unit tests for operator behavior, flattening, and constant scaling; updates gpytorch-component validation to ignore *_trainable. |
tests/test_iterations.py |
Updates composite-kernel test inputs to use +/* composition instead of explicit composite kernel constructors. |
tests/test_gp.py |
Updates GP kernel fixture to use MaternKernel() + RBFKernel() composition. |
tests/test_deprecations.py |
Adds coverage for AdditiveKernel deprecation warning and legacy-name deserialization. |
tests/hypothesis_strategies/kernels.py |
Switches additive composite strategy to produce SumKernel. |
tests/validation/kernels/test_composite_kernel_validation.py |
Updates composite-kernel validation parametrization from AdditiveKernel to SumKernel. |
CHANGELOG.md |
Documents kernel operator support and the AdditiveKernel → SumKernel replacement. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
AdrianSosic
left a comment
There was a problem hiding this comment.
Very good PR, thanks for taking care 🥇 Only a few minor comments
| if isinstance(other, Kernel): | ||
| return self.__add__(other) |
There was a problem hiding this comment.
Could it be that these two lines are essentially dead/unneeded code?
There was a problem hiding this comment.
you know it could be, it could not be, but to determine that I would need your reasoning.
I have an idea now that you mentioned it here. But why guess if the original thinker could just post their reasoning?
There was a problem hiding this comment.
Didn't want to post in order not to bias your thoughts but my reasoning is:
asking for isinstance(other, Kernel) does not make any sense because if other actually were a Kernel, then the __add__ method would have already be called, i.e. __radd__ is only ever called if __add__ yields NotImplemented. So I see no execution path that could ever lead here.
But this is a good example where people get confused quickly, so please let me know if this is wrong.
There was a problem hiding this comment.
yes that was also what I was thinkinga bout after your comment so strictly these lines seem to be dead. I haven thought about subclasses, perhaps there is something in there to keep it?
on the other hand I would find it rather weird to have this function not implement the logic for kernels :/
There was a problem hiding this comment.
well, it's not weird if you think of it it term of NotImplemented means I have no specialized handling for this case, so leave it for other functions to take over ... which is indeed what it means
There was a problem hiding this comment.
https://docs.python.org/3/reference/datamodel.html#object.__radd__
These functions are only called if the operands are of different types, when the left operand does not support the corresponding operation [3], or the right operand’s class is derived from the left operand’s class.
There was a problem hiding this comment.
Ok, but does this actually change anything? So if __radd__ is called first in this case and returns NotImplemented, then the regular __add__ is called next, which will take over the job. Thus, again, it would be sufficient to just say even though I'm asked first, I have no special handling, so let the regular process happen
There was a problem hiding this comment.
fine for me to add that special treatment, but doenst hat make th special treamtent of 1 in rmul obsolete again?
There was a problem hiding this comment.
Yes, it does, thus I removed the condition from __rmul__ now. But: could it be that you replied to the wrong thread? The special case handling was mentioned here. That said, the content of this thread still requires your approval
There was a problem hiding this comment.
yes sorry, wrong thread
Re does this change anything: What do you want me to approve? This thread is a request by you.
AVHopp
left a comment
There was a problem hiding this comment.
Only real issue is that one test is missing, but I am sure you can add this :)
tests/test_kernels.py
Outdated
| validate_gpytorch_kernel_components(kernel, k, **kwargs) | ||
|
|
||
|
|
||
| def test_add_produces_sum_kernel(): |
There was a problem hiding this comment.
If I understand it correctly, one of the function is added such that we can use the sum with kernels, correct? If so, then a dedicated test for this is missing as we only use + in all tests.
There was a problem hiding this comment.
good suggestion I've added a test now
I have also added the rmul functionality which is required for prod([kernel1, kernel2]) which starts out as 1*kernel1 just like 0+kernel1 in sum([kernel1, kernel2]). This is also tested
This surfaces an incosnistency however: floats multiplied with kernels will wrap a scale kernel with frozen output scale. but this is skipped for the prod case (just doesnt make sense to me and would not be identical to the manual kernel multiplication). In general, we could argue that 0 and 1 should have special meanings in the mul and add which they currently dont (it would revert the need to treat 0 or 1 specially in rmul and radd)
@AdrianSosic should also sign off on this
There was a problem hiding this comment.
Could we get around this by enabling the use of sum and prod only for sequences of Kernel objects? I guess that this is the main use case anyway, I do not think people want to have a sum([1, Kernel1, 23, 45, Kernel2]) or similar.
There was a problem hiding this comment.
not sure what you mean, but when you write sum([k1, k2]) it will internally call 0 + k1 first
and prod([k1,k2]) will start out as 1*k1
so theres no way to avoid it
There was a problem hiding this comment.
Hey @Scienfitz, pls correct me if I'm wrong, but I think this has now become obsolete since adding special treatment is exactly what happened lately, right?
There was a problem hiding this comment.
? the test is probably not obsolete
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: AdrianSosic <adrian.sosic@merckgroup.com>
|
@AdrianSosic just in case you've not seen this: this branch was based off of the gp refactor and merges back into it |
a8ab4b7 to
92d53c6
Compare
Case is handled when falling back to __mul__
Closes #760
DevPR, parent is #745
RenameAdditiveKerneltoSumKernel+ deprecation + deserialziation hook for compatibilitymul,rmul,add,raddoperators for BayBE kernelsSumKernel([SumKernel([a, b]), c])results inSumKernel([a,b,c])ScaleKernelwith the new attributeoutputscale_trainable=Falsewhich results in it not being affected by training, ie it is constant