🚀 Feature
Add an optional normalize parameter (with companion normalize_mode) to ot.sliced_wasserstein_distance and ot.max_sliced_wasserstein_distance that standardizes inputs before computing the distance. Supports 'standard', 'minmax', and 'l2' normalization, with 'joint', 'source', or 'target' as the reference for fitting statistics. Default is normalize=None, preserving fully backward-compatible behavior.
Motivation
Sliced Wasserstein Distance projects distributions onto random 1D directions and averages the 1D Wasserstein distances. Because projection is a dot product with a random unit vector, features with larger numerical ranges dominate the projection, regardless of their actual importance.
Example: comparing two user populations on annual income ($20k–$500k) and interaction rate (0–1). If incomes are similar between the populations but interaction rates differ meaningfully, the projection will still be dominated by income and SWD reports a small distance, missing the real signal.
Two compounding problems:
-
Users often won't realize this is happening: Unlike Euclidean distance, SWD's scale-sensitivity isn't obvious from its definition. Users may interpret a small SWD as similarity, not realizing one feature drowned out the rest.
-
Manual normalization is friction and easy to get wrong: Today, users must add boilerplate before every call which adds additional steps:
scaler = StandardScaler().fit(np.concatenate([X_s, X_t]))
swd = ot.sliced_wasserstein_distance(scaler.transform(X_s), scaler.transform(X_t))
-
Independent normalization: A common mistake - fitting each distribution independently, which silently corrupts the distance with no warning:
X_s_norm = StandardScaler().fit_transform(X_s)
X_t_norm = StandardScaler().fit_transform(X_t) # WRONG: independent fit
This makes genuinely different distributions appear identical because each is independently re-centered to mean zero.
Pitch
Add two optional keyword arguments to both sliced_wasserstein_distance and max_sliced_wasserstein_distance:
ot.sliced_wasserstein_distance(
X_s, X_t,
...,
normalize=None, # None | 'standard' | 'minmax' | 'l2'
normalize_mode='joint', # 'joint' | 'source' | 'target'
)
normalize controls the method:
None — no normalization (default, identical to current behavior)
'standard' — zero mean, unit variance per feature: (x - μ) / σ
'minmax' — scale each feature to [0, 1]: (x - min) / (max - min)
'l2' — unit L2-norm per sample (row-wise): x / ‖x‖₂
normalize_mode controls which samples are used to compute the statistics (relevant for 'standard' and 'minmax'; ignored for 'l2' and None):
'joint' (default) — fit on concat(X_s, X_t). Preserves symmetry: SWD(X_s, X_t) == SWD(X_t, X_s). Recommended for general distribution comparison.
'source' — fit on X_s only. Useful for drift detection where X_s is a reference.
'target' — fit on X_t only. Reverse drift case.
Implementation:
- A private helper
_normalize_inputs(X_s, X_t, normalize, normalize_mode, nx) performs the transformation using POT's backend abstraction (nx), so it works with NumPy, PyTorch, JAX, and TensorFlow and remains differentiable.
- Edge cases handled with warnings: zero-variance columns under
'standard', zero-range columns under 'minmax', zero-norm rows under 'l2'.
- Comprehensive tests covering backward compatibility, all method × mode combinations, edge cases, symmetry under
'joint', and at least one non-NumPy backend.
- An example script in
examples/sliced-wasserstein/ demonstrating the scale-sensitivity problem and how the parameter solves it.
Additional context
I'm actively working on a skeleton PR for this and will open it shortly with the [WIP] tag to gather early feedback on the API design before completing the implementation. Happy to iterate on the proposed API based on maintainer's input.
🚀 Feature
Add an optional
normalizeparameter (with companionnormalize_mode) toot.sliced_wasserstein_distanceandot.max_sliced_wasserstein_distancethat standardizes inputs before computing the distance. Supports'standard','minmax', and'l2'normalization, with'joint','source', or'target'as the reference for fitting statistics. Default isnormalize=None, preserving fully backward-compatible behavior.Motivation
Sliced Wasserstein Distance projects distributions onto random 1D directions and averages the 1D Wasserstein distances. Because projection is a dot product with a random unit vector, features with larger numerical ranges dominate the projection, regardless of their actual importance.
Example: comparing two user populations on
annual income($20k–$500k) andinteraction rate(0–1). If incomes are similar between the populations but interaction rates differ meaningfully, the projection will still be dominated by income and SWD reports a small distance, missing the real signal.Two compounding problems:
Users often won't realize this is happening: Unlike Euclidean distance, SWD's scale-sensitivity isn't obvious from its definition. Users may interpret a small SWD as similarity, not realizing one feature drowned out the rest.
Manual normalization is friction and easy to get wrong: Today, users must add boilerplate before every call which adds additional steps:
Independent normalization: A common mistake - fitting each distribution independently, which silently corrupts the distance with no warning:
This makes genuinely different distributions appear identical because each is independently re-centered to mean zero.
Pitch
Add two optional keyword arguments to both
sliced_wasserstein_distanceandmax_sliced_wasserstein_distance:normalizecontrols the method:None— no normalization (default, identical to current behavior)'standard'— zero mean, unit variance per feature:(x - μ) / σ'minmax'— scale each feature to [0, 1]:(x - min) / (max - min)'l2'— unit L2-norm per sample (row-wise):x / ‖x‖₂normalize_modecontrols which samples are used to compute the statistics (relevant for'standard'and'minmax'; ignored for'l2'andNone):'joint'(default) — fit onconcat(X_s, X_t). Preserves symmetry:SWD(X_s, X_t) == SWD(X_t, X_s). Recommended for general distribution comparison.'source'— fit onX_sonly. Useful for drift detection whereX_sis a reference.'target'— fit onX_tonly. Reverse drift case.Implementation:
_normalize_inputs(X_s, X_t, normalize, normalize_mode, nx)performs the transformation using POT's backend abstraction (nx), so it works with NumPy, PyTorch, JAX, and TensorFlow and remains differentiable.'standard', zero-range columns under'minmax', zero-norm rows under'l2'.'joint', and at least one non-NumPy backend.examples/sliced-wasserstein/demonstrating the scale-sensitivity problem and how the parameter solves it.Additional context
I'm actively working on a skeleton PR for this and will open it shortly with the
[WIP]tag to gather early feedback on the API design before completing the implementation. Happy to iterate on the proposed API based on maintainer's input.