Skip to content

Add normalize parameter to sliced Wasserstein distance functions #807

@Harguna

Description

@Harguna

🚀 Feature

Add an optional normalize parameter (with companion normalize_mode) to ot.sliced_wasserstein_distance and ot.max_sliced_wasserstein_distance that standardizes inputs before computing the distance. Supports 'standard', 'minmax', and 'l2' normalization, with 'joint', 'source', or 'target' as the reference for fitting statistics. Default is normalize=None, preserving fully backward-compatible behavior.

Motivation

Sliced Wasserstein Distance projects distributions onto random 1D directions and averages the 1D Wasserstein distances. Because projection is a dot product with a random unit vector, features with larger numerical ranges dominate the projection, regardless of their actual importance.

Example: comparing two user populations on annual income ($20k–$500k) and interaction rate (0–1). If incomes are similar between the populations but interaction rates differ meaningfully, the projection will still be dominated by income and SWD reports a small distance, missing the real signal.

Two compounding problems:

  1. Users often won't realize this is happening: Unlike Euclidean distance, SWD's scale-sensitivity isn't obvious from its definition. Users may interpret a small SWD as similarity, not realizing one feature drowned out the rest.

  2. Manual normalization is friction and easy to get wrong: Today, users must add boilerplate before every call which adds additional steps:

    scaler = StandardScaler().fit(np.concatenate([X_s, X_t]))
    swd = ot.sliced_wasserstein_distance(scaler.transform(X_s), scaler.transform(X_t))
  3. Independent normalization: A common mistake - fitting each distribution independently, which silently corrupts the distance with no warning:

    X_s_norm = StandardScaler().fit_transform(X_s)
    X_t_norm = StandardScaler().fit_transform(X_t)  # WRONG: independent fit

    This makes genuinely different distributions appear identical because each is independently re-centered to mean zero.

Pitch

Add two optional keyword arguments to both sliced_wasserstein_distance and max_sliced_wasserstein_distance:

ot.sliced_wasserstein_distance(
    X_s, X_t,
    ...,
    normalize=None,           # None | 'standard' | 'minmax' | 'l2'
    normalize_mode='joint',   # 'joint' | 'source' | 'target'
)

normalize controls the method:

  • None — no normalization (default, identical to current behavior)
  • 'standard' — zero mean, unit variance per feature: (x - μ) / σ
  • 'minmax' — scale each feature to [0, 1]: (x - min) / (max - min)
  • 'l2' — unit L2-norm per sample (row-wise): x / ‖x‖₂

normalize_mode controls which samples are used to compute the statistics (relevant for 'standard' and 'minmax'; ignored for 'l2' and None):

  • 'joint' (default) — fit on concat(X_s, X_t). Preserves symmetry: SWD(X_s, X_t) == SWD(X_t, X_s). Recommended for general distribution comparison.
  • 'source' — fit on X_s only. Useful for drift detection where X_s is a reference.
  • 'target' — fit on X_t only. Reverse drift case.

Implementation:

  • A private helper _normalize_inputs(X_s, X_t, normalize, normalize_mode, nx) performs the transformation using POT's backend abstraction (nx), so it works with NumPy, PyTorch, JAX, and TensorFlow and remains differentiable.
  • Edge cases handled with warnings: zero-variance columns under 'standard', zero-range columns under 'minmax', zero-norm rows under 'l2'.
  • Comprehensive tests covering backward compatibility, all method × mode combinations, edge cases, symmetry under 'joint', and at least one non-NumPy backend.
  • An example script in examples/sliced-wasserstein/ demonstrating the scale-sensitivity problem and how the parameter solves it.

Additional context

I'm actively working on a skeleton PR for this and will open it shortly with the [WIP] tag to gather early feedback on the API design before completing the implementation. Happy to iterate on the proposed API based on maintainer's input.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions