-
Notifications
You must be signed in to change notification settings - Fork 76
Kernel dimension control #748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
b84af1d
Absorb index kernel construction into ICMKernelFactory
AdrianSosic e365862
Add parameter selectors
AdrianSosic dc9ee90
Rename protocols
AdrianSosic 63e64fa
Implement active kernel dimension control
AdrianSosic da8ffc0
Move parameter_names attribute down to BasicKernel subclass
AdrianSosic 86e110b
Fix condition in BayBEKernelFactory
AdrianSosic 3714c5f
Add deprecation mechanism for breaking change in kernel logic
AdrianSosic 66b3dfb
Import KernelFactory to components/__init__.py
AdrianSosic d91623f
Add citation to docstring
AdrianSosic 91bd654
Fix typing
AdrianSosic 02bfc97
Add parameter selection to kernel hypothesis strategies
AdrianSosic ca37ecf
Drop batch_shape argument from Kernel.to_gpytorch
AdrianSosic cec1e5d
Update kernel assembly test
AdrianSosic daaa8f3
Refactor handling of constructor-only attributes in test
AdrianSosic 90d4641
Fix logic of custom kernel converter helper
AdrianSosic 2dd1316
Validate that GP component factories are callable
AdrianSosic 118e019
Update CHANGELOG.md
AdrianSosic aa07e2a
Simplify converter logic
AdrianSosic 70a4a3e
Make ParameterSelector class abstract
AdrianSosic eee1504
Fix test parametrization
AdrianSosic d486c84
Rename selector.py to selectors.py
AdrianSosic 8480607
Fix variable reference in kernel translation test
AdrianSosic 7cd2321
Use chain.from_iterable instead of unpacking
AdrianSosic 9cc4647
Replace hard-coded parameter type in deprecation error message
AdrianSosic 9c5cf18
Add comment explaining the role of active_dims=None
AdrianSosic d13e365
Fix keyword logic in Kernel.to_gpytorch
AdrianSosic 1707e3f
Add NameSelector class
AdrianSosic c859d85
Refine NameSelector.parameter_names field specification
AdrianSosic 3a1c421
Make KernelFactory class abstract
AdrianSosic 66b6f06
Remove unnecessary keyword assignment
AdrianSosic 4edf552
Add converters to ICMKernelFactory
AdrianSosic a1bc442
Revise comment on kernel attribute matching
AdrianSosic 4811dae
Fix converter specification
AdrianSosic 4347c76
Add convenience converter for parameter selectors
AdrianSosic 88468a1
Turn KernelFactory into mixin class for parameter selection
AdrianSosic 2d414a2
Override default selectors for existing factories
AdrianSosic 1b248a5
Turn assert statement into proper AssertionError
AdrianSosic f18cd32
Use regex by default for NameSelector
AdrianSosic a6d611e
Rename parameter_types attribute of TypeSelector to types
AdrianSosic 55582fc
Rephrase changelog entry
AdrianSosic 54a5674
Merge branch 'dev/gp' into refactor/multi_task
AdrianSosic 20f106b
Fix typing issues
AdrianSosic 3c38d34
Add ignore to docs/conf.py
AdrianSosic f943eeb
Add missing searchspace argument
AdrianSosic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
AdrianSosic marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| """Parameter selectors.""" | ||
|
|
||
| import re | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Collection | ||
| from typing import ClassVar, Protocol | ||
|
|
||
| from attrs import Converter, define, field | ||
| from attrs.converters import optional | ||
| from attrs.validators import deep_iterable, instance_of, min_len | ||
| from typing_extensions import override | ||
|
|
||
| from baybe.parameters.base import Parameter | ||
| from baybe.searchspace.core import SearchSpace | ||
| from baybe.utils.basic import to_tuple | ||
| from baybe.utils.conversion import nonstring_to_tuple | ||
|
|
||
|
|
||
| class ParameterSelectorProtocol(Protocol): | ||
| """Type protocol specifying the interface parameter selectors need to implement.""" | ||
|
|
||
| def __call__(self, parameter: Parameter) -> bool: | ||
| """Determine if a parameter should be included in the selection.""" | ||
|
|
||
|
|
||
| @define | ||
| class ParameterSelector(ParameterSelectorProtocol, ABC): | ||
| """Base class for parameter selectors.""" | ||
|
|
||
| exclude: bool = field(default=False, validator=instance_of(bool), kw_only=True) | ||
| """Boolean flag indicating whether invert the selection criterion.""" | ||
|
|
||
| @abstractmethod | ||
|
AdrianSosic marked this conversation as resolved.
|
||
| def _is_match(self, parameter: Parameter) -> bool: | ||
| """Determine if a parameter meets the selection criterion.""" | ||
|
AdrianSosic marked this conversation as resolved.
|
||
|
|
||
| @override | ||
| def __call__(self, parameter: Parameter) -> bool: | ||
| """Determine if a parameter should be included in the selection.""" | ||
| result = self._is_match(parameter) | ||
| return not result if self.exclude else result | ||
|
|
||
|
Scienfitz marked this conversation as resolved.
|
||
|
|
||
| @define | ||
| class TypeSelector(ParameterSelector): | ||
| """Select parameters by type.""" | ||
|
|
||
| types: tuple[type[Parameter], ...] = field(converter=to_tuple) | ||
| """The parameter types to be selected.""" | ||
|
|
||
| @override | ||
| def _is_match(self, parameter: Parameter) -> bool: | ||
| return isinstance(parameter, self.types) | ||
|
|
||
|
|
||
| @define | ||
| class NameSelector(ParameterSelector): | ||
| """Select parameters by name patterns.""" | ||
|
|
||
| patterns: tuple[str, ...] = field( | ||
| converter=Converter( # type: ignore | ||
| nonstring_to_tuple, takes_self=True, takes_field=True | ||
| ), | ||
| validator=[ | ||
| min_len(1), | ||
| deep_iterable(member_validator=instance_of(str)), | ||
| ], | ||
| ) | ||
| """The patterns to be matched against.""" | ||
|
|
||
| regex: bool = field(default=True, validator=instance_of(bool), kw_only=True) | ||
| """If ``False``, the provided patterns are interpreted as literal strings.""" | ||
|
|
||
| @override | ||
| def _is_match(self, parameter: Parameter) -> bool: | ||
| if self.regex: | ||
| return any(re.fullmatch(p, parameter.name) for p in self.patterns) | ||
| return parameter.name in self.patterns | ||
|
|
||
|
|
||
| def to_parameter_selector( | ||
| x: ( | ||
| str | ||
| | type[Parameter] | ||
| | Collection[str] | ||
| | Collection[type[Parameter]] | ||
| | ParameterSelectorProtocol | ||
| ), | ||
| /, | ||
| ) -> ParameterSelectorProtocol: | ||
| """Convert shorthand notations to parameter selectors. | ||
|
|
||
| Convenience converter that allows users to specify parameter selectors using | ||
| simpler types: | ||
|
|
||
| * A callable (i.e., an existing selector or any object satisfying | ||
| :class:`ParameterSelectorProtocol`) is passed through unchanged. | ||
| * A single string is interpreted as a parameter name and wrapped into a | ||
| :class:`NameSelector`. | ||
| * A single :class:`~baybe.parameters.base.Parameter` subclass is wrapped into a | ||
| :class:`TypeSelector`. | ||
| * A collection of strings is converted to a :class:`NameSelector`. | ||
| * A collection of :class:`~baybe.parameters.base.Parameter` subclasses is converted | ||
| to a :class:`TypeSelector`. | ||
|
|
||
| Args: | ||
| x: The object to convert. | ||
|
|
||
| Returns: | ||
| The corresponding parameter selector. | ||
|
|
||
| Raises: | ||
| TypeError: If the input cannot be converted to a parameter selector. | ||
| """ | ||
| if isinstance(x, str): | ||
| return NameSelector([x]) | ||
|
|
||
| if isinstance(x, type) and issubclass(x, Parameter): | ||
| return TypeSelector([x]) | ||
|
|
||
| if callable(x): | ||
| return x | ||
|
|
||
| # At this point, x should be a collection of strings or parameter types | ||
| items = tuple(x) | ||
|
|
||
| if all(isinstance(item, str) for item in items): | ||
| return NameSelector(items) | ||
|
|
||
| if all(isinstance(item, type) and issubclass(item, Parameter) for item in items): | ||
| return TypeSelector(items) | ||
|
|
||
| raise TypeError(f"Cannot convert {x!r} to a parameter selector.") | ||
|
|
||
|
|
||
| @define | ||
| class _ParameterSelectorMixin: | ||
| """A mixin class to enable parameter selection.""" | ||
|
|
||
| # For internal use only: sanity check mechanism to remind developers of new | ||
| # subclasses to actually use the parameter selector when it is provided | ||
| # TODO: Perhaps we can find a more elegant way to enforce this by design | ||
| _uses_parameter_names: ClassVar[bool] = False | ||
|
|
||
| parameter_selector: ParameterSelectorProtocol | None = field( | ||
| default=None, converter=optional(to_parameter_selector), kw_only=True | ||
| ) | ||
| """An optional selector to specify which parameters are to be considered.""" | ||
|
|
||
| def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None: | ||
| """Get the names of the parameters to be considered.""" | ||
| if self.parameter_selector is None: | ||
| return None | ||
|
|
||
| return tuple( | ||
| p.name for p in searchspace.parameters if self.parameter_selector(p) | ||
| ) | ||
|
|
||
| def __attrs_post_init__(self): | ||
| if self.parameter_selector is not None and not self._uses_parameter_names: | ||
| raise AssertionError( | ||
| f"A `parameter_selector` was provided to " | ||
| f"`{type(self).__name__}`, but the class does not set " | ||
| f"`_uses_parameter_names = True`. Subclasses that accept a " | ||
| f"parameter selector must explicitly set this flag to confirm " | ||
| f"they actually use the selected parameter names." | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.