Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b866f56
Add human readable id to example tests
Scienfitz Apr 12, 2024
71d4d62
Adapt deprecated assignment
Scienfitz Apr 12, 2024
a079ae7
Make non-MC ACQF error more precise
Scienfitz Apr 12, 2024
d69905c
Remove debotorchize
Scienfitz Apr 12, 2024
8503d47
Fix streamlit example
Scienfitz Apr 12, 2024
0cdef0f
Implement to_botorch conversion
Scienfitz Apr 12, 2024
d7598e7
Add acqf iteration tests
Scienfitz Apr 12, 2024
e38f02a
Reorder acqfs
Scienfitz Apr 12, 2024
18e2c60
Add simple regret
Scienfitz Apr 12, 2024
f54793b
Add q-noisy variants
Scienfitz Apr 12, 2024
99dd0d6
Add log variants
Scienfitz Apr 12, 2024
eedecf3
Extend hypothesis
Scienfitz Apr 12, 2024
b163f91
Update CHANGELOG.md
Scienfitz Apr 12, 2024
d7ef923
Use custom classproperty
Scienfitz Apr 25, 2024
6e40556
Replace botorch factory
Scienfitz Apr 25, 2024
f2204f8
Add pruning option
Scienfitz Apr 25, 2024
0b9b27f
Ignore is_mc in docs
Scienfitz Apr 26, 2024
064c61a
Add chaining deprecation reference
Scienfitz Apr 26, 2024
8e2dcf3
Update prune_basline validator
Scienfitz Apr 26, 2024
11fc52a
Make adapter module private
Scienfitz Apr 25, 2024
95ade5e
Reorder acqfs in init
Scienfitz Apr 26, 2024
21a48a3
Simplify test file parameterization
Scienfitz Apr 26, 2024
81b422a
Reorder acqf hypothesis according to complexity
Scienfitz Apr 26, 2024
4168475
Improve seq greedy error handling
Scienfitz Apr 26, 2024
16db826
Adjust seq greedy continuous error handling
Scienfitz Apr 29, 2024
580c3ff
Adjust seq greedy hybrid error handling
Scienfitz Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Comment thread
Scienfitz marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `mypy` for search space and objectives
- Class hierarchy for objectives
- Deserialization is now also possible from optional class name abbreviations
- Hypothesis strategies for acquisition functions
- `Kernel` base class allowing to specify kernels
- `MaternKernel` class can be chosen for GP surrogates
- `hypothesis` strategies and roundtrip test for kernels, constraints and objectives
- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives and acquisition
functions
- New acquisition functions: `qSR`, `qNEI`, `LogEI`, `qLogEI`, `qLogNEI`
Comment thread
AVHopp marked this conversation as resolved.

### Changed
- Reorganized acquisition.py into `acquisition` subpackage
Expand Down
52 changes: 42 additions & 10 deletions baybe/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,67 @@

from baybe.acquisition.acqfs import (
ExpectedImprovement,
LogExpectedImprovement,
PosteriorMean,
ProbabilityOfImprovement,
UpperConfidenceBound,
qExpectedImprovement,
qLogExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
qProbabilityOfImprovement,
qSimpleRegret,
qUpperConfidenceBound,
)

PM = PosteriorMean
qSR = qSimpleRegret
EI = ExpectedImprovement
PI = ProbabilityOfImprovement
UCB = UpperConfidenceBound
qEI = qExpectedImprovement
LogEI = LogExpectedImprovement
qLogEI = qLogExpectedImprovement
qNEI = qNoisyExpectedImprovement
qLogNEI = qLogNoisyExpectedImprovement
PI = ProbabilityOfImprovement
qPI = qProbabilityOfImprovement
UCB = UpperConfidenceBound
qUCB = qUpperConfidenceBound

__all__ = [
# ---------------------------
# Acquisition functions
######################### Acquisition functions
# Posterior Mean
"PosteriorMean",
# Simple Regret
"qSimpleRegret",
# Expected Improvement
"ExpectedImprovement",
"ProbabilityOfImprovement",
"UpperConfidenceBound",
"qExpectedImprovement",
"LogExpectedImprovement",
"qLogExpectedImprovement",
"qNoisyExpectedImprovement",
"qLogNoisyExpectedImprovement",
# Probability of Improvement
"ProbabilityOfImprovement",
"qProbabilityOfImprovement",
# Upper Confidence Bound
"UpperConfidenceBound",
"qUpperConfidenceBound",
# ---------------------------
# Abbreviations
######################### Abbreviations
# Posterior Mean
"PM",
# Simple Regret
"qSR",
# Expected Improvement
"EI",
"PI",
"UCB",
"qEI",
"LogEI",
"qLogEI",
"qNEI",
"qLogNEI",
# Probability of Improvement
"PI",
"qPI",
# Upper Confidence Bound
"UCB",
"qUCB",
]
45 changes: 45 additions & 0 deletions baybe/acquisition/_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Adapter for making BoTorch's acquisition functions work with BayBE models."""

from typing import Any, Callable, Optional

import gpytorch.distributions
from botorch.models.gpytorch import Model
from botorch.posteriors import Posterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from torch import Tensor

from baybe.surrogates.base import Surrogate


class AdapterModel(Model):
"""A BoTorch model that uses a BayBE surrogate model for posterior computation.

Can be used, for example, as an adapter layer for making a BayBE
surrogate model usable in conjunction with BoTorch acquisition functions.

Args:
surrogate: The internal surrogate model
"""

def __init__(self, surrogate: Surrogate):
super().__init__()
self._surrogate = surrogate

@property
def num_outputs(self) -> int: # noqa: D102
# See base class.
# TODO: So far, the usage is limited to single-output models.
return 1

def posterior( # noqa: D102
self,
X: Tensor,
output_indices: Optional[list[int]] = None,
observation_noise: bool = False,
posterior_transform: Optional[Callable[[Posterior], Posterior]] = None,
**kwargs: Any,
) -> Posterior:
# See base class.
mean, var = self._surrogate.posterior(X)
mvn = gpytorch.distributions.MultivariateNormal(mean, var)
return GPyTorchPosterior(mvn)
77 changes: 64 additions & 13 deletions baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,31 @@
from typing import ClassVar

from attrs import define, field
from attrs.validators import ge
from attrs.validators import ge, instance_of

from baybe.acquisition.base import AcquisitionFunction


########################################################################################
### Posterior Mean
@define(frozen=True)
class PosteriorMean(AcquisitionFunction):
"""Posterior mean."""

_abbreviation: ClassVar[str] = "PM"


########################################################################################
### Simple Regret
Comment thread
AVHopp marked this conversation as resolved.
@define(frozen=True)
class qExpectedImprovement(AcquisitionFunction):
"""Monte Carlo based expected improvement."""
class qSimpleRegret(AcquisitionFunction):
"""Monte Carlo based simple regret."""

_abbreviation: ClassVar[str] = "qEI"
_abbreviation: ClassVar[str] = "qSR"


########################################################################################
### Expected Improvement
@define(frozen=True)
class ExpectedImprovement(AcquisitionFunction):
"""Analytical expected improvement."""
Expand All @@ -30,12 +36,48 @@ class ExpectedImprovement(AcquisitionFunction):


@define(frozen=True)
class qProbabilityOfImprovement(AcquisitionFunction):
"""Monte Carlo based probability of improvement."""
class qExpectedImprovement(AcquisitionFunction):
"""Monte Carlo based expected improvement."""

_abbreviation: ClassVar[str] = "qEI"


@define(frozen=True)
class LogExpectedImprovement(AcquisitionFunction):
"""Logarithmic analytical expected improvement."""

_abbreviation: ClassVar[str] = "LogEI"


@define(frozen=True)
class qLogExpectedImprovement(AcquisitionFunction):
"""Logarithmic Monte Carlo based expected improvement."""

_abbreviation: ClassVar[str] = "qLogEI"


@define(frozen=True)
class qNoisyExpectedImprovement(AcquisitionFunction):
"""Monte Carlo based noisy expected improvement."""

_abbreviation: ClassVar[str] = "qNEI"

prune_baseline: bool = field(default=True, validator=instance_of(bool))
"""Auto-prune candidates that are unlikely to be the best."""

_abbreviation: ClassVar[str] = "qPI"

@define(frozen=True)
class qLogNoisyExpectedImprovement(AcquisitionFunction):
"""Logarithmic Monte Carlo based noisy expected improvement."""

_abbreviation: ClassVar[str] = "qLogNEI"

prune_baseline: bool = field(default=True, validator=instance_of(bool))
"""Auto-prune candidates that are unlikely to be the best."""


########################################################################################
### Probability of Improvement
@define(frozen=True)
class ProbabilityOfImprovement(AcquisitionFunction):
"""Analytical probability of improvement."""
Expand All @@ -44,10 +86,19 @@ class ProbabilityOfImprovement(AcquisitionFunction):


@define(frozen=True)
class qUpperConfidenceBound(AcquisitionFunction):
"""Monte Carlo based upper confidence bound."""
class qProbabilityOfImprovement(AcquisitionFunction):
"""Monte Carlo based probability of improvement."""

_abbreviation: ClassVar[str] = "qPI"

_abbreviation: ClassVar[str] = "qUCB"

########################################################################################
### Upper Confidence Bound
@define(frozen=True)
class UpperConfidenceBound(AcquisitionFunction):
"""Analytical upper confidence bound."""

_abbreviation: ClassVar[str] = "UCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
"""Trade-off parameter for mean and variance.
Expand All @@ -59,10 +110,10 @@ class qUpperConfidenceBound(AcquisitionFunction):


@define(frozen=True)
class UpperConfidenceBound(AcquisitionFunction):
"""Analytical upper confidence bound."""
class qUpperConfidenceBound(AcquisitionFunction):
"""Monte Carlo based upper confidence bound."""

_abbreviation: ClassVar[str] = "UCB"
_abbreviation: ClassVar[str] = "qUCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
"""Trade-off parameter for mean and variance.
Expand Down
94 changes: 0 additions & 94 deletions baybe/acquisition/adapter.py

This file was deleted.

Loading