From eb4545ae90d6973caff2a43651eb100a49009cf8 Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 16:31:51 +0200 Subject: [PATCH 01/30] Add narwhal lazyframe converter --- baybe/utils/dataframe.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 84b831a406..4a21ebd4c6 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -7,8 +7,10 @@ from collections.abc import Callable, Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload +import narwhals as nw import numpy as np import pandas as pd +from narwhals.typing import IntoDataFrame from typing_extensions import assert_never from baybe.exceptions import InputDataTypeWarning, SearchSpaceMatchWarning @@ -778,3 +780,32 @@ def needs_float_dtype(obj) -> bool: for col in cols_to_convert: df[col] = df[col].astype(active_settings.DTypeFloatNumpy) return df + + +def to_lazy_narwhals( + df: IntoDataFrame, +) -> nw.LazyFrame: + """Convert a native dataframe to a lazyframe, if it is not already a lazyframe. + + Args: + df: A dataframe in native format (e.g. pandas or polars) or already in narwhals + lazy format. + + Returns: + A lazy dataframe in narwhals format. + """ + return nw.from_native(df).lazy() + + +def from_lazy_narwhals( + ldf: nw.LazyFrame, +) -> IntoDataFrame: + """Convert a lazy dataframe to its native dataframe. + + Args: + ldf: A lazy dataframe + + Returns: + A dataframe in native format (e.g. pandas or polars) + """ + return ldf.collect().to_native() From 507a38b9b2e2576393f3254fa2a2033e723773ee Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 16:43:22 +0200 Subject: [PATCH 02/30] Add polars(pyarrow) and narwhals as hard dependencies Using version from beginnign of 2026 --- pyproject.toml | 2 ++ uv.lock | 12 ++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e5a21e70c1..8e5729caeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,10 @@ dependencies = [ "exceptiongroup", "gpytorch>=1.9.1,<2", "joblib>1.4.0,<2", + "narwhals>2.15.0", "numpy>=1.24.1,<3", "pandas>=1.4.2,<3", + "polars[pyarrow]>=0.19.19,<2", "scikit-learn>=1.1.1,<2", "scipy>=1.10.1", "torch>=1.13.1,<3", diff --git a/uv.lock b/uv.lock index 55018fb93b..97ae803590 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-03T08:31:57.919196Z" +exclude-newer = "2026-05-06T14:42:46.862591Z" exclude-newer-span = "P7D" [[package]] @@ -207,9 +207,11 @@ dependencies = [ { name = "exceptiongroup" }, { name = "gpytorch" }, { name = "joblib" }, + { name = "narwhals" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pandas" }, + { name = "polars", extra = ["pyarrow"] }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -428,6 +430,7 @@ requires-dist = [ { name = "matplotlib", marker = "extra == 'examples'", specifier = ">=3.7.3,!=3.9.1" }, { name = "mypy", marker = "extra == 'mypy'", specifier = ">=1.19.1" }, { name = "myst-parser", marker = "extra == 'docs'", specifier = ">=4.0.0" }, + { name = "narwhals", specifier = ">2.15.0" }, { name = "ngboost", marker = "extra == 'extras'", specifier = ">=0.3.12,<1" }, { name = "numpy", specifier = ">=1.24.1,<3" }, { name = "onnx", marker = "extra == 'onnx'", specifier = ">=1.16.0" }, @@ -439,6 +442,7 @@ requires-dist = [ { name = "pillow", marker = "extra == 'examples'", specifier = ">=10.0.1" }, { name = "pip-audit", marker = "extra == 'dev'", specifier = ">=2.5.5" }, { name = "plotly", marker = "extra == 'examples'", specifier = ">=5.10.0" }, + { name = "polars", extras = ["pyarrow"], specifier = ">=0.19.19,<2" }, { name = "polars", extras = ["pyarrow"], marker = "extra == 'polars'", specifier = ">=0.19.19,<2" }, { name = "pre-commit", marker = "extra == 'lint'", specifier = "==4.2.0" }, { name = "psutil", marker = "extra == 'benchmarking'", specifier = ">=7.0.0" }, @@ -3082,11 +3086,11 @@ wheels = [ [[package]] name = "narwhals" -version = "2.15.0" +version = "2.20.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/6d/b57c64e5038a8cf071bce391bb11551657a74558877ac961e7fa905ece27/narwhals-2.15.0.tar.gz", hash = "sha256:a9585975b99d95084268445a1fdd881311fa26ef1caa18020d959d5b2ff9a965", size = 603479, upload-time = "2026-01-06T08:10:13.27Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/f3/257adc69a71011b4c8cda321b00f02c5bf1980ae38ffd05a58d9632d4de8/narwhals-2.20.0.tar.gz", hash = "sha256:c10994975fa7dc5a68c2cffcddbd5908fc8ebb2d463c5bab085309c0ee1f551e", size = 627848, upload-time = "2026-04-20T12:11:45.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/2e/cf2ffeb386ac3763526151163ad7da9f1b586aac96d2b4f7de1eaebf0c61/narwhals-2.15.0-py3-none-any.whl", hash = "sha256:cbfe21ca19d260d9fd67f995ec75c44592d1f106933b03ddd375df7ac841f9d6", size = 432856, upload-time = "2026-01-06T08:10:11.511Z" }, + { url = "https://files.pythonhosted.org/packages/d0/69/f24d3d1c38ad69e256138b4ec2452a8c7cf66be49dc214771ae99dd4f0a0/narwhals-2.20.0-py3-none-any.whl", hash = "sha256:16e750ea5507d4ba6e8d03455b5f93a535e0405976561baea235bca5dc9f475d", size = 449373, upload-time = "2026-04-20T12:11:43.596Z" }, ] [[package]] From f499cbb865bf6247be21fd300ae4c1ca0f4feb45 Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 16:50:35 +0200 Subject: [PATCH 03/30] Add InfiniteSpaceError --- baybe/exceptions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/baybe/exceptions.py b/baybe/exceptions.py index 0be2273341..5221ba479a 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -179,5 +179,9 @@ class UnsupportedEarlyFilteringError(Exception): """A constraint does not support early filtering with the given parameters.""" +class InfiniteSpaceError(Exception): + """An operation requires an enumerable search space, but the space is infinite.""" + + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() From d920ae261fbe24f91faddef49256dc7a7bd0d284 Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 16:57:51 +0200 Subject: [PATCH 04/30] Add CandidateProtocol as well as TabelCandidates and ProductCandidates --- baybe/searchspace/candidates.py | 159 ++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 baybe/searchspace/candidates.py diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py new file mode 100644 index 0000000000..aee1f3ef1d --- /dev/null +++ b/baybe/searchspace/candidates.py @@ -0,0 +1,159 @@ +"""Candidates module for managing the lazy candidate generation.""" + +from typing import Protocol + +import narwhals as nw +from attr.validators import deep_iterable, instance_of, min_len +from attrs import define, field +from typing_extensions import override + +from baybe.constraints import DISCRETE_CONSTRAINTS_FILTERING_ORDER, validate_constraints +from baybe.constraints.base import DiscreteConstraint +from baybe.exceptions import InfiniteSpaceError +from baybe.parameters.base import DiscreteParameter +from baybe.searchspace.utils import build_constrained_product +from baybe.utils.basic import to_tuple +from baybe.utils.dataframe import to_lazy_narwhals +from baybe.utils.validation import validate_parameter_input + + +class CandidatesProtocol(Protocol): + """Type protocol specifying the interface for Candidates to implement.""" + + parameters: tuple[DiscreteParameter, ...] = field( + converter=to_tuple, + validator=[ + min_len(1), + deep_iterable( + member_validator=instance_of(DiscreteParameter), + ), + ], + ) + """ + The parameters that define the search space for which candidates are generated. + """ + + @property + def is_finite(self) -> bool: + """Define whether the candidate set is finite or infinite. + + Returns: + Whether the candidate set is finite. + """ + + def to_lazy_candidates(self) -> nw.LazyFrame: + """Generate the candidates from the given parameters and constraints. + + Returns: + The candidates as a lazy dataframe. + """ + + +@define(frozen=True) +class ProductCandidates(CandidatesProtocol): + """Class for managing product candidates. + + The candidates are generated by calculating the cartesian product of the parameter + values and applying constraints if available. + """ + + parameters: tuple[DiscreteParameter, ...] = field( + converter=to_tuple, + validator=[ + min_len(1), + deep_iterable( + member_validator=instance_of(DiscreteParameter), + ), + ], + ) + + constraints: tuple[DiscreteConstraint, ...] = field( + converter=lambda x: ( + to_tuple( + sorted( + x, + key=lambda c: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index( + c.__class__ + ), + ) + ) + if x is not None + else () + ), + factory=tuple, + validator=deep_iterable( + member_validator=instance_of(DiscreteConstraint), + ), + ) + """The constraints to apply to the cartesian product of the parameter values.""" + + @override + @property + def is_finite(self) -> bool: + # TODO: Need a property for each DiscreteParameter to check if it's finite then + # replace here with: return all(p.is_finite for p in self.parameters) + return True + + @override + def to_lazy_candidates(self) -> nw.LazyFrame: + """Create a lazy data frame to represent candidates. + + Calculate the cartesian product of the parameter values and apply the + constraints. + + Raises: + InfiniteSpaceError: If the search space is infinite. + + Returns: + The finite set of candidates as a lazy dataframe. + """ + if not self.is_finite: + raise InfiniteSpaceError( + "Cannot create candidates for an infinite search space." + ) + + if len(self.constraints) >= 1: + validate_constraints(self.constraints, self.parameters) + + candidates_df = build_constrained_product(self.parameters, self.constraints) + # TODO: Remove to lazy once build_constrained_product returns a nw.LazyFrame + return to_lazy_narwhals(candidates_df) + + +@define(frozen=True) +class TableCandidates(CandidatesProtocol): + """Class for managing candidates provided as a table directly.""" + + parameters: tuple[DiscreteParameter, ...] = field( + converter=to_tuple, + validator=[ + min_len(1), + deep_iterable( + member_validator=instance_of(DiscreteParameter), + ), + ], + ) + + dataframe: nw.LazyFrame = field( + validator=instance_of(nw.LazyFrame), + converter=to_lazy_narwhals, + ) + """The dataframe containing the candidates.""" + + def __attrs_post_init__(self): + # TODO: Remove .collect().to_pandas() once we supports validation on lazy frames + validate_parameter_input(self.dataframe.collect().to_pandas(), self.parameters) + + @override + @property + def is_finite(self) -> bool: + return True + + @override + def to_lazy_candidates(self) -> nw.LazyFrame: + """Return the candidates as a lazy dataframe. + + Returns: + The finite set of candidates as a lazy dataframe. + """ + return self.dataframe From 2c724e20ca7d60f41e3b96b94baf5882a857e246 Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 17:10:18 +0200 Subject: [PATCH 05/30] Add tests for ProductCandidates and TableCandidates --- tests/test_candidates.py | 268 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 tests/test_candidates.py diff --git a/tests/test_candidates.py b/tests/test_candidates.py new file mode 100644 index 0000000000..01b26d020b --- /dev/null +++ b/tests/test_candidates.py @@ -0,0 +1,268 @@ +"""Tests for the Candidate classes.""" + +import narwhals as nw +import pandas as pd +import polars as pl +import pytest + +from baybe.constraints import DiscreteSumConstraint, ThresholdCondition +from baybe.parameters import ( + CategoricalParameter, + NumericalContinuousParameter, + NumericalDiscreteParameter, +) +from baybe.searchspace.candidates import ProductCandidates, TableCandidates +from baybe.utils.interval import Interval + + +@pytest.mark.parametrize( + "parameter_names", + [ + ["Num_disc_1", "Num_disc_2"], + ["Categorical_1", "Num_disc_1"], + ["Categorical_1", "Categorical_2"], + ], + ids=["numerical", "mixed", "categorical"], +) +@pytest.mark.parametrize( + "dataframe_factory", + [ + pytest.param(lambda pd_df: pd_df, id="pandas_eager"), + pytest.param(pl.DataFrame, id="polars_eager"), + pytest.param( + pl.LazyFrame, + id="polars_lazy", + ), + pytest.param( + nw.from_native, + id="narwhals_eager", + ), + pytest.param( + lambda pd_df: nw.from_native(pd_df).lazy(), + id="narwhals_lazy", + ), + ], +) +@pytest.mark.parametrize( + "batch_size", + [16], + ids=["b16"], +) +def test_table_candidates_creation(parameters, dataframe_factory, fake_measurements): + """TableCandidates can be created with valid parameters and the compatible data.""" + df = dataframe_factory(fake_measurements) + candidates = TableCandidates(parameters=tuple(parameters), dataframe=df) + candidates_ldf = candidates.to_lazy_candidates() + + if isinstance(df, pl.LazyFrame) or isinstance(df, nw.LazyFrame): + df_shape = df.collect().shape + else: + df_shape = df.shape + assert isinstance(candidates_ldf, nw.LazyFrame) + assert all([p.name in candidates_ldf.collect().columns for p in parameters]) + assert candidates_ldf.collect().shape == df_shape + + +@pytest.mark.parametrize( + ("parameters", "dataframe"), + [ + pytest.param([1], pd.DataFrame({"x": [1, 2, 3]}), id="invalid_parameter_input"), + pytest.param( + NumericalDiscreteParameter( + name="Num_disc_1", + values=(1, 2, 7), + tolerance=0.3, + ), + pd.DataFrame({"Num_disc_1": [1, 2, 3]}), + id="parameter_not_a_sequence", + ), + pytest.param( + [ + NumericalDiscreteParameter( + name="Num_disc_1", + values=(1, 2, 7), + tolerance=0.3, + ) + ], + pd.DataFrame({"x": [1, 2, 3]}), + id="unmatched_dataframe_columns", + ), + pytest.param( + NumericalContinuousParameter( + name="Conti_finite1", + bounds=Interval(0, 1), + ), + pd.DataFrame({"x": [1, 2, 3]}), + id="invalid_parameter_type_continuous", + ), + pytest.param([], pd.DataFrame({"x": [1, 2, 3]}), id="empty_parameter"), + pytest.param(None, pd.DataFrame({"x": [1, 2, 3]}), id="none_parameter"), + pytest.param( + [ + NumericalDiscreteParameter( + name="Num_disc_1", values=(1, 2, 7), tolerance=0.3 + ) + ], + 123, # Not a DataFrame or compatible type + id="invalid_dataframe_type", + ), + ], +) +def test_table_candidates_invalid_input(parameters, dataframe): + """Invalid parameter and dataframe inputs raise appropriate errors.""" + with pytest.raises((TypeError, ValueError)): + TableCandidates(parameters=parameters, dataframe=dataframe) + + +@pytest.mark.parametrize( + "parameter_names", + [ + ["Num_disc_1", "Num_disc_2", "Fraction_2"], + ["Categorical_1", "Num_disc_1"], + ["Categorical_1", "Categorical_2", "Categorical_1_subset"], + ], + ids=["numerical", "mixed", "categorical"], +) +@pytest.mark.parametrize( + "constraint_names", + [ + [], + ["DiscreteSumConstraint"], + ["DiscreteExcludeConstraint"], + ], + ids=["no_constraint", "sum", "exclude"], +) +def test_product_candidates_creation(parameters, constraints): + """ProductCandidates can be created with valid parameters and constraints.""" + candidates = ProductCandidates(parameters=parameters, constraints=constraints) + lazy_candidates = candidates.to_lazy_candidates() + assert isinstance(lazy_candidates, nw.LazyFrame) + for p in parameters: + assert p.name in lazy_candidates.columns + assert candidates.is_finite + assert len(lazy_candidates.collect()) + + ProductCandidates(parameters=parameters, constraints=None) + + +@pytest.mark.parametrize( + ("parameters", "constraints"), + [ + pytest.param([1], ["DiscreteSumConstraint"], id="invalid_parameter_input"), + pytest.param( + NumericalContinuousParameter( + name="Conti_finite1", + bounds=Interval(0, 1), + ), + ["DiscreteSumConstraint"], + id="invalid_parameter_type", + ), + pytest.param([], ["DiscreteSumConstraint"], id="empty_parameter"), + pytest.param(None, ["DiscreteSumConstraint"], id="none_parameter"), + ], +) +def test_product_candidates_invalid_input(parameters, constraints): + """Invalid parameter and constraint inputs raise appropriate errors.""" + with pytest.raises((TypeError, ValueError)): + ProductCandidates(parameters=parameters, constraints=constraints) + + +@pytest.mark.parametrize( + "parameters,expected", + [ + pytest.param( + ( + NumericalDiscreteParameter(name="x", values=(1, 2)), + NumericalDiscreteParameter(name="y", values=(10, 20, 30)), + ), + {(x, y) for x in (1, 2) for y in (10, 20, 30)}, + id="numerical_numerical", + ), + pytest.param( + ( + NumericalDiscreteParameter(name="x", values=(1, 2)), + CategoricalParameter(name="cat", values=("a", "b")), + ), + {(x, c) for x in (1, 2) for c in ("a", "b")}, + id="numerical_categorical", + ), + pytest.param( + ( + CategoricalParameter(name="cat1", values=("a", "b")), + CategoricalParameter(name="cat2", values=("c", "d")), + CategoricalParameter(name="cat3", values=("e", "f", "g")), + ), + { + (c1, c2, c3) + for c1 in ("a", "b") + for c2 in ("c", "d") + for c3 in ("e", "f", "g") + }, + id="categorical_categorical", + ), + ], +) +def test_product_candidates_cartesian_product(parameters, expected): + """ProductCandidates builds the correct cartesian product.""" + candidates = ProductCandidates(parameters=parameters) + df = candidates.to_lazy_candidates().collect() + assert df.shape[0] == len(expected) + actual = {tuple(row) for row in df[[p.name for p in parameters]].to_numpy()} + assert actual == expected + + +@pytest.mark.parametrize( + ("parameters", "constraints", "expected_combinations"), + [ + pytest.param( + ( + NumericalDiscreteParameter(name="A", values=(1, 2)), + NumericalDiscreteParameter(name="B", values=(1, 2)), + ), + [ + DiscreteSumConstraint( + parameters=["A", "B"], + condition=ThresholdCondition(threshold=3, operator="="), + ) + ], + {(1, 2), (2, 1)}, + id="sum_equals_3", + ), + pytest.param( + ( + NumericalDiscreteParameter(name="A", values=(1, 2, 3)), + NumericalDiscreteParameter(name="B", values=(1, 2, 3)), + NumericalDiscreteParameter(name="C", values=(1, 2, 3)), + ), + [ + DiscreteSumConstraint( + parameters=["A", "B", "C"], + condition=ThresholdCondition(threshold=6, operator="<"), + ), + DiscreteSumConstraint( + parameters=["A", "B", "C"], + condition=ThresholdCondition(threshold=4, operator=">="), + ), + ], + { + (1, 2, 1), + (2, 1, 1), + (1, 1, 2), + (2, 2, 1), + (2, 1, 2), + (1, 2, 2), + (3, 1, 1), + (1, 3, 1), + (1, 1, 3), + }, + id="sum_between_4_and_6", + ), + ], +) +def test_constraints_product_candidates(parameters, constraints, expected_combinations): + """The constraints are applied correctly in ProductCandidates.to_lazy_candidates.""" + p_names = [p.name for p in parameters] + candidates = ProductCandidates(parameters=parameters, constraints=constraints) + df = candidates.to_lazy_candidates().collect() + assert {tuple(row) for row in df[p_names].to_numpy()} == expected_combinations + assert df.shape[0] == len(expected_combinations) From 22aa41b6506331ca598a6b64488946edf4bf0d8b Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 20:44:59 +0200 Subject: [PATCH 06/30] Add parameter name check to Candidates --- baybe/searchspace/candidates.py | 11 ++++++++--- tests/test_candidates.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index aee1f3ef1d..7e956c135e 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -11,7 +11,9 @@ from baybe.constraints.base import DiscreteConstraint from baybe.exceptions import InfiniteSpaceError from baybe.parameters.base import DiscreteParameter +from baybe.parameters.utils import sort_parameters from baybe.searchspace.utils import build_constrained_product +from baybe.searchspace.validation import validate_parameter_names from baybe.utils.basic import to_tuple from baybe.utils.dataframe import to_lazy_narwhals from baybe.utils.validation import validate_parameter_input @@ -21,8 +23,9 @@ class CandidatesProtocol(Protocol): """Type protocol specifying the interface for Candidates to implement.""" parameters: tuple[DiscreteParameter, ...] = field( - converter=to_tuple, + converter=sort_parameters, validator=[ + lambda _, __, x: validate_parameter_names(x), min_len(1), deep_iterable( member_validator=instance_of(DiscreteParameter), @@ -58,8 +61,9 @@ class ProductCandidates(CandidatesProtocol): """ parameters: tuple[DiscreteParameter, ...] = field( - converter=to_tuple, + converter=sort_parameters, validator=[ + lambda _, __, x: validate_parameter_names(x), min_len(1), deep_iterable( member_validator=instance_of(DiscreteParameter), @@ -125,8 +129,9 @@ class TableCandidates(CandidatesProtocol): """Class for managing candidates provided as a table directly.""" parameters: tuple[DiscreteParameter, ...] = field( - converter=to_tuple, + converter=sort_parameters, validator=[ + lambda _, __, x: validate_parameter_names(x), min_len(1), deep_iterable( member_validator=instance_of(DiscreteParameter), diff --git a/tests/test_candidates.py b/tests/test_candidates.py index 01b26d020b..d580aa45f8 100644 --- a/tests/test_candidates.py +++ b/tests/test_candidates.py @@ -110,7 +110,7 @@ def test_table_candidates_creation(parameters, dataframe_factory, fake_measureme ) def test_table_candidates_invalid_input(parameters, dataframe): """Invalid parameter and dataframe inputs raise appropriate errors.""" - with pytest.raises((TypeError, ValueError)): + with pytest.raises((TypeError, ValueError, AttributeError)): TableCandidates(parameters=parameters, dataframe=dataframe) @@ -163,7 +163,7 @@ def test_product_candidates_creation(parameters, constraints): ) def test_product_candidates_invalid_input(parameters, constraints): """Invalid parameter and constraint inputs raise appropriate errors.""" - with pytest.raises((TypeError, ValueError)): + with pytest.raises((TypeError, ValueError, AttributeError)): ProductCandidates(parameters=parameters, constraints=constraints) From 92af0ca4acc800e137e25b87c71ec00effa1d07d Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 20:53:14 +0200 Subject: [PATCH 07/30] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e669f82e0..67e22a31c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Coding convention instructions for agentic developers (`AGENTS.md`, `CLAUDE.md`) - `has_polars_implementation` property on `DiscreteConstraint` - `allow_missing` flag on `DiscreteConstraint.get_invalid` and `get_valid` +- `polars[pyarrow]` and `narwhals` as hard dependencies +- `CandidateProtocol` as a base protocol for candidates handling +- `TableCandidates` and `ProductCandidates` classes implementing `CandidateProtocol` ### Breaking Changes - `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved From fc19bc49865d3b299e6025f1e72ed78ea7559e67 Mon Sep 17 00:00:00 2001 From: Myra Zmarsly Date: Wed, 13 May 2026 21:01:27 +0200 Subject: [PATCH 08/30] Fix typo --- baybe/searchspace/candidates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 7e956c135e..56774e6d56 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -146,7 +146,7 @@ class TableCandidates(CandidatesProtocol): """The dataframe containing the candidates.""" def __attrs_post_init__(self): - # TODO: Remove .collect().to_pandas() once we supports validation on lazy frames + # TODO: Remove collect().to_pandas() once validation on lazy frames is supported validate_parameter_input(self.dataframe.collect().to_pandas(), self.parameters) @override From 4a43d8b673607eb8e7b5a752961b6c8fbf8a5cb6 Mon Sep 17 00:00:00 2001 From: Myra Date: Wed, 13 May 2026 21:10:03 +0200 Subject: [PATCH 09/30] Remove attribute validation from CandidateProtocol Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- baybe/searchspace/candidates.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 56774e6d56..a8899ffa76 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -22,16 +22,7 @@ class CandidatesProtocol(Protocol): """Type protocol specifying the interface for Candidates to implement.""" - parameters: tuple[DiscreteParameter, ...] = field( - converter=sort_parameters, - validator=[ - lambda _, __, x: validate_parameter_names(x), - min_len(1), - deep_iterable( - member_validator=instance_of(DiscreteParameter), - ), - ], - ) + parameters: tuple[DiscreteParameter, ...] """ The parameters that define the search space for which candidates are generated. """ From eacd434b87439d7d00ed5049a99cd0f6bfac6d57 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 10:02:41 +0200 Subject: [PATCH 10/30] Fix terminology: enumerable -> finite --- baybe/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/exceptions.py b/baybe/exceptions.py index 5221ba479a..c049d414b8 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -180,7 +180,7 @@ class UnsupportedEarlyFilteringError(Exception): class InfiniteSpaceError(Exception): - """An operation requires an enumerable search space, but the space is infinite.""" + """An operation requires a finite search space but the space is infinite.""" # Collect leftover original slotted classes processed by `attrs.define` From 6cb0e34188c6c75225fd232afa586ffc3d5095a3 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 10:06:03 +0200 Subject: [PATCH 11/30] Remove polars as hard dependency --- CHANGELOG.md | 2 +- pyproject.toml | 1 - uv.lock | 4 +--- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 67e22a31c5..e5e2bc7a2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Coding convention instructions for agentic developers (`AGENTS.md`, `CLAUDE.md`) - `has_polars_implementation` property on `DiscreteConstraint` - `allow_missing` flag on `DiscreteConstraint.get_invalid` and `get_valid` -- `polars[pyarrow]` and `narwhals` as hard dependencies +- `narwhals` as hard dependencies - `CandidateProtocol` as a base protocol for candidates handling - `TableCandidates` and `ProductCandidates` classes implementing `CandidateProtocol` diff --git a/pyproject.toml b/pyproject.toml index 8e5729caeb..8e9f42f744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ dependencies = [ "narwhals>2.15.0", "numpy>=1.24.1,<3", "pandas>=1.4.2,<3", - "polars[pyarrow]>=0.19.19,<2", "scikit-learn>=1.1.1,<2", "scipy>=1.10.1", "torch>=1.13.1,<3", diff --git a/uv.lock b/uv.lock index 97ae803590..42160d77a9 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-06T14:42:46.862591Z" +exclude-newer = "2026-05-14T08:04:57.929429Z" exclude-newer-span = "P7D" [[package]] @@ -211,7 +211,6 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pandas" }, - { name = "polars", extra = ["pyarrow"] }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -442,7 +441,6 @@ requires-dist = [ { name = "pillow", marker = "extra == 'examples'", specifier = ">=10.0.1" }, { name = "pip-audit", marker = "extra == 'dev'", specifier = ">=2.5.5" }, { name = "plotly", marker = "extra == 'examples'", specifier = ">=5.10.0" }, - { name = "polars", extras = ["pyarrow"], specifier = ">=0.19.19,<2" }, { name = "polars", extras = ["pyarrow"], marker = "extra == 'polars'", specifier = ">=0.19.19,<2" }, { name = "pre-commit", marker = "extra == 'lint'", specifier = "==4.2.0" }, { name = "psutil", marker = "extra == 'benchmarking'", specifier = ">=7.0.0" }, From 8be86473ea47a14c97e2a77fcf21f0ba68303775 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 10:07:54 +0200 Subject: [PATCH 12/30] Adjust narwhals version constraints --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e9f42f744..01f2826e5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "exceptiongroup", "gpytorch>=1.9.1,<2", "joblib>1.4.0,<2", - "narwhals>2.15.0", + "narwhals>=2,<3", "numpy>=1.24.1,<3", "pandas>=1.4.2,<3", "scikit-learn>=1.1.1,<2", diff --git a/uv.lock b/uv.lock index 42160d77a9..fec55f622c 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-14T08:04:57.929429Z" +exclude-newer = "2026-05-14T08:07:50.351631Z" exclude-newer-span = "P7D" [[package]] @@ -429,7 +429,7 @@ requires-dist = [ { name = "matplotlib", marker = "extra == 'examples'", specifier = ">=3.7.3,!=3.9.1" }, { name = "mypy", marker = "extra == 'mypy'", specifier = ">=1.19.1" }, { name = "myst-parser", marker = "extra == 'docs'", specifier = ">=4.0.0" }, - { name = "narwhals", specifier = ">2.15.0" }, + { name = "narwhals", specifier = ">=2,<3" }, { name = "ngboost", marker = "extra == 'extras'", specifier = ">=0.3.12,<1" }, { name = "numpy", specifier = ">=1.24.1,<3" }, { name = "onnx", marker = "extra == 'onnx'", specifier = ">=1.16.0" }, From 2eacc751db5956c82ec3db941640c46d82603d2a Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 10:10:59 +0200 Subject: [PATCH 13/30] Use narwhals stable.v2 namespace --- baybe/searchspace/candidates.py | 2 +- baybe/utils/dataframe.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index a8899ffa76..2b2f414886 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -2,7 +2,7 @@ from typing import Protocol -import narwhals as nw +import narwhals.stable.v2 as nw from attr.validators import deep_iterable, instance_of, min_len from attrs import define, field from typing_extensions import override diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 4a21ebd4c6..88a9efa016 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -7,10 +7,10 @@ from collections.abc import Callable, Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload -import narwhals as nw +import narwhals.stable.v2 as nw import numpy as np import pandas as pd -from narwhals.typing import IntoDataFrame +from narwhals.stable.v2.typing import IntoDataFrame from typing_extensions import assert_never from baybe.exceptions import InputDataTypeWarning, SearchSpaceMatchWarning From 053c16a00fb3a57b88641dd33b537408d5a1b73f Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 10:18:43 +0200 Subject: [PATCH 14/30] Turn protocol (class) attribute into property --- baybe/searchspace/candidates.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 2b2f414886..caec012f98 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -22,10 +22,9 @@ class CandidatesProtocol(Protocol): """Type protocol specifying the interface for Candidates to implement.""" - parameters: tuple[DiscreteParameter, ...] - """ - The parameters that define the search space for which candidates are generated. - """ + @property + def parameters(self) -> tuple[DiscreteParameter, ...]: + """The parameters that define the search space for which candidates are generated.""" # noqa: E501 @property def is_finite(self) -> bool: From 05b3ad3fd22f2d7e0f9c7850a37e4e817a1f3068 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 10:57:02 +0200 Subject: [PATCH 15/30] Refine attrs coding conventions in AGENTS.md --- AGENTS.md | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index d44c1cc9ad..5e11c557bd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -58,10 +58,9 @@ More specific conventions for subdirectories: ### attrs Only All domain classes use `attrs` `@define`. No dataclasses, no Pydantic. - Immutable value objects (parameters, kernels, priors, transformations, objectives, - targets): `@define(frozen=True, slots=False)`. + targets): `@define(frozen=True)`. - Mutable stateful objects (campaign, surrogates, recommenders): `@define`. -- `slots=False` required with `frozen=True` when `cached_property` is needed. See - `attrs` issue #164 +- `slots=False` required `cached_property` is needed. See `attrs` issue #164 - Also use `slots=False` when monkeypatching is needed (e.g., `register_hooks`) ### Inheritance: ABC + SerialMixin + Protocol @@ -71,14 +70,25 @@ All domain classes use `attrs` `@define`. No dataclasses, no Pydantic. 3. Concrete classes: Inherit from ABC. ### Fields and Methods -- Use `field()` with `validator=`, `converter=`, `default=`, `factory=`, `alias=`. +- Use `field()` with arguments in this order: 1) `alias=` (if needed), 2) `init=` + (if needed), 3) `default=` / `factory=`, 4) `converter=`, 5) `validator=`. - Private fields: `_` prefix, typically `init=False`. - Store each piece of information once — no data duplication. - Use `attrs.evolve()` for modified copies of frozen objects. - Use `on_setattr` hooks for cache invalidation on mutable objects. +- Use `kw_only=True` deliberately: only when positional construction would be + ambiguous or error-prone (e.g., multiple fields of the same type, or + optional/secondary fields that should not be passed positionally). Do not + apply `kw_only` to all fields by default. - `ClassVar[bool]` for capability flags (`supports_transfer_learning`, etc.). -- Order class content like this: 1) Attributes, 2) validators and post_init, 3) - properties, 4) methods. Within each group use alphabetical order. +- Order class content like this: 1) Attributes, 2) default and validator methods, + 3) `__attrs_post_init__`, 4) properties, 5) methods. + - Attributes are ordered by functionality/importance (primary identity fields + first, optional/secondary fields last), not alphabetically. + - Default and validator methods mirror the attribute order. For a given + attribute, the default method (`_default_`) comes before its validator + (`_validate_`). + - Regular methods are ordered alphabetically. ### Attribute Docstrings String literals immediately below field declarations, blank lines between attributes. @@ -225,11 +235,24 @@ Three tiers: ## 11. Validation Patterns - Inline validators: `field(validator=(instance_of(str), min_len(1)))`, `in_()`, - `deep_iterable()`, custom `finite_float`, `gt()`. + `deep_iterable()`, custom `finite_float`, `gt()`. Order validators from simplest + to most complex: cheap structural checks (e.g., `min_len`, `instance_of`) before + expensive semantic checks (e.g., cross-field consistency, name uniqueness). - Method validators: `@_field.validator` with `# noqa: DOC101, DOC103` for validators needing `self` access. -- Cross-field: `__attrs_post_init__` when validation involves multiple fields. +- Cross-field: `__attrs_post_init__` is a last resort. Method validators + (`@field.validator`) already receive `self` and can read other already-set + attributes, so most cross-field checks belong there instead. When one field + must be compatible with another, attach the validator to the later field — + attrs sets fields in declaration order, so earlier fields are always available + via `self` at that point. When one attribute's value must be adjusted after + all fields are set — which is typically a workaround and should itself be + questioned — `__attrs_post_init__` is acceptable. - Converters: `field(converter=to_searchspace)` for automatic type coercion. +- If a converter already guarantees a specific type (e.g., `converter=list` + always produces a `list`, a custom converter always returns a known type), + omit any `instance_of(...)` validator for that same type — the check is + redundant. - Reusable validators in `baybe/utils/validation.py`: `finite_float`, `non_nan_float`, `non_inf_float`, `validate_not_nan`, `validate_target_input`, `validate_parameter_input`, `validate_object_names`. From 831aafec7477f51e1bcea6c7dd2bc4d27f0cb868 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 11:02:45 +0200 Subject: [PATCH 16/30] Add missing garbage collection step --- baybe/searchspace/candidates.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index caec012f98..29485daa6f 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -1,5 +1,6 @@ """Candidates module for managing the lazy candidate generation.""" +import gc from typing import Protocol import narwhals.stable.v2 as nw @@ -152,3 +153,7 @@ def to_lazy_candidates(self) -> nw.LazyFrame: The finite set of candidates as a lazy dataframe. """ return self.dataframe + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() From a4abf8fe7d47d1926ba84c263d67c152e9aea47b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 11:04:31 +0200 Subject: [PATCH 17/30] Fix attribute definitions --- baybe/searchspace/candidates.py | 36 +++++++++------------------------ tests/test_candidates.py | 2 +- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 29485daa6f..32cd6ffa3d 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -54,31 +54,20 @@ class ProductCandidates(CandidatesProtocol): parameters: tuple[DiscreteParameter, ...] = field( converter=sort_parameters, validator=[ - lambda _, __, x: validate_parameter_names(x), min_len(1), - deep_iterable( - member_validator=instance_of(DiscreteParameter), - ), + deep_iterable(member_validator=instance_of(DiscreteParameter)), + lambda _, __, x: validate_parameter_names(x), ], ) constraints: tuple[DiscreteConstraint, ...] = field( - converter=lambda x: ( - to_tuple( - sorted( - x, - key=lambda c: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index( - c.__class__ - ), - ) + default=(), + converter=lambda x: to_tuple( + sorted( + x, key=lambda c: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index(c.__class__) ) - if x is not None - else () - ), - factory=tuple, - validator=deep_iterable( - member_validator=instance_of(DiscreteConstraint), ), + validator=deep_iterable(member_validator=instance_of(DiscreteConstraint)), ) """The constraints to apply to the cartesian product of the parameter values.""" @@ -122,18 +111,13 @@ class TableCandidates(CandidatesProtocol): parameters: tuple[DiscreteParameter, ...] = field( converter=sort_parameters, validator=[ - lambda _, __, x: validate_parameter_names(x), min_len(1), - deep_iterable( - member_validator=instance_of(DiscreteParameter), - ), + deep_iterable(member_validator=instance_of(DiscreteParameter)), + lambda _, __, x: validate_parameter_names(x), ], ) - dataframe: nw.LazyFrame = field( - validator=instance_of(nw.LazyFrame), - converter=to_lazy_narwhals, - ) + dataframe: nw.LazyFrame = field(converter=to_lazy_narwhals) """The dataframe containing the candidates.""" def __attrs_post_init__(self): diff --git a/tests/test_candidates.py b/tests/test_candidates.py index d580aa45f8..b34ad3c9a6 100644 --- a/tests/test_candidates.py +++ b/tests/test_candidates.py @@ -142,7 +142,7 @@ def test_product_candidates_creation(parameters, constraints): assert candidates.is_finite assert len(lazy_candidates.collect()) - ProductCandidates(parameters=parameters, constraints=None) + ProductCandidates(parameters=parameters) @pytest.mark.parametrize( From dda34fff20fcfc9bd57937cc29de67db529d971e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 11:15:56 +0200 Subject: [PATCH 18/30] Drop unnecessary __attrs_post_init__ --- baybe/searchspace/candidates.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 32cd6ffa3d..49eb220e9b 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -5,7 +5,7 @@ import narwhals.stable.v2 as nw from attr.validators import deep_iterable, instance_of, min_len -from attrs import define, field +from attrs import Attribute, define, field from typing_extensions import override from baybe.constraints import DISCRETE_CONSTRAINTS_FILTERING_ORDER, validate_constraints @@ -120,13 +120,15 @@ class TableCandidates(CandidatesProtocol): dataframe: nw.LazyFrame = field(converter=to_lazy_narwhals) """The dataframe containing the candidates.""" - def __attrs_post_init__(self): + @dataframe.validator + def _validate_dataframe(self, _: Attribute, value: nw.LazyFrame) -> None: # noqa: DOC101, DOC103 # TODO: Remove collect().to_pandas() once validation on lazy frames is supported - validate_parameter_input(self.dataframe.collect().to_pandas(), self.parameters) + validate_parameter_input(value.collect().to_pandas(), self.parameters) @override @property def is_finite(self) -> bool: + """Whether the candidate set is finite.""" return True @override From adabdd4077da544dff1267c8a6ce585af0792b2c Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 11:26:21 +0200 Subject: [PATCH 19/30] Rework candidate module docstrings --- baybe/searchspace/candidates.py | 49 ++++++++------------------------- 1 file changed, 11 insertions(+), 38 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 49eb220e9b..500d37a6b7 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -1,4 +1,4 @@ -"""Candidates module for managing the lazy candidate generation.""" +"""Candidates module for managing lazy candidate generation.""" import gc from typing import Protocol @@ -21,35 +21,23 @@ class CandidatesProtocol(Protocol): - """Type protocol specifying the interface for Candidates to implement.""" + """Type protocol specifying the interface candidate generators need to implement.""" @property def parameters(self) -> tuple[DiscreteParameter, ...]: - """The parameters that define the search space for which candidates are generated.""" # noqa: E501 + """The parameters spanning the space from which candidates are generated.""" @property def is_finite(self) -> bool: - """Define whether the candidate set is finite or infinite. - - Returns: - Whether the candidate set is finite. - """ + """Indicates whether the candidate set is finite or infinite.""" def to_lazy_candidates(self) -> nw.LazyFrame: - """Generate the candidates from the given parameters and constraints. - - Returns: - The candidates as a lazy dataframe. - """ + """Generate all candidates.""" @define(frozen=True) class ProductCandidates(CandidatesProtocol): - """Class for managing product candidates. - - The candidates are generated by calculating the cartesian product of the parameter - values and applying constraints if available. - """ + """Class for managing candidates from (filtered) Cartesian product spaces.""" parameters: tuple[DiscreteParameter, ...] = field( converter=sort_parameters, @@ -59,6 +47,7 @@ class ProductCandidates(CandidatesProtocol): lambda _, __, x: validate_parameter_names(x), ], ) + """See :attr:`CandidatesProtocol.parameters`.""" constraints: tuple[DiscreteConstraint, ...] = field( default=(), @@ -69,7 +58,7 @@ class ProductCandidates(CandidatesProtocol): ), validator=deep_iterable(member_validator=instance_of(DiscreteConstraint)), ) - """The constraints to apply to the cartesian product of the parameter values.""" + """Constraints to filter the Cartesian product of parameter values.""" @override @property @@ -80,20 +69,9 @@ def is_finite(self) -> bool: @override def to_lazy_candidates(self) -> nw.LazyFrame: - """Create a lazy data frame to represent candidates. - - Calculate the cartesian product of the parameter values and apply the - constraints. - - Raises: - InfiniteSpaceError: If the search space is infinite. - - Returns: - The finite set of candidates as a lazy dataframe. - """ if not self.is_finite: raise InfiniteSpaceError( - "Cannot create candidates for an infinite search space." + "Cannot generate all candidates from an infinite space." ) if len(self.constraints) >= 1: @@ -106,7 +84,7 @@ def to_lazy_candidates(self) -> nw.LazyFrame: @define(frozen=True) class TableCandidates(CandidatesProtocol): - """Class for managing candidates provided as a table directly.""" + """Class for managing candidates provided in a tabular format.""" parameters: tuple[DiscreteParameter, ...] = field( converter=sort_parameters, @@ -116,6 +94,7 @@ class TableCandidates(CandidatesProtocol): lambda _, __, x: validate_parameter_names(x), ], ) + """See :attr:`CandidatesProtocol.parameters`.""" dataframe: nw.LazyFrame = field(converter=to_lazy_narwhals) """The dataframe containing the candidates.""" @@ -128,16 +107,10 @@ def _validate_dataframe(self, _: Attribute, value: nw.LazyFrame) -> None: # noq @override @property def is_finite(self) -> bool: - """Whether the candidate set is finite.""" return True @override def to_lazy_candidates(self) -> nw.LazyFrame: - """Return the candidates as a lazy dataframe. - - Returns: - The finite set of candidates as a lazy dataframe. - """ return self.dataframe From 2a7f4b9c81202f541716711ea7b97030d724b52d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 11:34:32 +0200 Subject: [PATCH 20/30] Turn delayed validation into eager validation --- baybe/searchspace/candidates.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 500d37a6b7..f24eef7f8b 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -60,6 +60,12 @@ class ProductCandidates(CandidatesProtocol): ) """Constraints to filter the Cartesian product of parameter values.""" + @constraints.validator + def _validate_constraints( + self, _: Attribute, value: tuple[DiscreteConstraint, ...] + ): # noqa: DOC101, DOC103 + validate_constraints(value, self.parameters) + @override @property def is_finite(self) -> bool: @@ -74,9 +80,6 @@ def to_lazy_candidates(self) -> nw.LazyFrame: "Cannot generate all candidates from an infinite space." ) - if len(self.constraints) >= 1: - validate_constraints(self.constraints, self.parameters) - candidates_df = build_constrained_product(self.parameters, self.constraints) # TODO: Remove to lazy once build_constrained_product returns a nw.LazyFrame return to_lazy_narwhals(candidates_df) From c5ebbba0a013e629cbb3fa31211fea391c78cb8e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 11:41:46 +0200 Subject: [PATCH 21/30] Add DiscreteParameter.is_finite --- CHANGELOG.md | 1 + baybe/parameters/base.py | 6 ++++++ baybe/searchspace/candidates.py | 4 +--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5e2bc7a2e..a85e33bc54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `narwhals` as hard dependencies - `CandidateProtocol` as a base protocol for candidates handling - `TableCandidates` and `ProductCandidates` classes implementing `CandidateProtocol` +- `DiscreteParameter.is_finite` property ### Breaking Changes - `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved diff --git a/baybe/parameters/base.py b/baybe/parameters/base.py index 2d4df2bc77..8b09a90435 100644 --- a/baybe/parameters/base.py +++ b/baybe/parameters/base.py @@ -116,6 +116,12 @@ class DiscreteParameter(Parameter, ABC): def values(self) -> tuple: """The values the parameter can take.""" + @property + def is_finite(self) -> bool: + """Indicates whether the parameter has a finite number of values.""" + len(self.values) # <-- raises an error if the parameter is infinite + return True + @property def active_values(self) -> tuple: """The values that are considered for recommendation.""" diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index f24eef7f8b..ac76b1c759 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -69,9 +69,7 @@ def _validate_constraints( @override @property def is_finite(self) -> bool: - # TODO: Need a property for each DiscreteParameter to check if it's finite then - # replace here with: return all(p.is_finite for p in self.parameters) - return True + return all(p.is_finite for p in self.parameters) @override def to_lazy_candidates(self) -> nw.LazyFrame: From 4605151fa34a21b23ac3985ac2354bcf2e8e54ef Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 13:15:31 +0200 Subject: [PATCH 22/30] Drop unused helper function --- baybe/utils/dataframe.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 88a9efa016..263dadf808 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -795,17 +795,3 @@ def to_lazy_narwhals( A lazy dataframe in narwhals format. """ return nw.from_native(df).lazy() - - -def from_lazy_narwhals( - ldf: nw.LazyFrame, -) -> IntoDataFrame: - """Convert a lazy dataframe to its native dataframe. - - Args: - ldf: A lazy dataframe - - Returns: - A dataframe in native format (e.g. pandas or polars) - """ - return ldf.collect().to_native() From 0a4b9b89b6778af72b07b41b2f7dbe921a1bad36 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 13:55:54 +0200 Subject: [PATCH 23/30] Adjust lazyframe conversion utility --- baybe/searchspace/candidates.py | 8 +++++--- baybe/utils/dataframe.py | 14 ++------------ docs/conf.py | 9 +++++---- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index ac76b1c759..2685e972a8 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -16,7 +16,7 @@ from baybe.searchspace.utils import build_constrained_product from baybe.searchspace.validation import validate_parameter_names from baybe.utils.basic import to_tuple -from baybe.utils.dataframe import to_lazy_narwhals +from baybe.utils.dataframe import to_lazy from baybe.utils.validation import validate_parameter_input @@ -79,8 +79,10 @@ def to_lazy_candidates(self) -> nw.LazyFrame: ) candidates_df = build_constrained_product(self.parameters, self.constraints) + # TODO: Remove to lazy once build_constrained_product returns a nw.LazyFrame - return to_lazy_narwhals(candidates_df) + assert not isinstance(candidates_df, nw.LazyFrame) + return to_lazy(candidates_df) @define(frozen=True) @@ -97,7 +99,7 @@ class TableCandidates(CandidatesProtocol): ) """See :attr:`CandidatesProtocol.parameters`.""" - dataframe: nw.LazyFrame = field(converter=to_lazy_narwhals) + dataframe: nw.LazyFrame = field(converter=to_lazy) """The dataframe containing the candidates.""" @dataframe.validator diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 263dadf808..221e3942dd 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -782,16 +782,6 @@ def needs_float_dtype(obj) -> bool: return df -def to_lazy_narwhals( - df: IntoDataFrame, -) -> nw.LazyFrame: - """Convert a native dataframe to a lazyframe, if it is not already a lazyframe. - - Args: - df: A dataframe in native format (e.g. pandas or polars) or already in narwhals - lazy format. - - Returns: - A lazy dataframe in narwhals format. - """ +def to_lazy(df: IntoDataFrame, /) -> nw.LazyFrame: + """Convert any dataframe to a :class:`~narwhals.LazyFrame`.""" return nw.from_native(df).lazy() diff --git a/docs/conf.py b/docs/conf.py index dceb732509..747467e7d6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -274,15 +274,16 @@ # Mappings to all external packages that we want to have clickable links to intersphinx_mapping = { "botorch": ("https://botorch.readthedocs.io/en/latest", None), - "python": ("https://docs.python.org/3", None), + "narwhals": ("https://narwhals-dev.github.io/narwhals/", None), + "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), "polars": ("https://docs.pola.rs/api/python/stable/", None), + "python": ("https://docs.python.org/3", None), + "rdkit": ("https://rdkit.org/docs/", None), + "shap": ("https://shap.readthedocs.io/en/stable/", None), "skfp": ("https://scikit-fingerprints.readthedocs.io/latest/", None), "sklearn": ("https://scikit-learn.org/stable/", None), - "numpy": ("https://numpy.org/doc/stable/", None), "torch": ("https://pytorch.org/docs/main/", None), - "rdkit": ("https://rdkit.org/docs/", None), - "shap": ("https://shap.readthedocs.io/en/stable/", None), "xyzpy": ("https://xyzpy.readthedocs.io/en/latest/", None), } From 10f8118acbe66bb5a2b3bef35f1faaee862ca67e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 13:56:54 +0200 Subject: [PATCH 24/30] Rename CandidatesProtocol.to_lazy_candidates to to_lazy --- baybe/searchspace/candidates.py | 6 +++--- tests/test_candidates.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 2685e972a8..3df07b3aad 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -31,7 +31,7 @@ def parameters(self) -> tuple[DiscreteParameter, ...]: def is_finite(self) -> bool: """Indicates whether the candidate set is finite or infinite.""" - def to_lazy_candidates(self) -> nw.LazyFrame: + def to_lazy(self) -> nw.LazyFrame: """Generate all candidates.""" @@ -72,7 +72,7 @@ def is_finite(self) -> bool: return all(p.is_finite for p in self.parameters) @override - def to_lazy_candidates(self) -> nw.LazyFrame: + def to_lazy(self) -> nw.LazyFrame: if not self.is_finite: raise InfiniteSpaceError( "Cannot generate all candidates from an infinite space." @@ -113,7 +113,7 @@ def is_finite(self) -> bool: return True @override - def to_lazy_candidates(self) -> nw.LazyFrame: + def to_lazy(self) -> nw.LazyFrame: return self.dataframe diff --git a/tests/test_candidates.py b/tests/test_candidates.py index b34ad3c9a6..d180c442b1 100644 --- a/tests/test_candidates.py +++ b/tests/test_candidates.py @@ -52,7 +52,7 @@ def test_table_candidates_creation(parameters, dataframe_factory, fake_measureme """TableCandidates can be created with valid parameters and the compatible data.""" df = dataframe_factory(fake_measurements) candidates = TableCandidates(parameters=tuple(parameters), dataframe=df) - candidates_ldf = candidates.to_lazy_candidates() + candidates_ldf = candidates.to_lazy() if isinstance(df, pl.LazyFrame) or isinstance(df, nw.LazyFrame): df_shape = df.collect().shape @@ -135,7 +135,7 @@ def test_table_candidates_invalid_input(parameters, dataframe): def test_product_candidates_creation(parameters, constraints): """ProductCandidates can be created with valid parameters and constraints.""" candidates = ProductCandidates(parameters=parameters, constraints=constraints) - lazy_candidates = candidates.to_lazy_candidates() + lazy_candidates = candidates.to_lazy() assert isinstance(lazy_candidates, nw.LazyFrame) for p in parameters: assert p.name in lazy_candidates.columns @@ -205,7 +205,7 @@ def test_product_candidates_invalid_input(parameters, constraints): def test_product_candidates_cartesian_product(parameters, expected): """ProductCandidates builds the correct cartesian product.""" candidates = ProductCandidates(parameters=parameters) - df = candidates.to_lazy_candidates().collect() + df = candidates.to_lazy().collect() assert df.shape[0] == len(expected) actual = {tuple(row) for row in df[[p.name for p in parameters]].to_numpy()} assert actual == expected @@ -263,6 +263,6 @@ def test_constraints_product_candidates(parameters, constraints, expected_combin """The constraints are applied correctly in ProductCandidates.to_lazy_candidates.""" p_names = [p.name for p in parameters] candidates = ProductCandidates(parameters=parameters, constraints=constraints) - df = candidates.to_lazy_candidates().collect() + df = candidates.to_lazy().collect() assert {tuple(row) for row in df[p_names].to_numpy()} == expected_combinations assert df.shape[0] == len(expected_combinations) From 17fc0991b050f4fa67ddb350e038c75ecd76d45e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 14:00:27 +0200 Subject: [PATCH 25/30] Fix changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a85e33bc54..c9c76b9dcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,8 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `has_polars_implementation` property on `DiscreteConstraint` - `allow_missing` flag on `DiscreteConstraint.get_invalid` and `get_valid` - `narwhals` as hard dependencies -- `CandidateProtocol` as a base protocol for candidates handling -- `TableCandidates` and `ProductCandidates` classes implementing `CandidateProtocol` +- `CandidatesProtocol` as an interface for candidates generation +- `TableCandidates` and `ProductCandidates` classes implementing `CandidatesProtocol` - `DiscreteParameter.is_finite` property ### Breaking Changes From 41580f0d19aa111dcff1b877e1b0ab332441d52e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 14:47:24 +0200 Subject: [PATCH 26/30] Refactor table candidates tests --- tests/test_candidates.py | 122 ++++++++++++++------------------------- 1 file changed, 42 insertions(+), 80 deletions(-) diff --git a/tests/test_candidates.py b/tests/test_candidates.py index d180c442b1..61929e75f2 100644 --- a/tests/test_candidates.py +++ b/tests/test_candidates.py @@ -1,9 +1,10 @@ -"""Tests for the Candidate classes.""" +"""Tests for candidate generators.""" import narwhals as nw import pandas as pd import polars as pl import pytest +from pandas.testing import assert_frame_equal from baybe.constraints import DiscreteSumConstraint, ThresholdCondition from baybe.parameters import ( @@ -12,106 +13,67 @@ NumericalDiscreteParameter, ) from baybe.searchspace.candidates import ProductCandidates, TableCandidates +from baybe.utils.dataframe import create_fake_input from baybe.utils.interval import Interval +p_disc = NumericalDiscreteParameter("disc", (1, 2, 7)) +p_cat = CategoricalParameter("cat", ("a", "b", "c")) +p_cont = NumericalContinuousParameter("cont", (3, 8)) +edf = pd.DataFrame() + -@pytest.mark.parametrize( - "parameter_names", - [ - ["Num_disc_1", "Num_disc_2"], - ["Categorical_1", "Num_disc_1"], - ["Categorical_1", "Categorical_2"], - ], - ids=["numerical", "mixed", "categorical"], -) @pytest.mark.parametrize( "dataframe_factory", [ pytest.param(lambda pd_df: pd_df, id="pandas_eager"), pytest.param(pl.DataFrame, id="polars_eager"), - pytest.param( - pl.LazyFrame, - id="polars_lazy", - ), - pytest.param( - nw.from_native, - id="narwhals_eager", - ), - pytest.param( - lambda pd_df: nw.from_native(pd_df).lazy(), - id="narwhals_lazy", - ), + pytest.param(pl.LazyFrame, id="polars_lazy"), + pytest.param(lambda x: nw.from_native(x, eager_only=True), id="narwhals_eager"), + pytest.param(lambda x: nw.from_native(x).lazy(), id="narwhals_lazy"), ], ) -@pytest.mark.parametrize( - "batch_size", - [16], - ids=["b16"], -) -def test_table_candidates_creation(parameters, dataframe_factory, fake_measurements): - """TableCandidates can be created with valid parameters and the compatible data.""" - df = dataframe_factory(fake_measurements) - candidates = TableCandidates(parameters=tuple(parameters), dataframe=df) +def test_table_candidates_generation(dataframe_factory): + """TableCandidates generates the expected lazy dataframe.""" + parameters = [p_disc, p_cat] + data = create_fake_input(parameters, [], n_rows=4) + df = dataframe_factory(data) + candidates = TableCandidates(parameters, df) candidates_ldf = candidates.to_lazy() + candidates_df = candidates_ldf.collect() - if isinstance(df, pl.LazyFrame) or isinstance(df, nw.LazyFrame): - df_shape = df.collect().shape - else: - df_shape = df.shape + assert candidates.is_finite assert isinstance(candidates_ldf, nw.LazyFrame) - assert all([p.name in candidates_ldf.collect().columns for p in parameters]) - assert candidates_ldf.collect().shape == df_shape + assert set(candidates_df.columns) == {p.name for p in parameters} + assert candidates_df.shape == data.shape + assert_frame_equal(candidates_df.to_pandas(), data) @pytest.mark.parametrize( - ("parameters", "dataframe"), + ("parameters", "dataframe", "error"), [ - pytest.param([1], pd.DataFrame({"x": [1, 2, 3]}), id="invalid_parameter_input"), + pytest.param([], edf, ValueError(">= 1"), id="empty_param"), + pytest.param(None, edf, TypeError("not iterable"), id="none_param"), + pytest.param([p_cont], edf, TypeError("be Date: Thu, 21 May 2026 14:53:48 +0200 Subject: [PATCH 27/30] Enable check for extra dataframe columns --- baybe/searchspace/candidates.py | 4 +++- baybe/utils/validation.py | 14 ++++++++++++++ tests/test_candidates.py | 2 +- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py index 3df07b3aad..2c80fb135c 100644 --- a/baybe/searchspace/candidates.py +++ b/baybe/searchspace/candidates.py @@ -105,7 +105,9 @@ class TableCandidates(CandidatesProtocol): @dataframe.validator def _validate_dataframe(self, _: Attribute, value: nw.LazyFrame) -> None: # noqa: DOC101, DOC103 # TODO: Remove collect().to_pandas() once validation on lazy frames is supported - validate_parameter_input(value.collect().to_pandas(), self.parameters) + validate_parameter_input( + value.collect().to_pandas(), self.parameters, allow_extra=False + ) @override @property diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index 93c87ab316..41bd2ed25f 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -149,6 +149,8 @@ def validate_parameter_input( data: pd.DataFrame, parameters: Iterable[Parameter], numerical_measurements_must_be_within_tolerance: bool = False, + *, + allow_extra: bool = True, ) -> None: """Validate input dataframe columns corresponding to parameters. @@ -158,10 +160,14 @@ def validate_parameter_input( numerical_measurements_must_be_within_tolerance: If ``True``, numerical parameter values must match to parameter values within the parameter-specific tolerance. + allow_extra: If ``False``, the dataframe is not allowed to contain columns that + do not correspond to any parameter. Raises: ValueError: If the data is empty. ValueError: If the data misses columns for a parameter. + ValueError: If the data contains columns that do not correspond to any parameter + and the corresponding check is enabled. ValueError: If a parameter contains NaN. TypeError: If a parameter contains non-numeric values. """ @@ -174,6 +180,14 @@ def validate_parameter_input( f"{missing}" ) + if not allow_extra and ( + extra := set(data.columns).difference({p.name for p in parameters}) + ): + raise ValueError( + f"The input dataframe contains columns that do not correspond to any " + f"parameter: {extra}" + ) + for p in parameters: if data[p.name].isna().any(): raise ValueError( diff --git a/tests/test_candidates.py b/tests/test_candidates.py index 61929e75f2..4bad0d5b47 100644 --- a/tests/test_candidates.py +++ b/tests/test_candidates.py @@ -65,7 +65,7 @@ def test_table_candidates_generation(dataframe_factory): pytest.param( [p_disc], pd.DataFrame({"disc": [1], "extra": [2]}), - ValueError("extra columns"), + ValueError("not correspond"), id="extra_cols", ), ], From 8017b6546aa75d1e8b2471b1418d610acf90888a Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 21 May 2026 15:40:49 +0200 Subject: [PATCH 28/30] Refactor product candidates tests --- tests/test_candidates.py | 177 ++++++++------------------------------- 1 file changed, 36 insertions(+), 141 deletions(-) diff --git a/tests/test_candidates.py b/tests/test_candidates.py index 4bad0d5b47..3e6a71e98a 100644 --- a/tests/test_candidates.py +++ b/tests/test_candidates.py @@ -7,6 +7,8 @@ from pandas.testing import assert_frame_equal from baybe.constraints import DiscreteSumConstraint, ThresholdCondition +from baybe.constraints.conditions import SubSelectionCondition +from baybe.constraints.discrete import DiscreteExcludeConstraint from baybe.parameters import ( CategoricalParameter, NumericalContinuousParameter, @@ -14,11 +16,13 @@ ) from baybe.searchspace.candidates import ProductCandidates, TableCandidates from baybe.utils.dataframe import create_fake_input -from baybe.utils.interval import Interval -p_disc = NumericalDiscreteParameter("disc", (1, 2, 7)) +p_disc = NumericalDiscreteParameter("disc", (1, 2)) +p_disc2 = NumericalDiscreteParameter("disc2", (0, 10)) p_cat = CategoricalParameter("cat", ("a", "b", "c")) p_cont = NumericalContinuousParameter("cont", (3, 8)) +c_sum = DiscreteSumConstraint(["disc", "disc2"], ThresholdCondition(2, "<=")) +c_sub = DiscreteExcludeConstraint(["disc"], [SubSelectionCondition([1])]) edf = pd.DataFrame() @@ -77,154 +81,45 @@ def test_table_candidates_validation(parameters, dataframe, error): @pytest.mark.parametrize( - "parameter_names", + ("constraints", "expected"), [ - ["Num_disc_1", "Num_disc_2", "Fraction_2"], - ["Categorical_1", "Num_disc_1"], - ["Categorical_1", "Categorical_2", "Categorical_1_subset"], + ([], [[1, 0], [1, 10], [2, 0], [2, 10]]), + ([c_sum], [[1, 0], [2, 0]]), + ([c_sub], [[2, 0], [2, 10]]), ], - ids=["numerical", "mixed", "categorical"], + ids=["no_constraint", "sum", "subselection"], ) -@pytest.mark.parametrize( - "constraint_names", - [ - [], - ["DiscreteSumConstraint"], - ["DiscreteExcludeConstraint"], - ], - ids=["no_constraint", "sum", "exclude"], -) -def test_product_candidates_creation(parameters, constraints): - """ProductCandidates can be created with valid parameters and constraints.""" - candidates = ProductCandidates(parameters=parameters, constraints=constraints) - lazy_candidates = candidates.to_lazy() - assert isinstance(lazy_candidates, nw.LazyFrame) - for p in parameters: - assert p.name in lazy_candidates.columns - assert candidates.is_finite - assert len(lazy_candidates.collect()) - - ProductCandidates(parameters=parameters) - - -@pytest.mark.parametrize( - ("parameters", "constraints"), - [ - pytest.param([1], ["DiscreteSumConstraint"], id="invalid_parameter_input"), - pytest.param( - NumericalContinuousParameter( - name="Conti_finite1", - bounds=Interval(0, 1), - ), - ["DiscreteSumConstraint"], - id="invalid_parameter_type", - ), - pytest.param([], ["DiscreteSumConstraint"], id="empty_parameter"), - pytest.param(None, ["DiscreteSumConstraint"], id="none_parameter"), - ], -) -def test_product_candidates_invalid_input(parameters, constraints): - """Invalid parameter and constraint inputs raise appropriate errors.""" - with pytest.raises((TypeError, ValueError, AttributeError)): - ProductCandidates(parameters=parameters, constraints=constraints) - +def test_product_candidates_generation(constraints, expected): + """ProductCandidates generates the expected lazy dataframe.""" + parameters = [p_disc, p_disc2] + candidates = ProductCandidates(parameters, constraints) + candidates_ldf = candidates.to_lazy() + candidates_df = candidates_ldf.collect() -@pytest.mark.parametrize( - "parameters,expected", - [ - pytest.param( - ( - NumericalDiscreteParameter(name="x", values=(1, 2)), - NumericalDiscreteParameter(name="y", values=(10, 20, 30)), - ), - {(x, y) for x in (1, 2) for y in (10, 20, 30)}, - id="numerical_numerical", - ), - pytest.param( - ( - NumericalDiscreteParameter(name="x", values=(1, 2)), - CategoricalParameter(name="cat", values=("a", "b")), - ), - {(x, c) for x in (1, 2) for c in ("a", "b")}, - id="numerical_categorical", - ), - pytest.param( - ( - CategoricalParameter(name="cat1", values=("a", "b")), - CategoricalParameter(name="cat2", values=("c", "d")), - CategoricalParameter(name="cat3", values=("e", "f", "g")), - ), - { - (c1, c2, c3) - for c1 in ("a", "b") - for c2 in ("c", "d") - for c3 in ("e", "f", "g") - }, - id="categorical_categorical", - ), - ], -) -def test_product_candidates_cartesian_product(parameters, expected): - """ProductCandidates builds the correct cartesian product.""" - candidates = ProductCandidates(parameters=parameters) - df = candidates.to_lazy_candidates().collect() - assert df.shape[0] == len(expected) - actual = {tuple(row) for row in df[[p.name for p in parameters]].to_numpy()} - assert actual == expected + assert candidates.is_finite + assert isinstance(candidates_ldf, nw.LazyFrame) + assert set(candidates_df.columns) == {p.name for p in parameters} + assert_frame_equal( + candidates_df.to_pandas(), + pd.DataFrame(expected, columns=[p.name for p in parameters]), + check_dtype=False, + ) @pytest.mark.parametrize( - ("parameters", "constraints", "expected_combinations"), + ("parameters", "constraints", "error"), [ + pytest.param([], (), ValueError(">= 1"), id="empty_param"), + pytest.param(None, (), TypeError("not iterable"), id="none_param"), + pytest.param([p_cont], (), TypeError("be Date: Thu, 21 May 2026 16:39:27 +0200 Subject: [PATCH 29/30] Expose candidates classes via namespace --- baybe/searchspace/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/baybe/searchspace/__init__.py b/baybe/searchspace/__init__.py index d78f7fafee..42d39d9493 100644 --- a/baybe/searchspace/__init__.py +++ b/baybe/searchspace/__init__.py @@ -1,5 +1,10 @@ """BayBE search spaces.""" +from baybe.searchspace.candidates import ( + CandidatesProtocol, + ProductCandidates, + TableCandidates, +) from baybe.searchspace.continuous import SubspaceContinuous from baybe.searchspace.core import ( SearchSpace, @@ -9,9 +14,15 @@ from baybe.searchspace.discrete import SubspaceDiscrete __all__ = [ + # Search space "validate_searchspace_from_config", "SearchSpace", "SearchSpaceType", + # Discrete + "CandidatesProtocol", + "ProductCandidates", + "TableCandidates", "SubspaceDiscrete", + # Continuous "SubspaceContinuous", ] From f6b516b8e354466121cb731466093178ee377331 Mon Sep 17 00:00:00 2001 From: Fabian Liebig Date: Sat, 23 May 2026 18:44:18 +0200 Subject: [PATCH 30/30] Fix CI: Add polars as an optional dependency for testing --- pyproject.toml | 1 + uv.lock | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 01f2826e5c..fa85bad40a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,6 +166,7 @@ benchmarking = [ ] test = [ + "baybe[polars]", "hypothesis[pandas]>=6.88.4", "tenacity>=8.5.0", "pytest>=7.2.0", diff --git a/uv.lock b/uv.lock index fec55f622c..0e3d4a109f 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-14T08:07:50.351631Z" +exclude-newer = "2026-05-16T16:42:53.860732Z" exclude-newer-span = "P7D" [[package]] @@ -388,6 +388,7 @@ simulation = [ ] test = [ { name = "hypothesis", extra = ["pandas"] }, + { name = "polars", extra = ["pyarrow"] }, { name = "pytest" }, { name = "pytest-cov" }, { name = "tenacity" }, @@ -410,6 +411,7 @@ requires-dist = [ { name = "baybe", extras = ["onnx"], marker = "extra == 'benchmarking'" }, { name = "baybe", extras = ["onnx"], marker = "extra == 'extras'" }, { name = "baybe", extras = ["polars"], marker = "extra == 'extras'" }, + { name = "baybe", extras = ["polars"], marker = "extra == 'test'" }, { name = "baybe", extras = ["simulation"], marker = "extra == 'benchmarking'" }, { name = "baybe", extras = ["simulation"], marker = "extra == 'extras'" }, { name = "baybe", extras = ["test"], marker = "extra == 'dev'" },