diff --git a/algorithms/linfa-nn/src/distance.rs b/algorithms/linfa-nn/src/distance.rs index de4af9531..d941ab705 100644 --- a/algorithms/linfa-nn/src/distance.rs +++ b/algorithms/linfa-nn/src/distance.rs @@ -118,6 +118,32 @@ impl Distance for LpDist { } } +/// [Wasserstein](https://en.wikipedia.org/wiki/Wasserstein_metric) or +/// [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance. +/// +/// The function accepts histograms where each array element is the probability mass at that index. +/// This differs from SciPy's (v1.17.0) `wasserstein_distance` which instead accepts support values and weights, +/// then builds the histograms internally. +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WassersteinDist; +impl Distance for WassersteinDist { + #[inline] + fn distance(&self, a: ArrayView, b: ArrayView) -> F { + let mut cumulative_diff = F::zero(); + let mut dist = F::zero(); + Zip::from(&a).and(&b).for_each(|&a, &b| { + cumulative_diff += a - b; + dist += cumulative_diff.abs() + }); + dist + } +} + /// Computes a similarity matrix with gaussian kernel and scaling parameter `eps` /// /// The generated matrix is a upper triangular matrix with dimension NxN (number of observations) and contains the similarity between all permutations of observations @@ -157,6 +183,7 @@ mod test { has_autotraits::(); has_autotraits::(); has_autotraits::>(); + has_autotraits::(); } fn dist_test>(dist: D, result: f64) { @@ -204,4 +231,98 @@ mod test { fn lp_dist() { dist_test(LpDist(3.3), 4.635); } + + #[test] + fn wasserstein_dist() { + dist_test(WassersteinDist, 4.2); + } + + // The following Wasserstein tests are from SciPy (v1.17.0). + // However, since SciPy Wasserstein distance has different API as ours, + // we need to first transform the SciPy parameters into histograms that our API accepts. + // + // For example, SciPy values `[1, 3]` with weights `[1, 9]`: + // At index 1 we have weight 1 out of total weight 10 => 0.1. + // At index 3 we have weight 9 out of total weight 10 => 0.9. + // Thus we get a histogram [0.0, 0.1, 0.0, 0.9] + + #[test] + /// For basic distributions, the value of the Wasserstein distance is straightforward. + fn wasserstein_simple() { + let dist = WassersteinDist; + + // SciPy: u_values=[0, 1], v_values=[0], u_weights=[1, 1], v_weights=[1] + let u = arr1(&[0.5, 0.5]); + let v = arr1(&[1.0, 0.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.5); + + // SciPy: u_values=[0, 1], v_values=[0], u_weights=[3, 1], v_weights=[1] + let u = arr1(&[0.75, 0.25]); + let v = arr1(&[1.0, 0.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.25); + + // SciPy: u_values=[0, 2], v_values=[0], u_weights=[1, 1], v_weights=[1] + let u = arr1(&[0.5, 0.0, 0.5]); + let v = arr1(&[1.0, 0.0, 0.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0); + + // SciPy: u_values=[0, 1, 2], v_values=[1, 2, 3] + let u = arr1(&[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0]); + let v = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0); + } + + #[test] + /// Any distribution moved to itself should have a Wasserstein distance of zero. + fn wasserstein_same_distribution() { + let dist = WassersteinDist; + + // SciPy: u_values=[1, 2, 3], v_values=[2, 1, 3] + let u = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); + let v = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.0); + + // SciPy: u_values=[1, 1, 1, 4], v_values=[4, 1], u_weights=[1, 1, 1, 1], v_weights=[1, 3] + let u = arr1(&[0.0, 0.75, 0.0, 0.0, 0.25]); + let v = arr1(&[0.0, 0.75, 0.0, 0.0, 0.25]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.0); + } + + #[test] + /// If the whole distribution is shifted by x, then the Wasserstein distance should be the norm of x. + fn wasserstein_shift() { + let dist = WassersteinDist; + + // SciPy: u_values=[0], v_values=[1] + let u = arr1(&[1.0, 0.0]); + let v = arr1(&[0.0, 1.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0); + + // SciPy: u_values=[-5], v_values=[5] + let u = arr1(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + let v = arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 10.0); + + // SciPy: u_values=[1, 2, 3, 4, 5], v_values=[11, 12, 13, 14, 15] + let u = arr1(&[ + 0.2, 0.2, 0.2, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]); + let v = arr1(&[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.2, 0.2, 0.2, 0.2, + ]); + assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 10.0); + } + + #[test] + fn wasserstein_inf_values() { + let dist = WassersteinDist; + + let u = arr1(&[1.0, f64::INFINITY]); + let v = arr1(&[1.0, 0.0]); + assert!(dist.distance(u.view(), v.view()).is_infinite()); + + let u = arr1(&[1.0, f64::INFINITY]); + let v = arr1(&[1.0, f64::INFINITY]); + assert!(dist.distance(u.view(), v.view()).is_nan()); + } }