Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions algorithms/linfa-nn/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,32 @@ impl<F: Float> Distance<F> for LpDist<F> {
}
}

/// [Wasserstein](https://en.wikipedia.org/wiki/Wasserstein_metric) or
/// [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance.
///
/// The function accepts histograms where each array element is the probability mass at that index.
/// This differs from SciPy's (v1.17.0) `wasserstein_distance` which instead accepts support values and weights,
/// then builds the histograms internally.
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WassersteinDist;
impl<F: Float> Distance<F> for WassersteinDist {
#[inline]
fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
let mut cumulative_diff = F::zero();
let mut dist = F::zero();
Zip::from(&a).and(&b).for_each(|&a, &b| {
cumulative_diff += a - b;
dist += cumulative_diff.abs()
});
dist
}
}

/// Computes a similarity matrix with gaussian kernel and scaling parameter `eps`
///
/// The generated matrix is a upper triangular matrix with dimension NxN (number of observations) and contains the similarity between all permutations of observations
Expand Down Expand Up @@ -157,6 +183,7 @@ mod test {
has_autotraits::<L2Dist>();
has_autotraits::<LInfDist>();
has_autotraits::<LpDist<f64>>();
has_autotraits::<WassersteinDist>();
}

fn dist_test<D: Distance<f64>>(dist: D, result: f64) {
Expand Down Expand Up @@ -204,4 +231,98 @@ mod test {
fn lp_dist() {
dist_test(LpDist(3.3), 4.635);
}

#[test]
fn wasserstein_dist() {
dist_test(WassersteinDist, 4.2);
}

// The following Wasserstein tests are from SciPy (v1.17.0).
// However, since SciPy Wasserstein distance has different API as ours,
// we need to first transform the SciPy parameters into histograms that our API accepts.
//
// For example, SciPy values `[1, 3]` with weights `[1, 9]`:
// At index 1 we have weight 1 out of total weight 10 => 0.1.
// At index 3 we have weight 9 out of total weight 10 => 0.9.
// Thus we get a histogram [0.0, 0.1, 0.0, 0.9]

#[test]
/// For basic distributions, the value of the Wasserstein distance is straightforward.
fn wasserstein_simple() {
let dist = WassersteinDist;

// SciPy: u_values=[0, 1], v_values=[0], u_weights=[1, 1], v_weights=[1]
let u = arr1(&[0.5, 0.5]);
let v = arr1(&[1.0, 0.0]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.5);

// SciPy: u_values=[0, 1], v_values=[0], u_weights=[3, 1], v_weights=[1]
let u = arr1(&[0.75, 0.25]);
let v = arr1(&[1.0, 0.0]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.25);

// SciPy: u_values=[0, 2], v_values=[0], u_weights=[1, 1], v_weights=[1]
let u = arr1(&[0.5, 0.0, 0.5]);
let v = arr1(&[1.0, 0.0, 0.0]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0);

// SciPy: u_values=[0, 1, 2], v_values=[1, 2, 3]
let u = arr1(&[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0]);
let v = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0);
}

#[test]
/// Any distribution moved to itself should have a Wasserstein distance of zero.
fn wasserstein_same_distribution() {
let dist = WassersteinDist;

// SciPy: u_values=[1, 2, 3], v_values=[2, 1, 3]
let u = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]);
let v = arr1(&[0.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.0);

// SciPy: u_values=[1, 1, 1, 4], v_values=[4, 1], u_weights=[1, 1, 1, 1], v_weights=[1, 3]
let u = arr1(&[0.0, 0.75, 0.0, 0.0, 0.25]);
let v = arr1(&[0.0, 0.75, 0.0, 0.0, 0.25]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 0.0);
}

#[test]
/// If the whole distribution is shifted by x, then the Wasserstein distance should be the norm of x.
fn wasserstein_shift() {
let dist = WassersteinDist;

// SciPy: u_values=[0], v_values=[1]
let u = arr1(&[1.0, 0.0]);
let v = arr1(&[0.0, 1.0]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 1.0);

// SciPy: u_values=[-5], v_values=[5]
let u = arr1(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
let v = arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 10.0);

// SciPy: u_values=[1, 2, 3, 4, 5], v_values=[11, 12, 13, 14, 15]
let u = arr1(&[
0.2, 0.2, 0.2, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
]);
let v = arr1(&[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.2, 0.2, 0.2, 0.2,
]);
assert_abs_diff_eq!(dist.distance(u.view(), v.view()), 10.0);
}

#[test]
fn wasserstein_inf_values() {
let dist = WassersteinDist;

let u = arr1(&[1.0, f64::INFINITY]);
let v = arr1(&[1.0, 0.0]);
assert!(dist.distance(u.view(), v.view()).is_infinite());

let u = arr1(&[1.0, f64::INFINITY]);
let v = arr1(&[1.0, f64::INFINITY]);
assert!(dist.distance(u.view(), v.view()).is_nan());
}
}
Loading