From 22c2d1723ef1b74574a12dc6166d73d281b2a28e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Nordstr=C3=B6m?= Date: Tue, 21 Apr 2026 16:15:27 +0300 Subject: [PATCH 1/3] feat(linfa-nn): Added Earth Mover's Distance --- algorithms/linfa-nn/src/distance.rs | 65 ++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/algorithms/linfa-nn/src/distance.rs b/algorithms/linfa-nn/src/distance.rs index de4af9531..0272730c0 100644 --- a/algorithms/linfa-nn/src/distance.rs +++ b/algorithms/linfa-nn/src/distance.rs @@ -118,6 +118,27 @@ impl Distance for LpDist { } } +/// Wasserstein or [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EarthMoverDist; +impl Distance for EarthMoverDist { + #[inline] + fn distance(&self, a: ArrayView, b: ArrayView) -> F { + let mut cumulative_diff = F::zero(); + let mut emd = F::zero(); + Zip::from(&a).and(&b).for_each(|&a, &b| { + cumulative_diff += a - b; + emd += cumulative_diff.abs() + }); + emd + } +} + /// Computes a similarity matrix with gaussian kernel and scaling parameter `eps` /// /// The generated matrix is a upper triangular matrix with dimension NxN (number of observations) and contains the similarity between all permutations of observations @@ -146,7 +167,7 @@ pub fn to_gaussian_similarity( #[cfg(test)] mod test { use approx::assert_abs_diff_eq; - use ndarray::arr1; + use ndarray::{arr1, arr2}; use super::*; @@ -157,6 +178,7 @@ mod test { has_autotraits::(); has_autotraits::(); has_autotraits::>(); + has_autotraits::(); } fn dist_test>(dist: D, result: f64) { @@ -204,4 +226,45 @@ mod test { fn lp_dist() { dist_test(LpDist(3.3), 4.635); } + + #[test] + fn emd_dist() { + dist_test(EarthMoverDist, 4.2); + + let dist = EarthMoverDist; + let a = arr1(&[0.5, 0.5]); + let b = arr1(&[0.3, 0.7]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.2, epsilon = 1e-5); + assert_abs_diff_eq!(dist.rdist_to_dist(dist.dist_to_rdist(ab)), ab); + + let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); + let b = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.0, epsilon = 1e-5); + + let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); + let b = arr1(&[0.1, 0.2, 0.1, 0.15, 0.45]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.8, epsilon = 1e-5); + + let a = arr1(&[0.3, 0.2, 0.15, 0.10, 0.25]); + let b = arr1(&[0.1, 0.2, 0.05, 0.20, 0.45]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.9, epsilon = 1e-5); + + let a = arr1(&[0.35, 0.15, 0.15, 0.10, 0.25]); + let b = arr1(&[0.1, 0.20, 0.05, 0.20, 0.45]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.95, epsilon = 1e-5); + + let a = arr2(&[[0.3, 0.2, 0.15, 0.10, 0.25], [0.35, 0.15, 0.15, 0.10, 0.25]]); + let b = arr2(&[[0.1, 0.2, 0.05, 0.20, 0.45], [0.1, 0.20, 0.05, 0.20, 0.45]]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.9 + 0.95, epsilon = 1e-5); + + let a = arr1(&[f64::INFINITY, 6.6]); + let b = arr1(&[4.4, f64::NEG_INFINITY]); + assert!(dist.distance(a.view(), b.view()).is_infinite()); + } } From b52e4dc14012cfc77cea458a850eeb79914bf232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Nordstr=C3=B6m?= Date: Fri, 22 May 2026 12:34:30 +0300 Subject: [PATCH 2/3] Changed EarthMoverDist as WassersteinDist and use tests from SciPy --- algorithms/linfa-nn/src/distance.rs | 136 ++++++++++++++++++++-------- 1 file changed, 97 insertions(+), 39 deletions(-) diff --git a/algorithms/linfa-nn/src/distance.rs b/algorithms/linfa-nn/src/distance.rs index 0272730c0..95847e041 100644 --- a/algorithms/linfa-nn/src/distance.rs +++ b/algorithms/linfa-nn/src/distance.rs @@ -118,24 +118,29 @@ impl Distance for LpDist { } } -/// Wasserstein or [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance +/// [Wasserstein](https://en.wikipedia.org/wiki/Wasserstein_metric) or +/// [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance. +/// +/// The function accepts histograms where each array element is the probability mass at that index. +/// This differs from SciPy's `wasserstein_distance` which instead accepts support values and weights, +/// then builds the histograms internally. #[cfg_attr( feature = "serde", derive(Serialize, Deserialize), serde(crate = "serde_crate") )] #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EarthMoverDist; -impl Distance for EarthMoverDist { +pub struct WassersteinDist; +impl Distance for WassersteinDist { #[inline] fn distance(&self, a: ArrayView, b: ArrayView) -> F { let mut cumulative_diff = F::zero(); - let mut emd = F::zero(); + let mut dist = F::zero(); Zip::from(&a).and(&b).for_each(|&a, &b| { cumulative_diff += a - b; - emd += cumulative_diff.abs() + dist += cumulative_diff.abs() }); - emd + dist } } @@ -167,7 +172,7 @@ pub fn to_gaussian_similarity( #[cfg(test)] mod test { use approx::assert_abs_diff_eq; - use ndarray::{arr1, arr2}; + use ndarray::arr1; use super::*; @@ -178,7 +183,7 @@ mod test { has_autotraits::(); has_autotraits::(); has_autotraits::>(); - has_autotraits::(); + has_autotraits::(); } fn dist_test>(dist: D, result: f64) { @@ -228,43 +233,96 @@ mod test { } #[test] - fn emd_dist() { - dist_test(EarthMoverDist, 4.2); + fn wasserstein_dist() { + dist_test(WassersteinDist, 4.2); + } - let dist = EarthMoverDist; - let a = arr1(&[0.5, 0.5]); - let b = arr1(&[0.3, 0.7]); - let ab = dist.distance(a.view(), b.view()); - assert_abs_diff_eq!(ab, 0.2, epsilon = 1e-5); - assert_abs_diff_eq!(dist.rdist_to_dist(dist.dist_to_rdist(ab)), ab); + // The following Wasserstein tests are from SciPy. + // However, since SciPy Wasserstein distance has different API as ours, + // we need to first transform the SciPy parameters into histograms that our API accepts. + // + // For example, SciPy values `[1, 3]` with weights `[1, 9]`: + // At index 1 we have weight 1 out of total weight 10 => 0.1. + // At index 3 we have weight 9 out of total weight 10 => 0.9. + // Thus we get a histogram [0.0, 0.1, 0.0, 0.9] - let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); - let b = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); - let ab = dist.distance(a.view(), b.view()); - assert_abs_diff_eq!(ab, 0.0, epsilon = 1e-5); + #[test] + /// For basic distributions, the value of the Wasserstein distance is straightforward. + fn wasserstein_simple() { + let dist = WassersteinDist; + + // SciPy: u_values=[0, 1], v_values=[0], u_weights=[1, 1], v_weights=[1] + let u = arr1(&[0.5, 0.5]); + let v = arr1(&[1.0, 0.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.5); + + // SciPy: u_values=[0, 1], v_values=[0], u_weights=[3, 1], v_weights=[1] + let u = arr1(&[0.75, 0.25]); + let v = arr1(&[1.0, 0.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.25); + + // SciPy: u_values=[0, 2], v_values=[0], u_weights=[1, 1], v_weights=[1] + let u = arr1(&[0.5, 0.0, 0.5]); + let v = arr1(&[1.0, 0.0, 0.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0); + + // SciPy: u_values=[0, 1, 2], v_values=[1, 2, 3] + let u = arr1(&[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0]); + let v = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0); + } - let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); - let b = arr1(&[0.1, 0.2, 0.1, 0.15, 0.45]); - let ab = dist.distance(a.view(), b.view()); - assert_abs_diff_eq!(ab, 0.8, epsilon = 1e-5); + #[test] + /// Any distribution moved to itself should have a Wasserstein distance of zero. + fn wasserstein_same_distribution() { + let dist = WassersteinDist; + + // SciPy: u_values=[1, 2, 3], v_values=[2, 1, 3] + let u = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); + let v = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.0); + + // SciPy: u_values=[1, 1, 1, 4], v_values=[4, 1], u_weights=[1, 1, 1, 1], v_weights=[1, 3] + let u = arr1(&[0.0, 0.75, 0.0, 0.0, 0.25]); + let v = arr1(&[0.0, 0.75, 0.0, 0.0, 0.25]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.0); + } - let a = arr1(&[0.3, 0.2, 0.15, 0.10, 0.25]); - let b = arr1(&[0.1, 0.2, 0.05, 0.20, 0.45]); - let ab = dist.distance(a.view(), b.view()); - assert_abs_diff_eq!(ab, 0.9, epsilon = 1e-5); + #[test] + /// If the whole distribution is shifted by x, then the Wasserstein distance should be the norm of x. + fn wasserstein_shift() { + let dist = WassersteinDist; + + // SciPy: u_values=[0], v_values=[1] + let u = arr1(&[1.0, 0.0]); + let v = arr1(&[0.0, 1.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0); + + // SciPy: u_values=[-5], v_values=[5] + let u = arr1(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + let v = arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 10.0); + + // SciPy: u_values=[1, 2, 3, 4, 5], v_values=[11, 12, 13, 14, 15] + let u = arr1(&[ + 0.2, 0.2, 0.2, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]); + let v = arr1(&[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.2, 0.2, 0.2, 0.2, + ]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 10.0); + } - let a = arr1(&[0.35, 0.15, 0.15, 0.10, 0.25]); - let b = arr1(&[0.1, 0.20, 0.05, 0.20, 0.45]); - let ab = dist.distance(a.view(), b.view()); - assert_abs_diff_eq!(ab, 0.95, epsilon = 1e-5); + #[test] + fn wasserstein_inf_values() { + let dist = WassersteinDist; - let a = arr2(&[[0.3, 0.2, 0.15, 0.10, 0.25], [0.35, 0.15, 0.15, 0.10, 0.25]]); - let b = arr2(&[[0.1, 0.2, 0.05, 0.20, 0.45], [0.1, 0.20, 0.05, 0.20, 0.45]]); - let ab = dist.distance(a.view(), b.view()); - assert_abs_diff_eq!(ab, 0.9 + 0.95, epsilon = 1e-5); + let u = arr1(&[1.0, f64::INFINITY]); + let v = arr1(&[1.0, 0.0]); + assert!(dist.distance(u.view(), v.view()).is_infinite()); - let a = arr1(&[f64::INFINITY, 6.6]); - let b = arr1(&[4.4, f64::NEG_INFINITY]); - assert!(dist.distance(a.view(), b.view()).is_infinite()); + let u = arr1(&[1.0, f64::INFINITY]); + let v = arr1(&[1.0, f64::INFINITY]); + assert!(dist.distance(u.view(), v.view()).is_nan()); } } From 4ee664343f4fe958481432db62e75768da72e33e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Nordstr=C3=B6m?= Date: Mon, 25 May 2026 13:35:10 +0300 Subject: [PATCH 3/3] Added SciPy version that was used for comparison to Wasserstein distance documentation --- algorithms/linfa-nn/src/distance.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithms/linfa-nn/src/distance.rs b/algorithms/linfa-nn/src/distance.rs index 95847e041..d941ab705 100644 --- a/algorithms/linfa-nn/src/distance.rs +++ b/algorithms/linfa-nn/src/distance.rs @@ -122,7 +122,7 @@ impl Distance for LpDist { /// [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance. /// /// The function accepts histograms where each array element is the probability mass at that index. -/// This differs from SciPy's `wasserstein_distance` which instead accepts support values and weights, +/// This differs from SciPy's (v1.17.0) `wasserstein_distance` which instead accepts support values and weights, /// then builds the histograms internally. #[cfg_attr( feature = "serde", @@ -237,7 +237,7 @@ mod test { dist_test(WassersteinDist, 4.2); } - // The following Wasserstein tests are from SciPy. + // The following Wasserstein tests are from SciPy (v1.17.0). // However, since SciPy Wasserstein distance has different API as ours, // we need to first transform the SciPy parameters into histograms that our API accepts. //