Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions sklearn/_loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
MultinomialLogit,
)
from sklearn.utils import check_scalar
from sklearn.utils.stats import _weighted_percentile
from sklearn.utils.stats import _weighted_percentile, _weighted_quantile_by_idx


# Note: The shape of raw_prediction for multiclass classifications are
Expand Down Expand Up @@ -590,6 +590,32 @@ def fit_intercept_only(self, y_true, sample_weight=None):
else:
return _weighted_percentile(y_true, sample_weight, 50)

def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None):
"""Compute raw_prediction of an intercept-only model, by idx.

This is the weighted median of the target, computed separately for each
group defined by idx.

Parameters
----------
idx : array-like of shape (n_samples,)
Group indices. Quantiles are computed separately for each unique idx value.
y_true : array-like of shape (n_samples,)
Observed, true target values.
sample_weight : None or array of shape (n_samples,), default=None
Sample weights.

Returns
-------
raw_prediction : ndarray of shape (idx.max() + 1,)
Array containing the weighted median for each group.
Groups not present in idx will have NaN values.
"""
if sample_weight is None:
sample_weight = np.ones_like(y_true)

return _weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5)


class PinballLoss(BaseLoss):
"""Quantile loss aka pinball loss, for regression.
Expand Down Expand Up @@ -643,7 +669,7 @@ def __init__(self, sample_weight=None, quantile=0.5):
def fit_intercept_only(self, y_true, sample_weight=None):
"""Compute raw_prediction of an intercept-only model.

This is the weighted median of the target, i.e. over the samples
This is the weighted quantile of the target, i.e. over the samples
axis=0.
"""
if sample_weight is None:
Expand All @@ -653,6 +679,34 @@ def fit_intercept_only(self, y_true, sample_weight=None):
y_true, sample_weight, 100 * self.closs.quantile
)

def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None):
"""Compute raw_prediction of an intercept-only model, by idx.

This is the weighted quantile of the target, computed separately for each
group defined by idx.

Parameters
----------
idx : array-like of shape (n_samples,)
Group indices. Quantiles are computed separately for each unique idx value.
y_true : array-like of shape (n_samples,)
Observed, true target values.
sample_weight : None or array of shape (n_samples,), default=None
Sample weights.

Returns
-------
raw_prediction : ndarray of shape (idx.max() + 1,)
Array containing the weighted quantile for each group.
Groups not present in idx will have NaN values.
"""
if sample_weight is None:
sample_weight = np.ones_like(y_true)

return _weighted_quantile_by_idx(
y_true, sample_weight, idx, quantile=self.closs.quantile
)


class HuberLoss(BaseLoss):
"""Huber loss, for regression.
Expand Down Expand Up @@ -726,6 +780,51 @@ def fit_intercept_only(self, y_true, sample_weight=None):
term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff))
return median + np.average(term, weights=sample_weight)

def fit_intercept_only_by_idx(self, idx, y_true, sample_weight=None):
"""Compute raw_prediction of an intercept-only model, by idx.

This is the weighted median of the target, computed separately for each
group defined by idx, with Huber adjustment.

Parameters
----------
idx : array-like of shape (n_samples,)
Group indices. Quantiles are computed separately for each unique idx value.
y_true : array-like of shape (n_samples,)
Observed, true target values.
sample_weight : None or array of shape (n_samples,), default=None
Sample weights.

Returns
-------
raw_prediction : ndarray of shape (idx.max() + 1,)
Array containing the weighted median plus Huber adjustment for each group.
Groups not present in idx will have NaN values.
"""
if sample_weight is None:
sample_weight = np.ones_like(y_true)

# Compute weighted median per group (quantile=0.5)
median = _weighted_quantile_by_idx(y_true, sample_weight, idx, quantile=0.5)

# Compute Huber adjustment term per group
diff = y_true - median[idx]
term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff))

# Weighted sum of terms per group
weighted_sum_by_idx = np.bincount(idx, weights=term * sample_weight)
# Sum of weights per group
weight_sum_by_idx = np.bincount(idx, weights=sample_weight)

# Compute weighted average per group
avg_term = np.zeros(weight_sum_by_idx.size)
not_empty = weight_sum_by_idx > 0
avg_term[not_empty] = (
weighted_sum_by_idx[not_empty] / weight_sum_by_idx[not_empty]
)

return median + avg_term


class HalfPoissonLoss(BaseLoss):
"""Half Poisson deviance loss with log-link, for regression.
Expand Down
24 changes: 11 additions & 13 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import DOUBLE, DTYPE, TREE_LEAF
from sklearn.tree._tree import DOUBLE, DTYPE
from sklearn.utils import check_array, check_random_state, column_or_1d
from sklearn.utils._param_validation import HasMethods, Hidden, Interval, StrOptions
from sklearn.utils.multiclass import check_classification_targets
Expand Down Expand Up @@ -226,20 +226,18 @@ def _update_terminal_regions(
else:
# regression losses other than the squared error.
# As of now: absolute error, pinball loss, huber loss.

# mask all which are not in sample mask.
masked_terminal_regions = terminal_regions.copy()
masked_terminal_regions[~sample_mask] = -1
# update each leaf (= perform line search)
for leaf in np.nonzero(tree.children_left == TREE_LEAF)[0]:
(indices,) = np.nonzero(masked_terminal_regions == leaf)
sw = None if sample_weight is None else sample_weight[indices]
update = loss.fit_intercept_only(
y_true=y[indices] - raw_prediction[indices, k],
sample_weight=sw,
if sample_mask.all():
idx = terminal_regions
residual = y - raw_prediction[:, k]
else:
sample_weight = (
None if sample_weight is None else sample_weight[sample_mask]
)
residual = y[sample_mask] - raw_prediction[sample_mask, k]
idx = terminal_regions[sample_mask]

tree.value[leaf, 0, 0] = update
update = loss.fit_intercept_only_by_idx(idx, residual, sample_weight)
tree.value[:, 0, 0] = update

# update predictions (both in-bag and out-of-bag)
raw_prediction[:, k] += learning_rate * tree.value[:, 0, 0].take(
Expand Down
60 changes: 60 additions & 0 deletions sklearn/utils/stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np

from sklearn.utils._array_api import (
_find_matching_floating_dtype,
Expand Down Expand Up @@ -214,3 +215,62 @@ def _weighted_percentile(
result = result[..., 0]

return result[0, ...] if n_dim == 1 else result


def _weighted_quantile_by_idx(y, w, idx, quantile=0.5):
"""Compute weighted quantile for groups defined by idx.

Parameters
----------
y : array-like
Values to compute quantiles from.
w : array-like
Sample weights for each value in y.
idx : array-like
Group indices. Quantiles are computed separately for each unique idx value.
quantile : float, default=0.5
Quantile level to compute, must be between 0 and 1.
Default is 0.5 (median).

Returns
-------
result : ndarray
Array of length (idx.max() + 1) containing the weighted quantile
for each group. Groups not present in idx will have NaN values.
"""
# 1. Sort by (idx, y)
order = np.lexsort((y, idx))
sorted_idx = idx[order]
sorted_y = y[order]
sorted_w = w[order]

# 2. Cumulative weights per group
cumsum_w = np.cumsum(sorted_w)

# 3. Total weight per group
group_ends = np.r_[np.nonzero(np.diff(sorted_idx))[0], len(sorted_idx) - 1]
group_starts = np.r_[0, group_ends[:-1] + 1]

total_w = cumsum_w[group_ends] - np.r_[0, cumsum_w[group_ends[:-1]]]
quantile_w = total_w * quantile

# 4. Cumulative weight within each group
cumsum_w_group = cumsum_w - np.repeat(
np.r_[0, cumsum_w[group_ends[:-1]]], group_ends - group_starts + 1
)

# 5. Find first y where cumulative weight >= quantile threshold
is_quantile = cumsum_w_group >= np.repeat(quantile_w, group_ends - group_starts + 1)

quantile_positions = np.zeros(len(group_ends), dtype=int)
quantile_positions[:] = group_ends # default fallback

first_true = np.flatnonzero(is_quantile)
_, first_per_group = np.unique(sorted_idx[first_true], return_index=True)
quantile_positions = first_true[first_per_group]

# 6. Output array (one quantile per idx)
result = np.full(idx.max() + 1, np.nan)
result[sorted_idx[quantile_positions]] = sorted_y[quantile_positions]

return result
Loading
Loading