Skip to content
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- `coefficients` attribute for `DiscreteSumConstraint`, enabling weighted sums.
Defaults to all-ones (unweighted), preserving existing behavior. Follows the
same pattern as `ContinuousLinearConstraint.coefficients`
- `simplex_coefficients` keyword argument to `SubspaceDiscrete.from_simplex` for
weighted simplex sum constraints. Defaults to all-ones
- Support for Python 3.14
- `Settings` class for unified and streamlined settings management
- Settings options to (de-)activate recommendation caching / dataframe preprocessing
Expand All @@ -21,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Breaking Changes
- `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved
from `baybe.searchspace.discrete` to `baybe.searchspace.utils`
- All optional arguments of `SubspaceDiscrete.from_simplex` after `simplex_parameters`
are now keyword-only
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)

Expand Down
51 changes: 46 additions & 5 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from __future__ import annotations

import gc
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import reduce
from typing import TYPE_CHECKING, Any, ClassVar, cast

import cattrs
import pandas as pd
from attrs import define, field
from attrs.validators import in_, min_len
from attrs.validators import deep_iterable, in_, min_len
from typing_extensions import override

from baybe.constraints.base import CardinalityConstraint, DiscreteConstraint
Expand All @@ -25,6 +26,7 @@
converter,
)
from baybe.utils.basic import Dummy
from baybe.utils.validation import finite_float

if TYPE_CHECKING:
import polars as pl
Expand Down Expand Up @@ -76,7 +78,11 @@ def get_invalid_polars(self) -> pl.Expr:

@define
class DiscreteSumConstraint(DiscreteConstraint):
"""Class for modelling sum constraints."""
"""Class for modelling sum constraints.

The constraint evaluates whether the (optionally weighted) sum of the specified
parameters satisfies the given threshold condition.
"""

# IMPROVE: refactor `SumConstraint` and `ProdConstraint` to avoid code copying

Expand All @@ -93,9 +99,43 @@ class DiscreteSumConstraint(DiscreteConstraint):
condition: ThresholdCondition = field()
"""The condition modeled by this constraint."""

coefficients: tuple[float, ...] = field(
converter=lambda x: cattrs.structure(x, tuple[float, ...]),
validator=deep_iterable(member_validator=finite_float),
)
"""The coefficients for the weighted sum, one per entry in ``parameters``.

Defaults to all-ones, i.e. an unweighted sum."""

@coefficients.default
def _default_coefficients(self) -> tuple[float, ...]:
"""Return equal weight coefficients as default."""
return (1.0,) * len(self.parameters)

@coefficients.validator
def _validate_coefficients( # noqa: DOC101, DOC103
self, _: Any, coefficients: Sequence[float]
) -> None:
"""Validate the coefficients.

Raises:
ValueError: If the number of coefficients does not match the number of
parameters.
"""
if len(self.parameters) != len(coefficients):
raise ValueError(
"The given 'coefficients' list must have one floating point entry for "
"each entry in 'parameters'."
)

@override
def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index:
evaluate_df = df[self.parameters].sum(axis=1)
evaluate_df = pd.Series(
sum(
df[p].to_numpy() * c for p, c in zip(self.parameters, self.coefficients)
),
index=df.index,
)
mask_bad = ~self.condition.evaluate(evaluate_df)

return df.index[mask_bad]
Expand All @@ -104,7 +144,8 @@ def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index:
def get_invalid_polars(self) -> pl.Expr:
from baybe._optional.polars import polars as pl

return self.condition.to_polars(pl.sum_horizontal(self.parameters)).not_()
weighted = [pl.col(p) * c for p, c in zip(self.parameters, self.coefficients)]
return self.condition.to_polars(pl.sum_horizontal(weighted)).not_()


@define
Expand Down
197 changes: 113 additions & 84 deletions baybe/searchspace/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def from_simplex(
cls,
max_sum: float,
simplex_parameters: Sequence[NumericalDiscreteParameter],
*,
simplex_coefficients: Sequence[float] | None = None,
product_parameters: Sequence[DiscreteParameter] | None = None,
constraints: Sequence[DiscreteConstraint] | None = None,
min_nonzero: int = 0,
Expand All @@ -286,8 +288,12 @@ def from_simplex(
significantly faster construction.

Args:
max_sum: The maximum sum of the parameter values defining the simplex size.
max_sum: The maximum (weighted) sum of the parameter values defining the
simplex size.
simplex_parameters: The parameters to be used for the simplex construction.
simplex_coefficients: Optional coefficients for the weighted sum, one per
entry in ``simplex_parameters``. Defaults to all-ones, i.e. an
unweighted sum.
product_parameters: Optional parameters that enter in form of a Cartesian
product.
constraints: See :class:`baybe.searchspace.core.SearchSpace`.
Expand All @@ -302,6 +308,8 @@ def from_simplex(
Raises:
ValueError: If the passed simplex parameters are not suitable for a simplex
construction.
ValueError: If the length of ``simplex_coefficients`` does not match the
number of ``simplex_parameters``.
ValueError: If the passed product parameters are not discrete.
ValueError: If the passed simplex parameters and product parameters are
not disjoint.
Expand All @@ -321,6 +329,8 @@ def from_simplex(
constraints = []
if max_nonzero is None:
max_nonzero = len(simplex_parameters)
if simplex_coefficients is None:
simplex_coefficients = [1.0] * len(simplex_parameters)

# Validate constraints
validate_constraints(constraints, [*simplex_parameters, *product_parameters])
Expand All @@ -339,6 +349,14 @@ def from_simplex(
f"must be of subclasses of '{DiscreteParameter.__name__}'."
)

# Validate coefficients length
if len(simplex_coefficients) != len(simplex_parameters):
raise ValueError(
f"'simplex_coefficients' must have one entry per 'simplex_parameters' "
f"entry, but got {len(simplex_coefficients)} coefficient(s) for "
f"{len(simplex_parameters)} parameter(s)."
)

# Validate no overlap between simplex parameters and product parameters
simplex_parameters_names = {p.name for p in simplex_parameters}
product_parameters_names = {p.name for p in product_parameters}
Expand All @@ -360,79 +378,54 @@ def from_simplex(
if len(simplex_parameters) < 1:
return cls.from_product(product_parameters, constraints)

# Validate non-negativity
min_values = [min(p.values) for p in simplex_parameters]
max_values = [max(p.values) for p in simplex_parameters]
if not (min(min_values) >= 0.0):
# Validate non-negativity of raw parameter values (required by the algorithm)
min_raw = [min(p.values) for p in simplex_parameters]
max_raw = [max(p.values) for p in simplex_parameters]
if any(v < 0.0 for v in min_raw):
raise ValueError(
f"All simplex_parameters passed to '{cls.from_simplex.__name__}' "
f"must have non-negative values only."
)

def drop_invalid(
df: pd.DataFrame,
max_sum: float,
boundary_only: bool,
min_nonzero: int | None = None,
max_nonzero: int | None = None,
) -> None:
"""Drop rows that violate the specified simplex constraint.

Args:
df: The dataframe whose rows should satisfy the simplex constraint.
max_sum: The maximum row sum defining the simplex size.
boundary_only: Flag to control if the points represented by the rows
may lie inside the simplex or on its boundary only.
min_nonzero: Minimum number of nonzero parameters required per row.
max_nonzero: Maximum number of nonzero parameters allowed per row.
"""
# Apply sum constraints
row_sums = df.sum(axis=1)
mask_violated = row_sums > max_sum + tolerance
if boundary_only:
mask_violated |= row_sums < max_sum - tolerance

# Apply optional nonzero constraints
if (min_nonzero is not None) or (max_nonzero is not None):
n_nonzero = (df != 0.0).sum(axis=1)
if min_nonzero is not None:
mask_violated |= n_nonzero < min_nonzero
if max_nonzero is not None:
mask_violated |= n_nonzero > max_nonzero

# Remove violating rows
idxs_to_drop = df[mask_violated].index
df.drop(index=idxs_to_drop, inplace=True)

# Get the minimum sum contributions to come in the upcoming joins (the
# first item is the minimum possible sum of all parameters starting from the
# second parameter, the second item is the minimum possible sum starting from
# the third parameter, and so on ...)
min_sum_upcoming = np.cumsum(min_values[:0:-1])[::-1]

# Get the min/max number of nonzero values to come in the upcoming joins (the
# first item is the min/max number of nonzero parameters starting from the
# second parameter, the second item is the min/max number starting from
# the third parameter, and so on ...)
min_nonzero_upcoming = np.cumsum((np.asarray(min_values) > 0.0)[:0:-1])[::-1]
max_nonzero_upcoming = np.cumsum((np.asarray(max_values) > 0.0)[:0:-1])[::-1]

# Incrementally build up the space, dropping invalid configuration along the
# way. More specifically:
# * After having cross-joined a new parameter, there must
# be enough "room" left for the remaining parameters to fit. That is,
# configurations of the current parameter subset that exceed the desired
# total value minus the minimum contribution to come from the yet-to-be-added
# parameters can be already discarded, because it is already clear that
# the total sum will be exceeded once all joins are completed.
# * Analogously, there must be enough "nonzero slots" left for the yet to be
# joined parameters, i.e. parameter subset configurations can be discarded
# where the number of nonzero parameters already exceeds the maximum number
# of nonzeros minus the number of nonzeros to come, because it is already
# clear that the maximum will be exceeded once all joins are completed.
# * Similarly, it can be verified for each parameter that there are still
# enough nonzero parameters to come to even reach the minimum
# desired number of nonzero after all joins.
# Compute per-parameter minimum weighted contributions.
# For a positive coefficient c the minimum contribution is c*min_raw; for a
# negative coefficient the ordering flips and it becomes c*max_raw. Taking
# min of both products handles any real coefficient correctly.
coeffs = np.asarray(simplex_coefficients, dtype=float)
Comment thread
Scienfitz marked this conversation as resolved.
if not np.isfinite(coeffs).all():
raise ValueError(
f"All simplex_coefficients passed to '{cls.from_simplex.__name__}' "
f"must be finite numbers."
)
min_weighted = np.array(
[min(c * lo, c * hi) for c, lo, hi in zip(coeffs, min_raw, max_raw)]
)

# Get the minimum weighted sum contributions to come in the upcoming joins (the
# first item is the minimum possible weighted sum of all parameters starting
# from the second parameter, the second item is the minimum possible weighted
# sum starting from the third parameter, and so on ...)
min_sum_upcoming = np.cumsum(min_weighted[:0:-1])[::-1]

# Get the min/max number of nonzero values to come in the upcoming joins.
# Nonzero counting is based on raw parameter values, not weighted values,
# because the cardinality constraint counts zero/nonzero entries regardless
# of the coefficient signs.
min_nonzero_upcoming = np.cumsum((np.asarray(min_raw) > 0.0)[:0:-1])[::-1]
max_nonzero_upcoming = np.cumsum((np.asarray(max_raw) > 0.0)[:0:-1])[::-1]

# Incrementally build up the space as a numpy array, dropping invalid
# configurations along the way. Working with raw numpy avoids pandas overhead
# (index management, BlockManager, merge machinery) in the hot loop.
#
# After having cross-joined a new parameter, there must be enough "room" left
# for the remaining parameters to fit. That is, configurations of the current
# parameter subset that exceed the desired total value minus the minimum
# contribution to come from the yet-to-be-added parameters can be already
# discarded, because it is already clear that the total sum will be exceeded
# once all joins are completed. Analogously, nonzero cardinality bounds are
# checked at each step.
arr: np.ndarray
for i, (
param,
min_sum_to_go,
Expand All @@ -446,27 +439,44 @@ def drop_invalid(
np.append(max_nonzero_upcoming, 0),
)
):
values = np.asarray(param.values, dtype=float)

if i == 0:
exp_rep = pd.DataFrame({param.name: param.values})
arr = values.reshape(-1, 1)
else:
exp_rep = pd.merge(
exp_rep, pd.DataFrame({param.name: param.values}), how="cross"
n_old = arr.shape[0]
n_new = len(values)
arr = np.column_stack(
[
np.repeat(arr, n_new, axis=0),
np.tile(values, n_old),
]
)
drop_invalid(
exp_rep,
max_sum=max_sum - min_sum_to_go,
# the maximum possible number of nonzeros to come dictates if we
# can achieve our minimum constraint in the end:
min_nonzero=min_nonzero - max_nonzero_to_go,
# the minimum possible number of nonzeros to come dictates if we
# can stay below the targeted maximum in the end:
max_nonzero=max_nonzero - min_nonzero_to_go,
boundary_only=False,
)

# Compute weighted row sums and build validity mask
row_sums = arr @ coeffs[: i + 1]
mask = row_sums <= (max_sum - min_sum_to_go) + tolerance

# Apply nonzero cardinality bounds
effective_min = min_nonzero - max_nonzero_to_go
effective_max = max_nonzero - min_nonzero_to_go
if effective_min > 0 or effective_max < len(simplex_parameters):
n_nz = np.count_nonzero(arr, axis=1)
if effective_min > 0:
mask &= n_nz >= effective_min
if effective_max < len(simplex_parameters):
mask &= n_nz <= effective_max

arr = arr[mask]

# If requested, keep only the boundary values
if boundary_only:
drop_invalid(exp_rep, max_sum, boundary_only=True)
row_sums = arr @ coeffs
mask = np.abs(row_sums - max_sum) <= tolerance
arr = arr[mask]

# Wrap in DataFrame
exp_rep = pd.DataFrame(arr, columns=[p.name for p in simplex_parameters])

# Merge product parameters and apply constraints incrementally
exp_rep = build_constrained_product(
Expand Down Expand Up @@ -656,6 +666,25 @@ def validate_simplex_subspace_from_config(specs: dict, _) -> None:
f"values only."
)

simplex_coefficients = specs.get("simplex_coefficients", None)
if simplex_coefficients is not None:
try:
simplex_coefficients = converter.structure(
simplex_coefficients, list[float]
)
except (IterableValidationError, TypeError, ValueError) as exc:
raise ValueError(
"'simplex_coefficients' must be a list of numeric values."
) from exc

if len(simplex_coefficients) != len(simplex_parameters):
raise ValueError(
f"'simplex_coefficients' must have one entry per "
f"'simplex_parameters' entry, but got "
f"{len(simplex_coefficients)} coefficient(s) for "
f"{len(simplex_parameters)} parameter(s)."
)

product_parameters = specs.get("product_parameters", [])
if product_parameters:
product_parameters = converter.structure(
Expand Down
Loading
Loading