diff --git a/AGENTS.md b/AGENTS.md index d44c1cc9ad..5e11c557bd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -58,10 +58,9 @@ More specific conventions for subdirectories: ### attrs Only All domain classes use `attrs` `@define`. No dataclasses, no Pydantic. - Immutable value objects (parameters, kernels, priors, transformations, objectives, - targets): `@define(frozen=True, slots=False)`. + targets): `@define(frozen=True)`. - Mutable stateful objects (campaign, surrogates, recommenders): `@define`. -- `slots=False` required with `frozen=True` when `cached_property` is needed. See - `attrs` issue #164 +- `slots=False` required `cached_property` is needed. See `attrs` issue #164 - Also use `slots=False` when monkeypatching is needed (e.g., `register_hooks`) ### Inheritance: ABC + SerialMixin + Protocol @@ -71,14 +70,25 @@ All domain classes use `attrs` `@define`. No dataclasses, no Pydantic. 3. Concrete classes: Inherit from ABC. ### Fields and Methods -- Use `field()` with `validator=`, `converter=`, `default=`, `factory=`, `alias=`. +- Use `field()` with arguments in this order: 1) `alias=` (if needed), 2) `init=` + (if needed), 3) `default=` / `factory=`, 4) `converter=`, 5) `validator=`. - Private fields: `_` prefix, typically `init=False`. - Store each piece of information once — no data duplication. - Use `attrs.evolve()` for modified copies of frozen objects. - Use `on_setattr` hooks for cache invalidation on mutable objects. +- Use `kw_only=True` deliberately: only when positional construction would be + ambiguous or error-prone (e.g., multiple fields of the same type, or + optional/secondary fields that should not be passed positionally). Do not + apply `kw_only` to all fields by default. - `ClassVar[bool]` for capability flags (`supports_transfer_learning`, etc.). -- Order class content like this: 1) Attributes, 2) validators and post_init, 3) - properties, 4) methods. Within each group use alphabetical order. +- Order class content like this: 1) Attributes, 2) default and validator methods, + 3) `__attrs_post_init__`, 4) properties, 5) methods. + - Attributes are ordered by functionality/importance (primary identity fields + first, optional/secondary fields last), not alphabetically. + - Default and validator methods mirror the attribute order. For a given + attribute, the default method (`_default_`) comes before its validator + (`_validate_`). + - Regular methods are ordered alphabetically. ### Attribute Docstrings String literals immediately below field declarations, blank lines between attributes. @@ -225,11 +235,24 @@ Three tiers: ## 11. Validation Patterns - Inline validators: `field(validator=(instance_of(str), min_len(1)))`, `in_()`, - `deep_iterable()`, custom `finite_float`, `gt()`. + `deep_iterable()`, custom `finite_float`, `gt()`. Order validators from simplest + to most complex: cheap structural checks (e.g., `min_len`, `instance_of`) before + expensive semantic checks (e.g., cross-field consistency, name uniqueness). - Method validators: `@_field.validator` with `# noqa: DOC101, DOC103` for validators needing `self` access. -- Cross-field: `__attrs_post_init__` when validation involves multiple fields. +- Cross-field: `__attrs_post_init__` is a last resort. Method validators + (`@field.validator`) already receive `self` and can read other already-set + attributes, so most cross-field checks belong there instead. When one field + must be compatible with another, attach the validator to the later field — + attrs sets fields in declaration order, so earlier fields are always available + via `self` at that point. When one attribute's value must be adjusted after + all fields are set — which is typically a workaround and should itself be + questioned — `__attrs_post_init__` is acceptable. - Converters: `field(converter=to_searchspace)` for automatic type coercion. +- If a converter already guarantees a specific type (e.g., `converter=list` + always produces a `list`, a custom converter always returns a known type), + omit any `instance_of(...)` validator for that same type — the check is + redundant. - Reusable validators in `baybe/utils/validation.py`: `finite_float`, `non_nan_float`, `non_inf_float`, `validate_not_nan`, `validate_target_input`, `validate_parameter_input`, `validate_object_names`. diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e669f82e0..c9c76b9dcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Coding convention instructions for agentic developers (`AGENTS.md`, `CLAUDE.md`) - `has_polars_implementation` property on `DiscreteConstraint` - `allow_missing` flag on `DiscreteConstraint.get_invalid` and `get_valid` +- `narwhals` as hard dependencies +- `CandidatesProtocol` as an interface for candidates generation +- `TableCandidates` and `ProductCandidates` classes implementing `CandidatesProtocol` +- `DiscreteParameter.is_finite` property ### Breaking Changes - `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved diff --git a/baybe/exceptions.py b/baybe/exceptions.py index 0be2273341..c049d414b8 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -179,5 +179,9 @@ class UnsupportedEarlyFilteringError(Exception): """A constraint does not support early filtering with the given parameters.""" +class InfiniteSpaceError(Exception): + """An operation requires a finite search space but the space is infinite.""" + + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() diff --git a/baybe/parameters/base.py b/baybe/parameters/base.py index 2d4df2bc77..8b09a90435 100644 --- a/baybe/parameters/base.py +++ b/baybe/parameters/base.py @@ -116,6 +116,12 @@ class DiscreteParameter(Parameter, ABC): def values(self) -> tuple: """The values the parameter can take.""" + @property + def is_finite(self) -> bool: + """Indicates whether the parameter has a finite number of values.""" + len(self.values) # <-- raises an error if the parameter is infinite + return True + @property def active_values(self) -> tuple: """The values that are considered for recommendation.""" diff --git a/baybe/searchspace/__init__.py b/baybe/searchspace/__init__.py index d78f7fafee..42d39d9493 100644 --- a/baybe/searchspace/__init__.py +++ b/baybe/searchspace/__init__.py @@ -1,5 +1,10 @@ """BayBE search spaces.""" +from baybe.searchspace.candidates import ( + CandidatesProtocol, + ProductCandidates, + TableCandidates, +) from baybe.searchspace.continuous import SubspaceContinuous from baybe.searchspace.core import ( SearchSpace, @@ -9,9 +14,15 @@ from baybe.searchspace.discrete import SubspaceDiscrete __all__ = [ + # Search space "validate_searchspace_from_config", "SearchSpace", "SearchSpaceType", + # Discrete + "CandidatesProtocol", + "ProductCandidates", + "TableCandidates", "SubspaceDiscrete", + # Continuous "SubspaceContinuous", ] diff --git a/baybe/searchspace/candidates.py b/baybe/searchspace/candidates.py new file mode 100644 index 0000000000..2c80fb135c --- /dev/null +++ b/baybe/searchspace/candidates.py @@ -0,0 +1,123 @@ +"""Candidates module for managing lazy candidate generation.""" + +import gc +from typing import Protocol + +import narwhals.stable.v2 as nw +from attr.validators import deep_iterable, instance_of, min_len +from attrs import Attribute, define, field +from typing_extensions import override + +from baybe.constraints import DISCRETE_CONSTRAINTS_FILTERING_ORDER, validate_constraints +from baybe.constraints.base import DiscreteConstraint +from baybe.exceptions import InfiniteSpaceError +from baybe.parameters.base import DiscreteParameter +from baybe.parameters.utils import sort_parameters +from baybe.searchspace.utils import build_constrained_product +from baybe.searchspace.validation import validate_parameter_names +from baybe.utils.basic import to_tuple +from baybe.utils.dataframe import to_lazy +from baybe.utils.validation import validate_parameter_input + + +class CandidatesProtocol(Protocol): + """Type protocol specifying the interface candidate generators need to implement.""" + + @property + def parameters(self) -> tuple[DiscreteParameter, ...]: + """The parameters spanning the space from which candidates are generated.""" + + @property + def is_finite(self) -> bool: + """Indicates whether the candidate set is finite or infinite.""" + + def to_lazy(self) -> nw.LazyFrame: + """Generate all candidates.""" + + +@define(frozen=True) +class ProductCandidates(CandidatesProtocol): + """Class for managing candidates from (filtered) Cartesian product spaces.""" + + parameters: tuple[DiscreteParameter, ...] = field( + converter=sort_parameters, + validator=[ + min_len(1), + deep_iterable(member_validator=instance_of(DiscreteParameter)), + lambda _, __, x: validate_parameter_names(x), + ], + ) + """See :attr:`CandidatesProtocol.parameters`.""" + + constraints: tuple[DiscreteConstraint, ...] = field( + default=(), + converter=lambda x: to_tuple( + sorted( + x, key=lambda c: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index(c.__class__) + ) + ), + validator=deep_iterable(member_validator=instance_of(DiscreteConstraint)), + ) + """Constraints to filter the Cartesian product of parameter values.""" + + @constraints.validator + def _validate_constraints( + self, _: Attribute, value: tuple[DiscreteConstraint, ...] + ): # noqa: DOC101, DOC103 + validate_constraints(value, self.parameters) + + @override + @property + def is_finite(self) -> bool: + return all(p.is_finite for p in self.parameters) + + @override + def to_lazy(self) -> nw.LazyFrame: + if not self.is_finite: + raise InfiniteSpaceError( + "Cannot generate all candidates from an infinite space." + ) + + candidates_df = build_constrained_product(self.parameters, self.constraints) + + # TODO: Remove to lazy once build_constrained_product returns a nw.LazyFrame + assert not isinstance(candidates_df, nw.LazyFrame) + return to_lazy(candidates_df) + + +@define(frozen=True) +class TableCandidates(CandidatesProtocol): + """Class for managing candidates provided in a tabular format.""" + + parameters: tuple[DiscreteParameter, ...] = field( + converter=sort_parameters, + validator=[ + min_len(1), + deep_iterable(member_validator=instance_of(DiscreteParameter)), + lambda _, __, x: validate_parameter_names(x), + ], + ) + """See :attr:`CandidatesProtocol.parameters`.""" + + dataframe: nw.LazyFrame = field(converter=to_lazy) + """The dataframe containing the candidates.""" + + @dataframe.validator + def _validate_dataframe(self, _: Attribute, value: nw.LazyFrame) -> None: # noqa: DOC101, DOC103 + # TODO: Remove collect().to_pandas() once validation on lazy frames is supported + validate_parameter_input( + value.collect().to_pandas(), self.parameters, allow_extra=False + ) + + @override + @property + def is_finite(self) -> bool: + return True + + @override + def to_lazy(self) -> nw.LazyFrame: + return self.dataframe + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 84b831a406..221e3942dd 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -7,8 +7,10 @@ from collections.abc import Callable, Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload +import narwhals.stable.v2 as nw import numpy as np import pandas as pd +from narwhals.stable.v2.typing import IntoDataFrame from typing_extensions import assert_never from baybe.exceptions import InputDataTypeWarning, SearchSpaceMatchWarning @@ -778,3 +780,8 @@ def needs_float_dtype(obj) -> bool: for col in cols_to_convert: df[col] = df[col].astype(active_settings.DTypeFloatNumpy) return df + + +def to_lazy(df: IntoDataFrame, /) -> nw.LazyFrame: + """Convert any dataframe to a :class:`~narwhals.LazyFrame`.""" + return nw.from_native(df).lazy() diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index 93c87ab316..41bd2ed25f 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -149,6 +149,8 @@ def validate_parameter_input( data: pd.DataFrame, parameters: Iterable[Parameter], numerical_measurements_must_be_within_tolerance: bool = False, + *, + allow_extra: bool = True, ) -> None: """Validate input dataframe columns corresponding to parameters. @@ -158,10 +160,14 @@ def validate_parameter_input( numerical_measurements_must_be_within_tolerance: If ``True``, numerical parameter values must match to parameter values within the parameter-specific tolerance. + allow_extra: If ``False``, the dataframe is not allowed to contain columns that + do not correspond to any parameter. Raises: ValueError: If the data is empty. ValueError: If the data misses columns for a parameter. + ValueError: If the data contains columns that do not correspond to any parameter + and the corresponding check is enabled. ValueError: If a parameter contains NaN. TypeError: If a parameter contains non-numeric values. """ @@ -174,6 +180,14 @@ def validate_parameter_input( f"{missing}" ) + if not allow_extra and ( + extra := set(data.columns).difference({p.name for p in parameters}) + ): + raise ValueError( + f"The input dataframe contains columns that do not correspond to any " + f"parameter: {extra}" + ) + for p in parameters: if data[p.name].isna().any(): raise ValueError( diff --git a/docs/conf.py b/docs/conf.py index dceb732509..747467e7d6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -274,15 +274,16 @@ # Mappings to all external packages that we want to have clickable links to intersphinx_mapping = { "botorch": ("https://botorch.readthedocs.io/en/latest", None), - "python": ("https://docs.python.org/3", None), + "narwhals": ("https://narwhals-dev.github.io/narwhals/", None), + "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), "polars": ("https://docs.pola.rs/api/python/stable/", None), + "python": ("https://docs.python.org/3", None), + "rdkit": ("https://rdkit.org/docs/", None), + "shap": ("https://shap.readthedocs.io/en/stable/", None), "skfp": ("https://scikit-fingerprints.readthedocs.io/latest/", None), "sklearn": ("https://scikit-learn.org/stable/", None), - "numpy": ("https://numpy.org/doc/stable/", None), "torch": ("https://pytorch.org/docs/main/", None), - "rdkit": ("https://rdkit.org/docs/", None), - "shap": ("https://shap.readthedocs.io/en/stable/", None), "xyzpy": ("https://xyzpy.readthedocs.io/en/latest/", None), } diff --git a/pyproject.toml b/pyproject.toml index e5a21e70c1..fa85bad40a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "exceptiongroup", "gpytorch>=1.9.1,<2", "joblib>1.4.0,<2", + "narwhals>=2,<3", "numpy>=1.24.1,<3", "pandas>=1.4.2,<3", "scikit-learn>=1.1.1,<2", @@ -165,6 +166,7 @@ benchmarking = [ ] test = [ + "baybe[polars]", "hypothesis[pandas]>=6.88.4", "tenacity>=8.5.0", "pytest>=7.2.0", diff --git a/tests/test_candidates.py b/tests/test_candidates.py new file mode 100644 index 0000000000..3e6a71e98a --- /dev/null +++ b/tests/test_candidates.py @@ -0,0 +1,125 @@ +"""Tests for candidate generators.""" + +import narwhals as nw +import pandas as pd +import polars as pl +import pytest +from pandas.testing import assert_frame_equal + +from baybe.constraints import DiscreteSumConstraint, ThresholdCondition +from baybe.constraints.conditions import SubSelectionCondition +from baybe.constraints.discrete import DiscreteExcludeConstraint +from baybe.parameters import ( + CategoricalParameter, + NumericalContinuousParameter, + NumericalDiscreteParameter, +) +from baybe.searchspace.candidates import ProductCandidates, TableCandidates +from baybe.utils.dataframe import create_fake_input + +p_disc = NumericalDiscreteParameter("disc", (1, 2)) +p_disc2 = NumericalDiscreteParameter("disc2", (0, 10)) +p_cat = CategoricalParameter("cat", ("a", "b", "c")) +p_cont = NumericalContinuousParameter("cont", (3, 8)) +c_sum = DiscreteSumConstraint(["disc", "disc2"], ThresholdCondition(2, "<=")) +c_sub = DiscreteExcludeConstraint(["disc"], [SubSelectionCondition([1])]) +edf = pd.DataFrame() + + +@pytest.mark.parametrize( + "dataframe_factory", + [ + pytest.param(lambda pd_df: pd_df, id="pandas_eager"), + pytest.param(pl.DataFrame, id="polars_eager"), + pytest.param(pl.LazyFrame, id="polars_lazy"), + pytest.param(lambda x: nw.from_native(x, eager_only=True), id="narwhals_eager"), + pytest.param(lambda x: nw.from_native(x).lazy(), id="narwhals_lazy"), + ], +) +def test_table_candidates_generation(dataframe_factory): + """TableCandidates generates the expected lazy dataframe.""" + parameters = [p_disc, p_cat] + data = create_fake_input(parameters, [], n_rows=4) + df = dataframe_factory(data) + candidates = TableCandidates(parameters, df) + candidates_ldf = candidates.to_lazy() + candidates_df = candidates_ldf.collect() + + assert candidates.is_finite + assert isinstance(candidates_ldf, nw.LazyFrame) + assert set(candidates_df.columns) == {p.name for p in parameters} + assert candidates_df.shape == data.shape + assert_frame_equal(candidates_df.to_pandas(), data) + + +@pytest.mark.parametrize( + ("parameters", "dataframe", "error"), + [ + pytest.param([], edf, ValueError(">= 1"), id="empty_param"), + pytest.param(None, edf, TypeError("not iterable"), id="none_param"), + pytest.param([p_cont], edf, TypeError("be = 1"), id="empty_param"), + pytest.param(None, (), TypeError("not iterable"), id="none_param"), + pytest.param([p_cont], (), TypeError("be