From fe2636c576338d7c81d9a2750d0a313fd6f9e6d9 Mon Sep 17 00:00:00 2001 From: Jiarui Xu <39042389+jxudata@users.noreply.github.com> Date: Tue, 26 May 2026 08:47:43 -0700 Subject: [PATCH] Fix uniform sample weight knot placement --- src/splinator/estimators.py | 2 ++ src/splinator/metric_wrappers.py | 1 - src/splinator/monotonic_spline.py | 14 +++++++++++++- tests/test_sample_weight.py | 24 +++++++++++++++++++++++- 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/splinator/estimators.py b/src/splinator/estimators.py index 5f635f0..e0c4768 100644 --- a/src/splinator/estimators.py +++ b/src/splinator/estimators.py @@ -500,6 +500,8 @@ def fit(self, X, y, sample_weight=None): # Validate sample_weight if sample_weight is not None: sample_weight = np.asarray(sample_weight) + if sample_weight.ndim != 1: + raise ValueError("sample_weight must be a 1-D array") if sample_weight.shape[0] != X.shape[0]: raise ValueError( f"sample_weight has {sample_weight.shape[0]} samples, " diff --git a/src/splinator/metric_wrappers.py b/src/splinator/metric_wrappers.py index d7ea312..2765140 100644 --- a/src/splinator/metric_wrappers.py +++ b/src/splinator/metric_wrappers.py @@ -150,4 +150,3 @@ def torch_wrapper(y_true, y_pred, sample_weight=None): f"Unknown framework: {framework}. " f"Supported: 'sklearn', 'xgboost', 'lightgbm', 'pytorch'" ) - diff --git a/src/splinator/monotonic_spline.py b/src/splinator/monotonic_spline.py index 4fced3f..f21f4f4 100644 --- a/src/splinator/monotonic_spline.py +++ b/src/splinator/monotonic_spline.py @@ -149,8 +149,20 @@ def _weighted_quantile(values, quantiles, sample_weight=None): if sample_weight is None: return np.quantile(values, quantiles) - + sample_weight = np.asarray(sample_weight) + if sample_weight.shape != values.shape: + raise ValueError("sample_weight must have the same shape as values") + + positive_weight = sample_weight > 0 + if not np.any(positive_weight): + raise ValueError("sample_weight must contain at least one positive value") + + values = values[positive_weight] + sample_weight = sample_weight[positive_weight] + if np.all(sample_weight == sample_weight[0]): + return np.quantile(values, quantiles) + sorted_indices = np.argsort(values) sorted_values = values[sorted_indices] sorted_weights = sample_weight[sorted_indices] diff --git a/tests/test_sample_weight.py b/tests/test_sample_weight.py index d4e01f6..eb754ba 100644 --- a/tests/test_sample_weight.py +++ b/tests/test_sample_weight.py @@ -32,7 +32,16 @@ def test_uniform_weights_matches_numpy(self): result = _weighted_quantile(X, quantiles, sample_weight=weights) expected = np.quantile(X, quantiles) - np.testing.assert_array_almost_equal(result, expected, decimal=1) + np.testing.assert_array_almost_equal(result, expected) + + def test_uniform_weights_preserve_interpolation(self): + """Uniform weights should preserve NumPy's interpolated quantiles.""" + X = np.array([0.0, 10.0]) + weights = np.ones(2) + + result = _weighted_quantile(X, [0.5], sample_weight=weights) + + np.testing.assert_array_equal(result, np.array([5.0])) def test_doubled_weights_equivalent_to_duplication(self): """Doubling a sample's weight should be like duplicating it.""" @@ -69,6 +78,15 @@ def test_knots_with_weights(self): assert len(knots) == 4 + def test_uniform_weights_match_unweighted_knots(self): + """Uniform weights should not change knot placement.""" + X = np.array([0.0, 10.0]) + weights = np.ones(2) + + knots = _fit_knots(X, num_knots=2, sample_weight=weights) + + np.testing.assert_array_equal(knots, _fit_knots(X, num_knots=2)) + class TestLossGradHessWithWeights: """Test LossGradHess class with sample weights.""" @@ -213,6 +231,10 @@ def test_sample_weight_validation(self): # Wrong length with pytest.raises(ValueError, match="sample_weight has .* samples"): model.fit(X, y, sample_weight=np.ones(50)) + + # Wrong dimensionality + with pytest.raises(ValueError, match="1-D"): + model.fit(X, y, sample_weight=np.ones((100, 2))) # Negative weights with pytest.raises(ValueError, match="non-negative"):