diff --git a/Cargo.toml b/Cargo.toml index adb990f..16e90d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ maintenance = { status = "passively-maintained" } [dependencies] rand = { version = "0.8.0", default-features = false } +num = "0.4" [dev-dependencies] rand = "0.8.0" diff --git a/src/lib.rs b/src/lib.rs index 8cca575..9a019a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,8 +17,8 @@ //! use rand::distributions::Distribution; //! //! let mut rng = rand::thread_rng(); -//! let mut zipf = zipf::ZipfDistribution::new(1000, 1.03).unwrap(); -//! let sample = zipf.sample(&mut rng); +//! let mut zipf = zipf::ZipfDistribution::::new(1000, 1.03).unwrap(); +//! let sample: usize = zipf.sample(&mut rng); //! ``` //! //! This implementation is effectively a direct port of Apache Common's @@ -31,12 +31,13 @@ #![warn(rust_2018_idioms)] +use num::FromPrimitive; use rand::Rng; /// Random number generator that generates Zipf-distributed random numbers using rejection /// inversion. #[derive(Clone, Copy)] -pub struct ZipfDistribution { +pub struct ZipfDistribution { /// Number of elements num_elements: f64, /// Exponent parameter of the distribution @@ -47,9 +48,10 @@ pub struct ZipfDistribution { h_integral_num_elements: f64, /// `2 - hIntegralInverse(hIntegral(2.5) - h(2)}` s: f64, + marker: PhantomData, } -impl ZipfDistribution { +impl ZipfDistribution { /// Creates a new [Zipf-distributed](https://en.wikipedia.org/wiki/Zipf's_law) /// random number generator. /// @@ -65,17 +67,14 @@ impl ZipfDistribution { let z = ZipfDistribution { num_elements: num_elements as f64, exponent, - h_integral_x1: ZipfDistribution::h_integral(1.5, exponent) - 1f64, - h_integral_num_elements: ZipfDistribution::h_integral( - num_elements as f64 + 0.5, - exponent, - ), + h_integral_x1: Self::h_integral(1.5, exponent) - 1f64, + h_integral_num_elements: Self::h_integral(num_elements as f64 + 0.5, exponent), s: 2f64 - - ZipfDistribution::h_integral_inv( - ZipfDistribution::h_integral(2.5, exponent) - - ZipfDistribution::h(2f64, exponent), + - Self::h_integral_inv( + Self::h_integral(2.5, exponent) - Self::h(2f64, exponent), exponent, ), + marker: PhantomData, }; // populate cache @@ -84,8 +83,8 @@ impl ZipfDistribution { } } -impl ZipfDistribution { - fn next(&self, rng: &mut R) -> usize { +impl ZipfDistribution { + fn next(&self, rng: &mut R) -> T { // The paper describes an algorithm for exponents larger than 1 (Algorithm ZRI). // // The original method uses @@ -108,17 +107,16 @@ impl ZipfDistribution { let hnum = self.h_integral_num_elements; loop { - use std::cmp; let u: f64 = hnum + rng.gen::() * (self.h_integral_x1 - hnum); // u is uniformly distributed in (h_integral_x1, h_integral_num_elements] - let x: f64 = ZipfDistribution::h_integral_inv(u, self.exponent); + let x: f64 = Self::h_integral_inv(u, self.exponent); // Limit k to the range [1, num_elements] if it would be outside // due to numerical inaccuracies. let k64 = x.max(1.0).min(self.num_elements); // float -> integer rounds towards zero - let k = cmp::max(1, k64 as usize); + let k = T::from_f64(k64.max(1.0)).unwrap(); // Here, the distribution of k is given by: // @@ -127,8 +125,7 @@ impl ZipfDistribution { // // where C = 1 / (h_integral_num_elements - h_integral_x1) if k64 - x <= self.s - || u >= ZipfDistribution::h_integral(k64 + 0.5, self.exponent) - - ZipfDistribution::h(k64, self.exponent) + || u >= Self::h_integral(k64 + 0.5, self.exponent) - Self::h(k64, self.exponent) { // Case k = 1: // @@ -173,14 +170,14 @@ impl ZipfDistribution { } } -impl rand::distributions::Distribution for ZipfDistribution { - fn sample(&self, rng: &mut R) -> usize { +impl rand::distributions::Distribution for ZipfDistribution { + fn sample(&self, rng: &mut R) -> T { self.next(rng) } } -use std::fmt; -impl fmt::Debug for ZipfDistribution { +use std::{fmt, marker::PhantomData}; +impl fmt::Debug for ZipfDistribution { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("ZipfDistribution") .field("e", &self.exponent) @@ -189,7 +186,7 @@ impl fmt::Debug for ZipfDistribution { } } -impl ZipfDistribution { +impl ZipfDistribution { /// Computes `H(x)`, defined as /// /// - `(x^(1 - exponent) - 1) / (1 - exponent)`, if `exponent != 1` @@ -262,7 +259,7 @@ mod test { // sample a bunch let mut buckets = vec![0; N]; for _ in 0..samples { - let sample = zipf.sample(&mut rng); + let sample: usize = zipf.sample(&mut rng); buckets[sample - 1] += 1; } @@ -323,13 +320,13 @@ mod test { #[test] fn debug() { - eprintln!("{:?}", ZipfDistribution::new(100, 1.0).unwrap()); + eprintln!("{:?}", ZipfDistribution::::new(100, 1.0).unwrap()); } #[test] fn errs() { - ZipfDistribution::new(0, 1.0).unwrap_err(); - ZipfDistribution::new(100, 0.0).unwrap_err(); - ZipfDistribution::new(100, -1.0).unwrap_err(); + ZipfDistribution::::new(0, 1.0).unwrap_err(); + ZipfDistribution::::new(100, 0.0).unwrap_err(); + ZipfDistribution::::new(100, -1.0).unwrap_err(); } }