-
Notifications
You must be signed in to change notification settings - Fork 70
Projection kernels #689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Projection kernels #689
Changes from all commits
b686246
a4aa23a
4b9c4c9
32229da
6d9bf6a
6f8a0b1
8bdb46c
4c53379
6f624d2
31a3711
fbeefef
3568807
24c3852
06276d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| """GPyTorch kernel implementations.""" | ||
|
|
||
| from typing import Any | ||
|
|
||
| import torch | ||
| from gpytorch.kernels import Kernel | ||
| from torch import Tensor | ||
|
|
||
| from baybe.utils.torch import DTypeFloatTorch | ||
|
|
||
| _ConvertibleToTensor = Any | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused by this - first, why is |
||
| """A type alias for objects convertible to tensors.""" | ||
|
|
||
|
|
||
| class GPyTorchProjectionKernel(Kernel): | ||
AVHopp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """GPyTorch implementation of :class:`baybe.kernels.composite.ProjectionKernel`.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| base_kernel: Kernel, | ||
| projection_matrix: _ConvertibleToTensor, | ||
| *, | ||
| learn_projection: bool = False, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self.base_kernel = base_kernel | ||
| self.learn_projection = learn_projection | ||
|
|
||
| matrix = torch.tensor(projection_matrix, dtype=DTypeFloatTorch) | ||
| if self.learn_projection: | ||
| self.register_parameter("projection_matrix", torch.nn.Parameter(matrix)) | ||
| else: | ||
| self.register_buffer("projection_matrix", matrix) | ||
|
|
||
| def forward(self, x1: Tensor, x2: Tensor, **kwargs): | ||
| """Apply the base kernel to the projected input tensors.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed, we might want to add a |
||
| x1_proj = x1 @ self.projection_matrix | ||
| x2_proj = x2 @ self.projection_matrix | ||
| return self.base_kernel(x1_proj, x2_proj, **kwargs) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,8 @@ | |
| from functools import reduce | ||
| from operator import add, mul | ||
|
|
||
| from attrs import define, field | ||
| import numpy as np | ||
| from attrs import Attribute, cmp_using, define, field | ||
| from attrs.converters import optional as optional_c | ||
| from attrs.validators import deep_iterable, gt, instance_of, min_len | ||
| from attrs.validators import optional as optional_v | ||
|
|
@@ -83,5 +84,45 @@ def to_gpytorch(self, *args, **kwargs): | |
| return reduce(mul, (k.to_gpytorch(*args, **kwargs) for k in self.base_kernels)) | ||
|
|
||
|
|
||
| @define(frozen=True) | ||
| class ProjectionKernel(CompositeKernel): | ||
| """A projection kernel for dimensionality reduction.""" | ||
|
|
||
| base_kernel: Kernel = field(validator=instance_of(Kernel)) | ||
| """The kernel applied to the projected inputs.""" | ||
|
|
||
| projection_matrix: np.ndarray = field( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems to be no sort of validation of the exact dimensions of the matrix here other than that the matrix needs to have two dimensions, correct? Hence my question is whether or not this will be auto-derived in general and is thus not intended to be actually set by the user themselves or if this validation happens somewhere else.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed, we might want to make the description here a bit clearer and tell the user that this will not be validated upon creation (as we can't) and/or throw an error at an appropriate point (see other comment in GPyTorch-kernel) |
||
| converter=np.asarray, eq=cmp_using(eq=np.array_equal) | ||
| ) | ||
| """The projection matrix.""" | ||
|
|
||
| learn_projection: bool = field( | ||
| default=True, validator=instance_of(bool), kw_only=True | ||
| ) | ||
| """Boolean specifying if the projection matrix should be learned. | ||
|
|
||
| If ``True``, the provided projection matrix is used as initial value.""" | ||
|
|
||
| @projection_matrix.validator | ||
| def _validate_projection_matrix(self, _: Attribute, value: np.ndarray): | ||
| if value.ndim != 2: | ||
| raise ValueError( | ||
| f"The projection matrix must be 2-dimensional, " | ||
| f"but has shape {value.shape}." | ||
| ) | ||
|
|
||
| @override | ||
| def to_gpytorch(self, **kwargs): | ||
| from baybe.kernels._gpytorch import GPyTorchProjectionKernel | ||
|
|
||
| n_projections = self.projection_matrix.shape[-1] | ||
| gpytorch_kernel = self.base_kernel.to_gpytorch(ard_num_dims=n_projections) | ||
| return GPyTorchProjectionKernel( | ||
| gpytorch_kernel, | ||
| projection_matrix=self.projection_matrix, | ||
| learn_projection=self.learn_projection, | ||
| ) | ||
|
|
||
|
|
||
| # Collect leftover original slotted classes processed by `attrs.define` | ||
| gc.collect() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
|
|
||
| import attrs | ||
| import cattrs | ||
| import numpy as np | ||
| import pandas as pd | ||
| from cattrs.strategies import configure_union_passthrough | ||
|
|
||
|
|
@@ -104,6 +105,15 @@ def _unstructure_dataframe_hook(df: pd.DataFrame) -> str: | |
| return base64.b64encode(pickled_df).decode("utf-8") | ||
|
|
||
|
|
||
| _unstructure_ndarray_hook = _unstructure_dataframe_hook | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this now mean that we use a function that explicitly uses |
||
|
|
||
|
|
||
| def _structure_ndarray_hook(obj: str, _) -> np.ndarray: | ||
| """Deserialize a numpy ndarray.""" | ||
| pickled_array = base64.b64decode(obj.encode("utf-8")) | ||
| return pickle.loads(pickled_array) | ||
|
|
||
|
|
||
| def block_serialization_hook(obj: Any) -> NoReturn: # noqa: DOC101, DOC103 | ||
| """Prevent serialization of the passed object. | ||
|
|
||
|
|
@@ -163,6 +173,8 @@ def select_constructor_hook(specs: dict, cls: type[_T]) -> _T: | |
| ) | ||
| converter.register_unstructure_hook(pd.DataFrame, _unstructure_dataframe_hook) | ||
| converter.register_structure_hook(pd.DataFrame, _structure_dataframe_hook) | ||
| converter.register_unstructure_hook(np.ndarray, _unstructure_ndarray_hook) | ||
| converter.register_structure_hook(np.ndarray, _structure_ndarray_hook) | ||
| converter.register_unstructure_hook(datetime, lambda x: x.isoformat()) | ||
| converter.register_structure_hook(datetime, lambda x, _: datetime.fromisoformat(x)) | ||
| converter.register_unstructure_hook(timedelta, lambda x: f"{x.total_seconds()}s") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,20 +3,118 @@ | |
| from __future__ import annotations | ||
|
|
||
| import gc | ||
| from enum import Enum | ||
| from typing import TYPE_CHECKING, Protocol | ||
|
|
||
| import numpy as np | ||
| from attrs import define, field | ||
| from attrs.validators import instance_of | ||
| from typing_extensions import override | ||
| from typing_extensions import assert_never, override | ||
|
|
||
| from baybe.kernels.base import Kernel | ||
| from baybe.kernels.composite import AdditiveKernel, ProjectionKernel | ||
| from baybe.searchspace import SearchSpace | ||
| from baybe.serialization.mixin import SerialMixin | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| def to_kernel_factory(x: Kernel | KernelFactory, /) -> KernelFactory: | ||
| """Wrap a kernel into a plain kernel factory (with factory passthrough).""" | ||
| return x.to_factory() if isinstance(x, Kernel) else x | ||
|
|
||
|
|
||
| class ProjectionMatrixInitialization(Enum): | ||
| """Initialization strategies for kernel projection matrices.""" | ||
|
|
||
| MASKING = "MASKING" | ||
| """Axis-aligned masking (random selection of input dimensions).""" | ||
|
Comment on lines
+31
to
+32
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the name
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to suggest better alternatives, was struggling here a little. (context: "masking" was meant as it "hides" the other dimensions)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I just wanted to know where the name comes from as "masking" has a special meaning in some other projects and contexts. How about "reducing" maybe or "trimming"? |
||
|
|
||
| ORTHONORMAL = "ORTHONORMAL" | ||
| """Random orthonormal basis.""" | ||
|
|
||
| PLS = "PLS" | ||
| """Partial Least Squares (PLS) directions.""" | ||
|
|
||
| SPHERICAL = "SPHERICAL" | ||
| """Uniform random sampling on the unit sphere.""" | ||
|
|
||
|
|
||
| def _make_projection_matrices( | ||
| n_projections: int, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer |
||
| n_matrices: int, | ||
| initialization: ProjectionMatrixInitialization, | ||
| train_x: np.ndarray, | ||
| train_y: np.ndarray, | ||
| ) -> np.ndarray: | ||
| """Create a collection of projection matrices. | ||
|
|
||
| Args: | ||
| n_projections: The number of projections in each matrix. | ||
| n_matrices: The number of projection matrices to create. | ||
| initialization: The initialization strategy to use. | ||
| train_x: The training inputs. | ||
| train_y: The training outputs. | ||
|
|
||
| Returns: | ||
| An array of shape ``(n_matrices, n_input_dims, n_projections)`` containing the | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| created matrices. | ||
| """ | ||
| n_input_dims = train_x.shape[-1] | ||
|
|
||
| if n_matrices == 0: | ||
| return np.empty((0, n_input_dims, n_projections)) | ||
|
|
||
| if initialization is ProjectionMatrixInitialization.MASKING: | ||
| matrices = [] | ||
| for _ in range(n_matrices): | ||
| matrix = np.eye(n_input_dims) | ||
| matrix = matrix[ | ||
| :, np.random.choice(n_input_dims, n_projections, replace=False) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this require a check for |
||
| ] | ||
| matrices.append(matrix) | ||
|
|
||
| elif initialization is ProjectionMatrixInitialization.ORTHONORMAL: | ||
| matrices = [] | ||
| for _ in range(n_matrices): | ||
| random_matrix = np.random.randn(n_input_dims, n_projections) | ||
| q, _ = np.linalg.qr(random_matrix) | ||
| matrices.append(q[:, :n_projections]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my quick test, this step was not necessary and |
||
|
|
||
| elif initialization is ProjectionMatrixInitialization.PLS: | ||
| from sklearn.cross_decomposition import PLSRegression | ||
|
|
||
| pls = PLSRegression(n_components=n_projections) | ||
| pls.fit(train_x, train_y) | ||
| M = pls.x_rotations_ | ||
|
|
||
| # IMPROVE: One could use the remaining PLS directions for the next matrices | ||
| # until they are exhausted, then switch to orthonormal. | ||
| matrices = [ | ||
| M, | ||
| *_make_projection_matrices( | ||
| n_projections=n_projections, | ||
| n_matrices=n_matrices - 1, | ||
| initialization=ProjectionMatrixInitialization.ORTHONORMAL, | ||
| train_x=train_x, | ||
| train_y=train_y, | ||
| ), | ||
| ] | ||
|
|
||
| elif initialization is ProjectionMatrixInitialization.SPHERICAL: | ||
| matrices = [] | ||
| for _ in range(n_matrices): | ||
| matrix = np.random.randn(n_input_dims, n_projections) | ||
| matrix = matrix / np.linalg.norm(matrix, axis=0, keepdims=True) | ||
| matrices.append(matrix) | ||
|
|
||
| else: | ||
| assert_never(initialization) | ||
|
|
||
| return np.stack(matrices) if n_matrices > 1 else matrices[0][None, ...] | ||
|
|
||
|
|
||
| class KernelFactory(Protocol): | ||
| """A protocol defining the interface expected for kernel factories.""" | ||
|
|
||
|
|
@@ -41,9 +139,61 @@ def __call__( | |
| return self.kernel | ||
|
|
||
|
|
||
| def to_kernel_factory(x: Kernel | KernelFactory, /) -> KernelFactory: | ||
| """Wrap a kernel into a plain kernel factory (with factory passthrough).""" | ||
| return x.to_factory() if isinstance(x, Kernel) else x | ||
| @define(frozen=True) | ||
| class ProjectionKernelFactory(KernelFactory, SerialMixin): | ||
| """A factory producing projected kernels.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this needs a bit more explanation regarding the exact calculation of the projection as this is not clear imo. |
||
|
|
||
| n_projections: int = field(validator=instance_of(int)) | ||
| """The number of projections to be used in each projection matrix.""" | ||
|
|
||
| n_matrices: int = field(validator=instance_of(int)) | ||
| """The number of projection matrices to be used.""" | ||
|
|
||
| initialization: ProjectionMatrixInitialization = field( | ||
| converter=ProjectionMatrixInitialization | ||
| ) | ||
| """The initialization strategy for the projection matrices.""" | ||
|
|
||
| base_kernel_factory: KernelFactory = field( | ||
| alias="kernel_or_factory", | ||
| converter=to_kernel_factory, | ||
| ) | ||
| """The factory creating the base kernel to be applied to the projected inputs.""" | ||
|
|
||
| learn_projection: bool = field( | ||
| default=True, validator=instance_of(bool), kw_only=True | ||
| ) | ||
| """See :attr:`baybe.kernels.composite.ProjectionKernel.learn_projection`.""" | ||
|
|
||
| @base_kernel_factory.default | ||
| def _default_base_kernel_factory(self) -> KernelFactory: | ||
| from baybe.surrogates.gaussian_process.presets.default import ( | ||
| DefaultKernelFactory, | ||
| ) | ||
|
|
||
| return DefaultKernelFactory() | ||
|
|
||
| @override | ||
| def __call__( | ||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||
| ) -> Kernel: | ||
| base_kernel = self.base_kernel_factory(searchspace, train_x, train_y) | ||
| projection_matrices = _make_projection_matrices( | ||
| n_projections=self.n_projections, | ||
| n_matrices=self.n_matrices, | ||
| initialization=self.initialization, | ||
| train_x=train_x.numpy(), | ||
| train_y=train_y.numpy(), | ||
| ) | ||
| kernels = [ | ||
| ProjectionKernel( | ||
| base_kernel=base_kernel, | ||
| projection_matrix=m, | ||
| learn_projection=self.learn_projection, | ||
| ) | ||
| for m in projection_matrices | ||
| ] | ||
| return AdditiveKernel(kernels) if self.n_matrices > 1 else kernels[0] | ||
|
|
||
|
|
||
| # Collect leftover original slotted classes processed by `attrs.define` | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No lazy import? I assume this is since this file itself will only be imported lazily, or has this been overlooked?