From 40a656d73356ec0d74f7f97d9916857c63b39bf3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jan 2026 22:00:02 +0100 Subject: [PATCH 1/2] PRF: optim GB linesearch with by idx calculations for regression losses --- sklearn/_loss/loss.py | 103 +++++++++++++++++- sklearn/ensemble/_gb.py | 24 ++--- sklearn/utils/stats.py | 60 +++++++++++ sklearn/utils/tests/test_stats.py | 169 +++++++++++++++++++++++++++++- 4 files changed, 340 insertions(+), 16 deletions(-) diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index 9cbaa5284d3a2..11a059aed61aa 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -46,7 +46,7 @@ MultinomialLogit, ) from sklearn.utils import check_scalar -from sklearn.utils.stats import _weighted_percentile +from sklearn.utils.stats import _weighted_percentile, weighted_quantile_by_idx # Note: The shape of raw_prediction for multiclass classifications are @@ -590,6 +590,32 @@ def fit_intercept_only(self, y_true, sample_weight=None): else: return _weighted_percentile(y_true, sample_weight, 50) + def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None): + """Compute raw_prediction of an intercept-only model, by idx. + + This is the weighted median of the target, computed separately for each + group defined by idx. + + Parameters + ---------- + idx : array-like of shape (n_samples,) + Group indices. Quantiles are computed separately for each unique idx value. + y_true : array-like of shape (n_samples,) + Observed, true target values. + sample_weight : None or array of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + raw_prediction : ndarray of shape (idx.max() + 1,) + Array containing the weighted median for each group. + Groups not present in idx will have NaN values. + """ + if sample_weight is None: + sample_weight = np.ones_like(y_true) + + return weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5) + class PinballLoss(BaseLoss): """Quantile loss aka pinball loss, for regression. @@ -643,7 +669,7 @@ def __init__(self, sample_weight=None, quantile=0.5): def fit_intercept_only(self, y_true, sample_weight=None): """Compute raw_prediction of an intercept-only model. - This is the weighted median of the target, i.e. over the samples + This is the weighted quantile of the target, i.e. over the samples axis=0. """ if sample_weight is None: @@ -653,6 +679,34 @@ def fit_intercept_only(self, y_true, sample_weight=None): y_true, sample_weight, 100 * self.closs.quantile ) + def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None): + """Compute raw_prediction of an intercept-only model, by idx. + + This is the weighted quantile of the target, computed separately for each + group defined by idx. + + Parameters + ---------- + idx : array-like of shape (n_samples,) + Group indices. Quantiles are computed separately for each unique idx value. + y_true : array-like of shape (n_samples,) + Observed, true target values. + sample_weight : None or array of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + raw_prediction : ndarray of shape (idx.max() + 1,) + Array containing the weighted quantile for each group. + Groups not present in idx will have NaN values. + """ + if sample_weight is None: + sample_weight = np.ones_like(y_true) + + return weighted_quantile_by_idx( + y_true, sample_weight, idx, quantile=self.closs.quantile + ) + class HuberLoss(BaseLoss): """Huber loss, for regression. @@ -726,6 +780,51 @@ def fit_intercept_only(self, y_true, sample_weight=None): term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff)) return median + np.average(term, weights=sample_weight) + def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None): + """Compute raw_prediction of an intercept-only model, by idx. + + This is the weighted median of the target, computed separately for each + group defined by idx, with Huber adjustment. + + Parameters + ---------- + idx : array-like of shape (n_samples,) + Group indices. Quantiles are computed separately for each unique idx value. + y_true : array-like of shape (n_samples,) + Observed, true target values. + sample_weight : None or array of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + raw_prediction : ndarray of shape (idx.max() + 1,) + Array containing the weighted median plus Huber adjustment for each group. + Groups not present in idx will have NaN values. + """ + if sample_weight is None: + sample_weight = np.ones_like(y_true) + + # Compute weighted median per group (quantile=0.5) + median = weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5) + + # Compute Huber adjustment term per group + diff = y_true - median[idx] + term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff)) + + # Weighted sum of terms per group + weighted_sum_by_idx = np.bincount(idx, weights=term * sample_weight) + # Sum of weights per group + weight_sum_by_idx = np.bincount(idx, weights=sample_weight) + + # Compute weighted average per group + avg_term = np.zeros(weight_sum_by_idx.size) + not_empty = weight_sum_by_idx > 0 + avg_term[not_empty] = ( + weighted_sum_by_idx[not_empty] / weight_sum_by_idx[not_empty] + ) + + return median + avg_term + class HalfPoissonLoss(BaseLoss): """Half Poisson deviance loss with log-link, for regression. diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index ee163b764a875..6522d8a2e9fa3 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -49,7 +49,7 @@ from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from sklearn.tree import DecisionTreeRegressor -from sklearn.tree._tree import DOUBLE, DTYPE, TREE_LEAF +from sklearn.tree._tree import DOUBLE, DTYPE from sklearn.utils import check_array, check_random_state, column_or_1d from sklearn.utils._param_validation import HasMethods, Hidden, Interval, StrOptions from sklearn.utils.multiclass import check_classification_targets @@ -226,20 +226,18 @@ def _update_terminal_regions( else: # regression losses other than the squared error. # As of now: absolute error, pinball loss, huber loss. - - # mask all which are not in sample mask. - masked_terminal_regions = terminal_regions.copy() - masked_terminal_regions[~sample_mask] = -1 - # update each leaf (= perform line search) - for leaf in np.nonzero(tree.children_left == TREE_LEAF)[0]: - (indices,) = np.nonzero(masked_terminal_regions == leaf) - sw = None if sample_weight is None else sample_weight[indices] - update = loss.fit_intercept_only( - y_true=y[indices] - raw_prediction[indices, k], - sample_weight=sw, + if sample_mask.all(): + idx = terminal_regions + residual = y - raw_prediction[:, k] + else: + sample_weight = ( + None if sample_weight is None else sample_weight[sample_mask] ) + residual = y[sample_mask] - raw_prediction[sample_mask, k] + idx = terminal_regions[sample_mask] - tree.value[leaf, 0, 0] = update + update = loss.fit_intercept_only_by_idx(idx, residual, sample_weight) + tree.value[:, 0, 0] = update # update predictions (both in-bag and out-of-bag) raw_prediction[:, k] += learning_rate * tree.value[:, 0, 0].take( diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 71fa1418e235e..f0ff2e316a1bd 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -1,5 +1,6 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import numpy as np from sklearn.utils._array_api import ( _find_matching_floating_dtype, @@ -214,3 +215,62 @@ def _weighted_percentile( result = result[..., 0] return result[0, ...] if n_dim == 1 else result + + +def weighted_quantile_by_idx(y, w, idx, quantile=0.5): + """Compute weighted quantile for groups defined by idx. + + Parameters + ---------- + y : array-like + Values to compute quantiles from. + w : array-like + Sample weights for each value in y. + idx : array-like + Group indices. Quantiles are computed separately for each unique idx value. + quantile : float, default=0.5 + Quantile level to compute, must be between 0 and 1. + Default is 0.5 (median). + + Returns + ------- + result : ndarray + Array of length (idx.max() + 1) containing the weighted quantile + for each group. Groups not present in idx will have NaN values. + """ + # 1. Sort by (idx, y) + order = np.lexsort((y, idx)) + sorted_idx = idx[order] + sorted_y = y[order] + sorted_w = w[order] + + # 2. Cumulative weights per group + cumsum_w = np.cumsum(sorted_w) + + # 3. Total weight per group + group_ends = np.r_[np.nonzero(np.diff(sorted_idx))[0], len(sorted_idx) - 1] + group_starts = np.r_[0, group_ends[:-1] + 1] + + total_w = cumsum_w[group_ends] - np.r_[0, cumsum_w[group_ends[:-1]]] + quantile_w = total_w * quantile + + # 4. Cumulative weight within each group + cumsum_w_group = cumsum_w - np.repeat( + np.r_[0, cumsum_w[group_ends[:-1]]], group_ends - group_starts + 1 + ) + + # 5. Find first y where cumulative weight >= quantile threshold + is_quantile = cumsum_w_group >= np.repeat(quantile_w, group_ends - group_starts + 1) + + quantile_positions = np.zeros(len(group_ends), dtype=int) + quantile_positions[:] = group_ends # default fallback + + first_true = np.flatnonzero(is_quantile) + _, first_per_group = np.unique(sorted_idx[first_true], return_index=True) + quantile_positions = first_true[first_per_group] + + # 6. Output array (one quantile per idx) + result = np.full(idx.max() + 1, np.nan) + result[sorted_idx[quantile_positions]] = sorted_y[quantile_positions] + + return result diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 60e1c2acc0945..865abc42f9b8d 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -12,7 +12,7 @@ from sklearn.utils._array_api import device as array_device from sklearn.utils.estimator_checks import _array_api_for_tests from sklearn.utils.fixes import np_version, parse_version -from sklearn.utils.stats import _weighted_percentile +from sklearn.utils.stats import _weighted_percentile, weighted_quantile_by_idx @pytest.mark.parametrize("average", [True, False]) @@ -483,3 +483,170 @@ def test_weighted_percentile_like_numpy_nanquantile( ) assert_array_equal(percentile_weighted_percentile, percentile_numpy_nanquantile) + + +def test_weighted_quantile_by_idx_basic(): + """Test basic functionality of weighted_quantile_by_idx with median.""" + y = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + w = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + idx = np.array([0, 0, 0, 1, 1, 1]) + + result = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + + # Using "inverted_cdf" method + # Group 0: [1, 2, 3] -> cumsum [1, 2, 3], threshold=1.5, + # first >= 1.5 is at cumsum=2 -> y=2 + # Group 1: [4, 5, 6] -> cumsum [1, 2, 3], threshold=1.5, + # first >= 1.5 is at cumsum=2 -> y=5 + assert result[0] == approx(2.0) + assert result[1] == approx(5.0) + + +def test_weighted_quantile_by_idx_different_quantiles(): + """Test weighted_quantile_by_idx with different quantile levels.""" + y = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + w = np.array([1.0, 1.0, 1.0, 1.0, 1.0]) + idx = np.array([0, 0, 0, 0, 0]) + + # Test different quantiles + result_q25 = weighted_quantile_by_idx(y, w, idx, quantile=0.25) + result_q50 = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + result_q75 = weighted_quantile_by_idx(y, w, idx, quantile=0.75) + + # For uniform weights [1,2,3,4,5]: cumulative weights are [1,2,3,4,5] + # q=0.25: 0.25*5=1.25, first >= 1.25 is at cumsum=2 -> y=2 + # q=0.50: 0.50*5=2.5, first >= 2.5 is at cumsum=3 -> y=3 + # q=0.75: 0.75*5=3.75, first >= 3.75 is at cumsum=4 -> y=4 + assert result_q25[0] == approx(2.0) + assert result_q50[0] == approx(3.0) + assert result_q75[0] == approx(4.0) + + +def test_weighted_quantile_by_idx_weighted(): + """Test weighted_quantile_by_idx with non-uniform weights.""" + y = np.array([1.0, 2.0, 3.0]) + w = np.array([1.0, 2.0, 1.0]) # Total weight = 4 + idx = np.array([0, 0, 0]) + + result = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + + # Cumulative weights: [1, 3, 4] + # Median at 0.5*4=2.0, first cumsum >= 2.0 is 3 -> y=2 + assert result[0] == approx(2.0) + + +def test_weighted_quantile_by_idx_multiple_groups(): + """Test weighted_quantile_by_idx with multiple groups.""" + y = np.array([5.0, 1.0, 3.0, 2.0, 4.0, 6.0]) + w = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + idx = np.array([1, 0, 0, 1, 2, 2]) + + result = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + + # Using "inverted_cdf" method (like _weighted_percentile with average=False) + # Group 0: [1, 3] -> cumsum [1, 2], threshold=1.0, + # first >= 1.0 is at cumsum=1 -> y=1 + # Group 1: [2, 5] -> cumsum [1, 2], threshold=1.0, + # first >= 1.0 is at cumsum=1 -> y=2 + # Group 2: [4, 6] -> cumsum [1, 2], threshold=1.0, + # first >= 1.0 is at cumsum=1 -> y=4 + assert result[0] == approx(1.0) + assert result[1] == approx(2.0) + assert result[2] == approx(4.0) + + +def test_weighted_quantile_by_idx_missing_groups(): + """Test that missing group indices return NaN.""" + y = np.array([1.0, 2.0, 3.0]) + w = np.array([1.0, 1.0, 1.0]) + idx = np.array([0, 0, 2]) # Group 1 is missing + + result = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + + # Group 0: [1, 2] -> cumsum [1, 2], threshold=1.0, + # first >= 1.0 is at cumsum=1 -> y=1 + assert result[0] == approx(1.0) + assert np.isnan(result[1]) # Group 1: missing + assert result[2] == approx(3.0) # Group 2: only [3] + + +def test_weighted_quantile_by_idx_single_value_per_group(): + """Test with single value per group.""" + y = np.array([10.0, 20.0, 30.0]) + w = np.array([1.0, 2.0, 3.0]) + idx = np.array([0, 1, 2]) + + result = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + + # Each group has only one value, so that value is the quantile + assert result[0] == approx(10.0) + assert result[1] == approx(20.0) + assert result[2] == approx(30.0) + + +def test_weighted_quantile_by_idx_extreme_quantiles(): + """Test with extreme quantile values (0 and 1).""" + y = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + w = np.array([1.0, 1.0, 1.0, 1.0, 1.0]) + idx = np.array([0, 0, 0, 0, 0]) + + # Quantile 0 should give first value when cumsum > 0 + result_q0 = weighted_quantile_by_idx(y, w, idx, quantile=0.0) + # For q=0, threshold is 0, first cumsum >= 0 is cumsum=1 -> y=1 + assert result_q0[0] == approx(1.0) + + # Quantile 1 should give last value + result_q1 = weighted_quantile_by_idx(y, w, idx, quantile=1.0) + # For q=1, threshold is 5, first cumsum >= 5 is cumsum=5 -> y=5 + assert result_q1[0] == approx(5.0) + + +def test_weighted_quantile_by_idx_unsorted_input(): + """Test that the function handles unsorted input correctly.""" + # Input is intentionally unsorted + y = np.array([3.0, 1.0, 4.0, 2.0, 5.0]) + w = np.array([1.0, 1.0, 1.0, 1.0, 1.0]) + idx = np.array([1, 0, 1, 0, 1]) + + result = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + + # Group 0: [1, 2] -> cumsum [1, 2], threshold=1.0, + # first >= 1.0 is at cumsum=1 -> y=1 + # Group 1: [3, 4, 5] -> cumsum [1, 2, 3], threshold=1.5, + # first >= 1.5 is at cumsum=2 -> y=4 + assert result[0] == approx(1.0) + assert result[1] == approx(4.0) + + +def test_weighted_quantile_by_idx_consistency_with_weighted_percentile(): + """Test consistency with _weighted_percentile for single group.""" + rng = np.random.RandomState(42) + y = rng.rand(20) + w = rng.rand(20) + idx = np.zeros(20, dtype=int) # All in group 0 + + for quantile in [0.25, 0.5, 0.75]: + result_by_idx = weighted_quantile_by_idx(y, w, idx, quantile=quantile) + # _weighted_percentile expects percentile_rank in [0, 100] + result_percentile = _weighted_percentile( + y, w, percentile_rank=quantile * 100, average=False + ) + + assert result_by_idx[0] == approx(result_percentile) + + +@pytest.mark.parametrize("quantile", [0.1, 0.25, 0.5, 0.75, 0.9]) +def test_weighted_quantile_by_idx_parametrized_quantiles(quantile): + """Test various quantile levels.""" + y = np.arange(1, 11, dtype=float) # [1, 2, ..., 10] + w = np.ones(10) + idx = np.zeros(10, dtype=int) + + result = weighted_quantile_by_idx(y, w, idx, quantile=quantile) + + # Verify result is within the data range + assert 1.0 <= result[0] <= 10.0 + # Verify it's monotonic with quantile level + if quantile > 0.5: + result_lower = weighted_quantile_by_idx(y, w, idx, quantile=0.5) + assert result[0] >= result_lower[0] From 9787150e3cce559c5d430ece873e408667e698b1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jan 2026 22:40:04 +0100 Subject: [PATCH 2/2] renaming --- sklearn/_loss/loss.py | 8 ++++---- sklearn/utils/stats.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index 11a059aed61aa..3a3ab6df0248b 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -46,7 +46,7 @@ MultinomialLogit, ) from sklearn.utils import check_scalar -from sklearn.utils.stats import _weighted_percentile, weighted_quantile_by_idx +from sklearn.utils.stats import _weighted_percentile, _weighted_quantile_by_idx # Note: The shape of raw_prediction for multiclass classifications are @@ -614,7 +614,7 @@ def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None): if sample_weight is None: sample_weight = np.ones_like(y_true) - return weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5) + return _weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5) class PinballLoss(BaseLoss): @@ -703,7 +703,7 @@ def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None): if sample_weight is None: sample_weight = np.ones_like(y_true) - return weighted_quantile_by_idx( + return _weighted_quantile_by_idx( y_true, sample_weight, idx, quantile=self.closs.quantile ) @@ -805,7 +805,7 @@ def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None): sample_weight = np.ones_like(y_true) # Compute weighted median per group (quantile=0.5) - median = weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5) + median = _weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5) # Compute Huber adjustment term per group diff = y_true - median[idx] diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index f0ff2e316a1bd..b387ca091efb1 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -217,7 +217,7 @@ def _weighted_percentile( return result[0, ...] if n_dim == 1 else result -def weighted_quantile_by_idx(y, w, idx, quantile=0.5): +def _weighted_quantile_by_idx(y, w, idx, quantile=0.5): """Compute weighted quantile for groups defined by idx. Parameters