Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions ot/sliced.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,72 @@
)


def _normalize_inputs(X_s, X_t, normalize, normalize_mode, nx):
"""Normalize input distributions before computing sliced Wasserstein distance.

Parameters
----------
X_s : array-like, shape (n_s, d)
Source samples
X_t : array-like, shape (n_t, d)
Target samples
normalize : str or None
Normalization method. One of {None, 'standard', 'minmax', 'l2'}.
normalize_mode : str
Reference for computing statistics. One of {'joint', 'source', 'target'}.
Ignored when normalize is None or 'l2'.
nx : backend
POT backend instance (from ot.backend.get_backend)

Returns
-------
X_s_out : array-like, shape (n_s, d)
Normalized source samples
X_t_out : array-like, shape (n_t, d)
Normalized target samples
"""
if normalize is None:
return X_s, X_t

if normalize_mode not in ("joint", "source", "target"):
raise ValueError(
f"Invalid normalize_mode '{normalize_mode}'. "
"Expected one of: 'joint', 'source', 'target'."
)

if normalize == "standard":
# TODO: full implementation
# - compute mean/std using nx ops based on normalize_mode
# - apply to both X_s and X_t
# - handle zero-variance columns with warnings.warn
raise NotImplementedError(
"normalize='standard' will be implemented in a follow-up commit."
)

elif normalize == "minmax":
# TODO: full implementation
# - compute min/max using nx ops based on normalize_mode
# - apply to both X_s and X_t
# - handle zero-range columns with warnings.warn
raise NotImplementedError(
"normalize='minmax' will be implemented in a follow-up commit."
)

elif normalize == "l2":
# TODO: full implementation
# - row-wise L2 normalization (normalize_mode is ignored)
# - handle zero-norm rows with warnings.warn
raise NotImplementedError(
"normalize='l2' will be implemented in a follow-up commit."
)

else:
raise ValueError(
f"Invalid normalize value '{normalize}'. "
"Expected one of: None, 'standard', 'minmax', 'l2'."
)


def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
r"""
Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})`
Expand Down Expand Up @@ -76,6 +142,8 @@ def sliced_wasserstein_distance(
projections=None,
seed=None,
log=False,
normalize=None,
normalize_mode="joint",
):
r"""
Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance
Expand Down Expand Up @@ -109,6 +177,24 @@ def sliced_wasserstein_distance(
Seed used for random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
normalize : str or None, optional
Normalization applied to X_s and X_t before computing the distance.
Useful when features have different scales. Options:

- ``None`` : no normalization (default, preserves existing behavior)
- ``'standard'`` : zero mean, unit variance per feature dimension
- ``'minmax'`` : scale each feature to [0, 1]
- ``'l2'`` : normalize each sample to unit L2 norm (row-wise)

normalize_mode : str, optional
Determines which samples are used to compute normalization statistics.
Ignored when ``normalize`` is ``None`` or ``'l2'``. Options:

- ``'joint'`` : statistics from ``concat(X_s, X_t)`` (default).
Preserves symmetry: SWD(X_s, X_t) == SWD(X_t, X_s).
- ``'source'`` : statistics from ``X_s`` only. Useful for drift
detection where X_s is the reference distribution.
- ``'target'`` : statistics from ``X_t`` only.

Returns
-------
Expand Down Expand Up @@ -136,6 +222,8 @@ def sliced_wasserstein_distance(

nx = get_backend(X_s, X_t, a, b, projections)

X_s, X_t = _normalize_inputs(X_s, X_t, normalize, normalize_mode, nx)

n = X_s.shape[0]
m = X_t.shape[0]

Expand Down Expand Up @@ -181,6 +269,8 @@ def max_sliced_wasserstein_distance(
projections=None,
seed=None,
log=False,
normalize=None,
normalize_mode="joint",
):
r"""
Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance
Expand Down Expand Up @@ -215,6 +305,24 @@ def max_sliced_wasserstein_distance(
Seed used for random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
normalize : str or None, optional
Normalization applied to X_s and X_t before computing the distance.
Useful when features have different scales. Options:

- ``None`` : no normalization (default, preserves existing behavior)
- ``'standard'`` : zero mean, unit variance per feature dimension
- ``'minmax'`` : scale each feature to [0, 1]
- ``'l2'`` : normalize each sample to unit L2 norm (row-wise)

normalize_mode : str, optional
Determines which samples are used to compute normalization statistics.
Ignored when ``normalize`` is ``None`` or ``'l2'``. Options:

- ``'joint'`` : statistics from ``concat(X_s, X_t)`` (default).
Preserves symmetry: SWD(X_s, X_t) == SWD(X_t, X_s).
- ``'source'`` : statistics from ``X_s`` only. Useful for drift
detection where X_s is the reference distribution.
- ``'target'`` : statistics from ``X_t`` only.

Returns
-------
Expand Down Expand Up @@ -242,6 +350,8 @@ def max_sliced_wasserstein_distance(

nx = get_backend(X_s, X_t, a, b, projections)

X_s, X_t = _normalize_inputs(X_s, X_t, normalize, normalize_mode, nx)

n = X_s.shape[0]
m = X_t.shape[0]

Expand Down
Loading