Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b1d4589
Add UnsupportedEarlyFilteringError exception
Scienfitz Apr 1, 2026
0312a0b
Add incremental filtering interface to DiscreteConstraint
Scienfitz Apr 1, 2026
55bdf24
Add incremental filtering to discrete constraint implementations
Scienfitz Apr 1, 2026
aec6686
Add searchspace utility functions
Scienfitz Apr 1, 2026
a7b8594
Rework from_product and from_simplex to use constrained Cartesian pro…
Scienfitz Apr 1, 2026
5123372
Update test imports for moved utility functions
Scienfitz Apr 1, 2026
6006ac4
Add tests for constrained Cartesian product
Scienfitz Apr 1, 2026
beb23a1
Update CHANGELOG
Scienfitz Apr 1, 2026
05f92fd
Simplify logic
Scienfitz Apr 9, 2026
2bf0f55
Make has_polars_implementation a classproperty
Scienfitz Apr 9, 2026
232d4d0
Rename _required_filtering_parameters to _required_parameters
Scienfitz Apr 9, 2026
8973d3f
Add breaking changes entry for moved utility functions
Scienfitz Apr 9, 2026
2c634b5
Improve UnsupportedEarlyFilteringError messages to reference the data…
Scienfitz Apr 10, 2026
05fe304
Add allow_missing flag to get_invalid and remove UnsupportedEarlyFilt…
Scienfitz Apr 13, 2026
c1c53f3
Avoid FutureWarning in PermutationInvarianceConstraint concat
Scienfitz Apr 13, 2026
3360cd5
Lift _required_parameters to Constraint base and use it in validation
Scienfitz Apr 22, 2026
5e5d731
Use warning instead of error
Scienfitz Apr 22, 2026
c4008d9
Remove dead Polars constraint fallback code
Scienfitz Apr 22, 2026
8f2e521
Rename compute_parameter_order to optimize_parameter_order
Scienfitz Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `identify_non_dominated_configurations` method to `Campaign` and `Objective`
for determining the Pareto front
- Interpoint constraints for continuous search spaces
- `has_polars_implementation` property on `DiscreteConstraint`
- `allow_missing` flag on `DiscreteConstraint.get_invalid` and `get_valid`

### Changed
- Discrete search space construction now applies constraints incrementally during
Cartesian product building, significantly reducing memory usage and construction
time for constrained spaces
- Polars path in discrete search space construction now builds the Cartesian product
only for parameters involved in Polars-capable constraints, merging the rest
incrementally via pandas
Comment thread
Scienfitz marked this conversation as resolved.
Comment on lines +23 to +25
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Polars path in discrete search space construction now builds the Cartesian product
only for parameters involved in Polars-capable constraints, merging the rest
incrementally via pandas
- `Polars` path in discrete search space construction now builds the Cartesian product
only for parameters involved in `Polars`-capable constraints, merging the rest
incrementally via `pandas`

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be quite inconsistent with the existing CHANGELOG fyi


### Breaking Changes
- `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved
from `baybe.searchspace.discrete` to `baybe.searchspace.utils`
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)

Expand Down
69 changes: 64 additions & 5 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from baybe.serialization.core import (
converter,
)
from baybe.utils.basic import classproperty

if TYPE_CHECKING:
import polars as pl
Expand Down Expand Up @@ -81,6 +82,17 @@ def is_discrete(self) -> bool:
"""Boolean indicating if this is a constraint over discrete parameters."""
return isinstance(self, DiscreteConstraint)

@property
def _required_parameters(self) -> set[str]:
"""All parameter names needed for full constraint evaluation.

For most constraints, this is simply the set of names from
:attr:`~baybe.constraints.base.Constraint.parameters`.
Constraints with additional parameter references (e.g., affected
parameters in dependency constraints) override this to include those.
"""
return set(self.parameters)


@define
class DiscreteConstraint(Constraint, ABC):
Expand All @@ -97,29 +109,76 @@ class DiscreteConstraint(Constraint, ABC):
eval_during_modeling: ClassVar[bool] = False
# See base class.

def get_valid(self, df: pd.DataFrame, /) -> pd.Index:
def get_valid(
self, df: pd.DataFrame, /, *, allow_missing: bool = False
) -> pd.Index:
"""Get the indices of dataframe entries that are valid under the constraint.

Args:
df: A dataframe where each row represents a parameter configuration.
allow_missing: If ``False``, a :class:`ValueError` is raised when
the dataframe is missing required parameter columns. If
``True``, the constraint performs partial filtering on the
available columns.

Returns:
The dataframe indices of rows that fulfill the constraint.
"""
invalid = self.get_invalid(df)
invalid = self.get_invalid(df, allow_missing=allow_missing)
return df.index.drop(invalid)

@abstractmethod
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
def get_invalid(
self, data: pd.DataFrame, /, *, allow_missing: bool = False
) -> pd.Index:
"""Get the indices of dataframe entries that are invalid under the constraint.

Args:
data: A dataframe where each row represents a parameter configuration.
data: A dataframe where each row represents a parameter
configuration.
allow_missing: If ``False``, a :class:`ValueError` is raised when
the dataframe is missing required parameter columns. If
``True``, the constraint performs partial filtering on the
available columns, returning an empty index when insufficient
columns are present.

Raises:
ValueError: If ``allow_missing`` is ``False`` and the dataframe
is missing required parameter columns.

Returns:
The dataframe indices of rows that violate the constraint.
"""
# TODO: Should switch backends (pandas/polars/...) behind the scenes
if not allow_missing:
if missing := self._required_parameters - set(data.columns):
raise ValueError(
f"'{self.__class__.__name__}' requires columns {missing} "
f"which are missing from the dataframe."
)
return self._get_invalid(data)

@abstractmethod
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
"""Get the indices of invalid entries (implementation for subclasses).

Subclasses implement this method with their specific filtering logic.
When the dataframe contains only a subset of the constraint's
parameters, implementations should return an empty index if they
cannot perform useful filtering.

Args:
data: A dataframe where each row represents a parameter
configuration. May contain all or a subset of the constraint's
parameters.

Returns:
The dataframe indices of rows that violate the constraint.
"""

@classproperty
def has_polars_implementation(cls) -> bool:
"""Whether this constraint class has a Polars implementation."""
return cls.get_invalid_polars is not DiscreteConstraint.get_invalid_polars

def get_invalid_polars(self) -> pl.Expr:
"""Translate the constraint to Polars expression identifying undesired rows.
Expand Down
123 changes: 95 additions & 28 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,19 @@ class DiscreteExcludeConstraint(DiscreteConstraint):
"""Operator encoding how to combine the individual conditions."""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
satisfied = [
cond.evaluate(data[self.parameters[k]])
for k, cond in enumerate(self.conditions)
]
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
pairs = [(p, c) for p, c in zip(self.parameters, self.conditions) if p in data]
if not pairs:
return pd.Index([])

# Only the OR combiner supports incremental filtering: a single
# true condition is sufficient to mark a row as invalid.
if self.combiner != "OR" and len(pairs) < len(self.parameters):
Comment thread
Scienfitz marked this conversation as resolved.
return pd.Index([])

satisfied = [cond.evaluate(data[p]) for p, cond in pairs]
res = reduce(_valid_logic_combiners[self.combiner], satisfied)

return data.index[res]

@override
Expand Down Expand Up @@ -78,7 +85,13 @@ class DiscreteSumConstraint(DiscreteConstraint):
"""The condition modeled by this constraint."""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
# IMPROVE: Look-ahead filtering would be possible if parameter
# value ranges (min/max) were available to the constraint, allowing
# bound-based pruning of partial sums before all parameters are
# present.
if not set(self.parameters) <= set(data.columns):
return pd.Index([])
evaluate_data = data[self.parameters].sum(axis=1)
mask_bad = ~self.condition.evaluate(evaluate_data)

Expand Down Expand Up @@ -106,7 +119,13 @@ class DiscreteProductConstraint(DiscreteConstraint):
"""The condition that is used for this constraint."""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
# IMPROVE: Look-ahead filtering would be possible if parameter
# value ranges (min/max) were available to the constraint, allowing
# bound-based pruning of partial products before all parameters are
# present.
if not set(self.parameters) <= set(data.columns):
return pd.Index([])
evaluate_data = data[self.parameters].prod(axis=1)
mask_bad = ~self.condition.evaluate(evaluate_data)

Expand Down Expand Up @@ -140,8 +159,11 @@ class DiscreteNoLabelDuplicatesConstraint(DiscreteConstraint):
"""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
mask_bad = data[self.parameters].nunique(axis=1) != len(self.parameters)
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
params = [p for p in self.parameters if p in data]
if len(params) < 2:
return pd.Index([])
mask_bad = data[params].nunique(axis=1) != len(params)

return data.index[mask_bad]

Expand All @@ -158,6 +180,7 @@ def get_invalid_polars(self) -> pl.Expr:
return expr


@define
class DiscreteLinkedParametersConstraint(DiscreteConstraint):
"""Constraint class for linking the values of parameters.

Expand All @@ -168,8 +191,11 @@ class DiscreteLinkedParametersConstraint(DiscreteConstraint):
"""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
mask_bad = data[self.parameters].nunique(axis=1) != 1
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
params = [p for p in self.parameters if p in set(data.columns)]
if len(params) < 2:
return pd.Index([])
mask_bad = data[params].nunique(axis=1) != 1
Comment thread
Scienfitz marked this conversation as resolved.

return data.index[mask_bad]

Expand Down Expand Up @@ -228,8 +254,19 @@ def _validate_affected_parameters( # noqa: DOC101, DOC103
f"the conditions list."
)

@property
@override
def _required_parameters(self) -> set[str]:
"""See base class."""
params = set(self.parameters)
for group in self.affected_parameters:
params.update(group)
return params

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
if not self._required_parameters <= set(data.columns):
return pd.Index([])
# Create data copy and mark entries where the dependency conditions are negative
# with a dummy value to cause degeneracy.
censored_data = data.copy()
Expand Down Expand Up @@ -293,28 +330,45 @@ class DiscretePermutationInvarianceConstraint(DiscreteConstraint):
dependencies: DiscreteDependenciesConstraint | None = field(default=None)
"""Dependencies connected with the invariant parameters."""

@property
@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
def _required_parameters(self) -> set[str]:
"""See base class."""
params = set(self.parameters)
if self.dependencies:
params.update(self.dependencies._required_parameters)
return params

@override
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
cols = set(data.columns)
params = [p for p in self.parameters if p in cols]
if len(params) < 2:
return pd.Index([])
# When dependencies exist, permutation dedup on a partial set of
# parameters is not safe because the dependency logic can change
# which permutations are equivalent. In this case, only the
# label-dedup part (which is always safe incrementally) is applied.
if self.dependencies:
if not self._required_parameters <= cols:
return DiscreteNoLabelDuplicatesConstraint(
Comment thread
Scienfitz marked this conversation as resolved.
parameters=params
).get_invalid(data)

# Get indices of entries with duplicate label entries. These will also be
# dropped by this constraint.
mask_duplicate_labels = pd.Series(False, index=data.index)
mask_duplicate_labels[
DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid(
data
)
DiscreteNoLabelDuplicatesConstraint(parameters=params).get_invalid(data)
] = True

# Merge a permutation invariant representation of all affected parameters with
# the other parameters and indicate duplicates. This ensures that variation in
# other parameters is also accounted for.
other_params = data.columns.drop(self.parameters).tolist()
df_eval = pd.concat(
[
data[other_params].copy(),
data[self.parameters].apply(cast(Callable, frozenset), axis=1),
],
axis=1,
).loc[
other_params = data.columns.drop(params).tolist()
frozen = data[params].apply(cast(Callable, frozenset), axis=1)
parts = [data[other_params].copy(), frozen] if other_params else [frozen]
df_eval = pd.concat(parts, axis=1).loc[
~mask_duplicate_labels # only consider label-duplicate-free part
]
mask_duplicate_permutations = df_eval.duplicated(keep="first")
Expand Down Expand Up @@ -349,7 +403,9 @@ class DiscreteCustomConstraint(DiscreteConstraint):
you want to keep/remove."""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
if not set(self.parameters) <= set(data.columns):
return pd.Index([])
mask_bad = ~self.validator(data[self.parameters])

return data.index[mask_bad]
Expand All @@ -364,10 +420,21 @@ class DiscreteCardinalityConstraint(CardinalityConstraint, DiscreteConstraint):
# See base class.

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
non_zeros = (data[self.parameters] != 0.0).sum(axis=1)
def _get_invalid(self, data: pd.DataFrame) -> pd.Index:
cols = set(data.columns)
params = [p for p in self.parameters if p in cols]
if not params:
return pd.Index([])
all_present = len(params) == len(self.parameters)

non_zeros = (data[params] != 0.0).sum(axis=1)
# The max_cardinality check is safe on a partial subset: the nonzero
# count can only increase as more parameters are added.
mask_bad = non_zeros > self.max_cardinality
mask_bad |= non_zeros < self.min_cardinality
# The min_cardinality check can only be applied when all parameters
# are present, since missing parameters could still add nonzero values.
if all_present:
mask_bad |= non_zeros < self.min_cardinality
return data.index[mask_bad]


Expand Down
13 changes: 7 additions & 6 deletions baybe/constraints/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,30 @@ def validate_constraints( # noqa: DOC101, DOC103
]

for constraint in constraints:
if not all(p in param_names_all for p in constraint.parameters):
if not all(p in param_names_all for p in constraint._required_parameters):
raise ValueError(
f"You are trying to create a constraint with at least one parameter "
f"name that does not exist in the list of defined parameters. "
f"Parameter list of the affected constraint: {constraint.parameters}"
f"Parameter list of the affected constraint: "
f"{constraint._required_parameters}"
)

if constraint.is_continuous and any(
p in param_names_discrete for p in constraint.parameters
p in param_names_discrete for p in constraint._required_parameters
):
raise ValueError(
f"You are trying to initialize a continuous constraint over a "
f"parameter that is discrete. Parameter list of the affected "
f"constraint: {constraint.parameters}"
f"constraint: {constraint._required_parameters}"
)

if constraint.is_discrete and any(
p in param_names_continuous for p in constraint.parameters
p in param_names_continuous for p in constraint._required_parameters
):
raise ValueError(
f"You are trying to initialize a discrete constraint over a parameter "
f"that is continuous. Parameter list of the affected constraint: "
f"{constraint.parameters}"
f"{constraint._required_parameters}"
)

if constraint.numerical_only and any(
Expand Down
Loading
Loading