Add Wasserstein distance (aka Earth Mover's distance) to linfa-nn#440
Conversation
| assert_abs_diff_eq!(ab, 0.9, epsilon = 1e-5); | ||
|
|
||
| let a = arr1(&[0.35, 0.15, 0.15, 0.10, 0.25]); | ||
| let b = arr1(&[0.1, 0.20, 0.05, 0.20, 0.45]); |
There was a problem hiding this comment.
Slight variations in input data, maybe add a comment to explain what we should see as a result. You can also test against scipy examples
| let a = arr2(&[[0.3, 0.2, 0.15, 0.10, 0.25], [0.35, 0.15, 0.15, 0.10, 0.25]]); | ||
| let b = arr2(&[[0.1, 0.2, 0.05, 0.20, 0.45], [0.1, 0.20, 0.05, 0.20, 0.45]]); | ||
| let ab = dist.distance(a.view(), b.view()); | ||
| assert_abs_diff_eq!(ab, 0.9 + 0.95, epsilon = 1e-5); |
There was a problem hiding this comment.
You could use previous a, b, ab variables
| serde(crate = "serde_crate") | ||
| )] | ||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||
| pub struct EarthMoverDist; |
There was a problem hiding this comment.
I would go with WassersteinDist as you put it first above and it will be consistent with scipy impl.
There was a problem hiding this comment.
I updated it to WassersteinDist and changed the tests to be according to SciPy where relevant since the APIs are a bit different
There was a problem hiding this comment.
Ok I was not aware of the difference of API. Maybe specify once the version of scipy you checked against.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #440 +/- ##
==========================================
+ Coverage 77.54% 77.56% +0.02%
==========================================
Files 106 106
Lines 7578 7585 +7
==========================================
+ Hits 5876 5883 +7
Misses 1702 1702 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
relf
left a comment
There was a problem hiding this comment.
Could you rebase your changes on current master?
Done, and I added the SciPy version I compared against to comments |
Add Earth Mover's Distance, aka. Wasserstein to linfa-nn