Skip to content

Add Convenience Kernel Arithmetic#763

Open
Scienfitz wants to merge 20 commits intodev/gpfrom
feature/convenience_kernel_arithmetic
Open

Add Convenience Kernel Arithmetic#763
Scienfitz wants to merge 20 commits intodev/gpfrom
feature/convenience_kernel_arithmetic

Conversation

@Scienfitz
Copy link
Copy Markdown
Collaborator

@Scienfitz Scienfitz commented Mar 12, 2026

Closes #760
DevPR, parent is #745

  • Rename AdditiveKernel to SumKernel + deprecation + deserialziation hook for compatibility
  • Adds mul, rmul, add, radd operators for BayBE kernels
  • Ensures flattening of add and mul operatrions, eg SumKernel([SumKernel([a, b]), c]) results in SumKernel([a,b,c])
  • Multiplication wth a constant will result in a ScaleKernel with the new attribute outputscale_trainable=False which results in it not being affected by training, ie it is constant

@Scienfitz Scienfitz self-assigned this Mar 12, 2026
@Scienfitz Scienfitz added the enhancement Expand / change existing functionality label Mar 12, 2026
Resolve the naming inconsistency between AdditiveKernel and ProductKernel by
renaming AdditiveKernel to SumKernel, forming a consistent SumKernel/ProductKernel
pair. A deprecation shim is provided so that the old AdditiveKernel name still works
but emits a DeprecationWarning.
Add a boolean flag to control whether the output scale parameter is trainable.
When set to False, the raw_outputscale parameter is frozen (requires_grad=False)
after GPyTorch kernel construction. This enables creating ScaleKernels with fixed
output scales, which is needed for the upcoming constant * kernel operator.
Add __add__, __mul__, __radd__, and __rmul__ to the Kernel base class:
- kernel + kernel creates a SumKernel
- kernel * kernel creates a ProductKernel
- kernel * constant / constant * kernel creates a ScaleKernel with a fixed
  (non-trainable) output scale
Nested sums and products are now flattened: (a + b) + c produces
SumKernel([a, b, c]) instead of SumKernel([SumKernel([a, b]), c]),
and likewise for ProductKernel. This can be reverted independently
while preserving the core operator functionality.
Test suite covering:
- Addition operator producing SumKernel with flattening
- Multiplication operator producing ProductKernel with flattening
- Constant multiplication producing ScaleKernel with frozen outputscale
- TypeError for unsupported operand types
- AdditiveKernel deprecation warning
- Serialization roundtrips for operator-constructed kernels
Replace explicit SumKernel([...]) and ProductKernel([...]) construction in
tests with the new + and * operators for more readable kernel composition.
@Scienfitz Scienfitz force-pushed the feature/convenience_kernel_arithmetic branch from f75ad4d to 7bcc328 Compare March 12, 2026 16:44
@Scienfitz Scienfitz marked this pull request as ready for review March 12, 2026 16:47
Copilot AI review requested due to automatic review settings March 12, 2026 16:47
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves BayBE kernel ergonomics by adding operator overloading for kernel composition and aligning composite-kernel naming (AdditiveKernelSumKernel) with a backwards-compatible deprecation path (including deserialization support).

Changes:

  • Renames AdditiveKernel to SumKernel, adds deprecation wrapper and a cattrs structure hook to deserialize legacy "type": "AdditiveKernel" configs.
  • Adds Kernel arithmetic via + and *, including flattening of nested SumKernel/ProductKernel compositions.
  • Adds constant scaling via constant * kernel / kernel * constant by producing a ScaleKernel with outputscale_trainable=False.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
baybe/kernels/base.py Adds __add__/__mul__ (flattening) and a deserialization hook for legacy AdditiveKernel configs; updates gpytorch attribute matching to tolerate *_trainable.
baybe/kernels/composite.py Introduces ScaleKernel.outputscale_trainable and renames AdditiveKernel to SumKernel.
baybe/kernels/deprecation.py Adds deprecated AdditiveKernel factory that warns and returns SumKernel.
baybe/kernels/__init__.py Re-exports SumKernel and routes AdditiveKernel through the deprecation wrapper.
tests/test_kernels.py Adds unit tests for operator behavior, flattening, and constant scaling; updates gpytorch-component validation to ignore *_trainable.
tests/test_iterations.py Updates composite-kernel test inputs to use +/* composition instead of explicit composite kernel constructors.
tests/test_gp.py Updates GP kernel fixture to use MaternKernel() + RBFKernel() composition.
tests/test_deprecations.py Adds coverage for AdditiveKernel deprecation warning and legacy-name deserialization.
tests/hypothesis_strategies/kernels.py Switches additive composite strategy to produce SumKernel.
tests/validation/kernels/test_composite_kernel_validation.py Updates composite-kernel validation parametrization from AdditiveKernel to SumKernel.
CHANGELOG.md Documents kernel operator support and the AdditiveKernelSumKernel replacement.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@Scienfitz Scienfitz linked an issue Mar 18, 2026 that may be closed by this pull request
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@AdrianSosic AdrianSosic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good PR, thanks for taking care 🥇 Only a few minor comments

Comment on lines +49 to +50
if isinstance(other, Kernel):
return self.__add__(other)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be that these two lines are essentially dead/unneeded code?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you know it could be, it could not be, but to determine that I would need your reasoning.

I have an idea now that you mentioned it here. But why guess if the original thinker could just post their reasoning?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't want to post in order not to bias your thoughts but my reasoning is:

asking for isinstance(other, Kernel) does not make any sense because if other actually were a Kernel, then the __add__ method would have already be called, i.e. __radd__ is only ever called if __add__ yields NotImplemented. So I see no execution path that could ever lead here.

But this is a good example where people get confused quickly, so please let me know if this is wrong.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that was also what I was thinkinga bout after your comment so strictly these lines seem to be dead. I haven thought about subclasses, perhaps there is something in there to keep it?

on the other hand I would find it rather weird to have this function not implement the logic for kernels :/

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it's not weird if you think of it it term of NotImplemented means I have no specialized handling for this case, so leave it for other functions to take over ... which is indeed what it means

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.python.org/3/reference/datamodel.html#object.__radd__

These functions are only called if the operands are of different types, when the left operand does not support the corresponding operation [3], or the right operand’s class is derived from the left operand’s class.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but does this actually change anything? So if __radd__ is called first in this case and returns NotImplemented, then the regular __add__ is called next, which will take over the job. Thus, again, it would be sufficient to just say even though I'm asked first, I have no special handling, so let the regular process happen

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine for me to add that special treatment, but doenst hat make th special treamtent of 1 in rmul obsolete again?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does, thus I removed the condition from __rmul__ now. But: could it be that you replied to the wrong thread? The special case handling was mentioned here. That said, the content of this thread still requires your approval

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sorry, wrong thread

Re does this change anything: What do you want me to approve? This thread is a request by you.

Copy link
Copy Markdown
Collaborator

@AVHopp AVHopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only real issue is that one test is missing, but I am sure you can add this :)

validate_gpytorch_kernel_components(kernel, k, **kwargs)


def test_add_produces_sum_kernel():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand it correctly, one of the function is added such that we can use the sum with kernels, correct? If so, then a dedicated test for this is missing as we only use + in all tests.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion I've added a test now

I have also added the rmul functionality which is required for prod([kernel1, kernel2]) which starts out as 1*kernel1 just like 0+kernel1 in sum([kernel1, kernel2]). This is also tested

This surfaces an incosnistency however: floats multiplied with kernels will wrap a scale kernel with frozen output scale. but this is skipped for the prod case (just doesnt make sense to me and would not be identical to the manual kernel multiplication). In general, we could argue that 0 and 1 should have special meanings in the mul and add which they currently dont (it would revert the need to treat 0 or 1 specially in rmul and radd)

@AdrianSosic should also sign off on this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get around this by enabling the use of sum and prod only for sequences of Kernel objects? I guess that this is the main use case anyway, I do not think people want to have a sum([1, Kernel1, 23, 45, Kernel2]) or similar.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean, but when you write sum([k1, k2]) it will internally call 0 + k1 first

and prod([k1,k2]) will start out as 1*k1

so theres no way to avoid it

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Scienfitz, pls correct me if I'm wrong, but I think this has now become obsolete since adding special treatment is exactly what happened lately, right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? the test is probably not obsolete

Scienfitz and others added 2 commits March 30, 2026 19:18
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: AdrianSosic <adrian.sosic@merckgroup.com>
@Scienfitz
Copy link
Copy Markdown
Collaborator Author

Scienfitz commented Mar 30, 2026

@AdrianSosic just in case you've not seen this: this branch was based off of the gp refactor and merges back into it
since this is basically your area perhaps you should merge the PR when you see fit?

@Scienfitz Scienfitz force-pushed the feature/convenience_kernel_arithmetic branch from a8ab4b7 to 92d53c6 Compare March 31, 2026 14:52
Case is handled when falling back to __mul__
@AdrianSosic AdrianSosic mentioned this pull request Apr 2, 2026
20 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement Expand / change existing functionality

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Kernel operators and naming

4 participants