Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to
scalar conversion
- Using `np.isclose` for assessing equality of interval bounds instead of hard equality
check

### Removed
- `parallel_runs` argument from `simulate_scenarios`, since parallelization
Expand Down
4 changes: 3 additions & 1 deletion baybe/utils/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Any, Union

import numpy as np
from attrs import define, field
from attrs import cmp_using, define, field

from baybe.serialization import SerialMixin, converter
from baybe.settings import active_settings
Expand Down Expand Up @@ -39,13 +39,15 @@ class Interval(SerialMixin):
default=float("-inf"),
converter=lambda x: float("-inf") if x is None else float(x),
validator=non_nan_float,
eq=cmp_using(eq=lambda a, b: bool(np.isclose(a, b))),
)
"""The lower end of the interval."""

upper: float = field(
default=float("inf"),
converter=lambda x: float("inf") if x is None else float(x),
validator=non_nan_float,
eq=cmp_using(eq=lambda a, b: bool(np.isclose(a, b))),
)
"""The upper end of the interval."""

Expand Down
15 changes: 15 additions & 0 deletions tests/validation/test_interval_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,18 @@ def test_invalid_range(request, bounds):
return
with pytest.raises(ValueError):
Interval(*bounds[::-1])


@pytest.mark.parametrize(
("other", "expected"),
[
param(Interval(0, 1), True, id="exact_match"),
param(Interval(0, 0.9999999999999999), True, id="upper_float_imprecision"),
param(Interval(1e-16, 1 - 1e-16), True, id="both_float_imprecision"),
param(Interval(0, 0.5), False, id="different_upper"),
param(Interval(0.5, 1), False, id="different_lower"),
],
)
def test_close_interval_bounds(other, expected):
"""Intervals that are close up to floating-point precision are detected."""
assert (Interval(0, 1) == other) == expected
Loading