Skip to content

Commit f223793

Browse files
authored
Merge: Fix CI (emdgroup#704)
- fixes an issue that came up after latest mypy release - reverses the sorting order in our FPS utility to make the results consistent with the new `fpsample>=1.0.0` release
2 parents 1f4d1de + d77dff8 commit f223793

2 files changed

Lines changed: 6 additions & 3 deletions

File tree

baybe/surrogates/gaussian_process/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
145145
import botorch
146146
import gpytorch
147147
import torch
148+
from botorch.models.transforms import Normalize, Standardize
148149

149150
# FIXME[typing]: It seems there is currently no better way to inform the type
150151
# checker that the attribute is available at the time of the function call
@@ -155,12 +156,12 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
155156
numerical_idxs = context.get_numerical_indices(train_x.shape[-1])
156157

157158
# For GPs, we let botorch handle the scaling. See [Scaling Workaround] above.
158-
input_transform = botorch.models.transforms.Normalize(
159+
input_transform = Normalize(
159160
train_x.shape[-1],
160161
bounds=context.parameter_bounds,
161162
indices=list(numerical_idxs),
162163
)
163-
outcome_transform = botorch.models.transforms.Standardize(train_y.shape[-1])
164+
outcome_transform = Standardize(train_y.shape[-1])
164165

165166
# extract the batch shape of the training data
166167
batch_shape = train_x.shape[:-2]

baybe/utils/sampling_algorithms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def farthest_point_sampling(
114114
return list(range(n_samples))
115115

116116
# Sort the points to produce the same result regardless of the input order
117-
sort_idx = np.lexsort(tuple(points.T))
117+
# The choice here is done to produce the same results as fpsample==1.0.0
118+
# See https://github.com/leonardodalinky/fpsample/issues/10
119+
sort_idx = np.lexsort(tuple(points.T))[::-1]
118120
points = points[sort_idx]
119121

120122
# Pre-compute the pairwise distances between all points

0 commit comments

Comments
 (0)