diff --git a/src/Furnace.Core/Util.fs b/src/Furnace.Core/Util.fs index 9de80d3a..cd265c74 100644 --- a/src/Furnace.Core/Util.fs +++ b/src/Furnace.Core/Util.fs @@ -34,9 +34,13 @@ type GlobalNestingLevel() = /// Contains operations relating to pseudo-random number generation. type Random() = static let mutable rnd = System.Random() + static let mutable cachedNormal = 0.0 + static let mutable hasCachedNormal = false /// Sets the random seed. - static member Seed(seed) = rnd <- System.Random(seed) + static member Seed(seed) = + rnd <- System.Random(seed) + hasCachedNormal <- false // Clear cache when seeding /// Samples a random value from the standard uniform distribution over the interval [0,1). static member Uniform() = rnd.NextDouble() @@ -46,13 +50,25 @@ type Random() = /// Samples a random value from the standard normal distribution with mean 0 and standard deviation 1. static member Normal() = - // Marsaglia polar method - // TODO: this is discarding one of the two samples that can be generated. For efficiency, we can keep the second sample around to return it in the next call. - let rec normal() = - let x, y = (rnd.NextDouble()) * 2.0 - 1.0, (rnd.NextDouble()) * 2.0 - 1.0 - let s = x * x + y * y - if s > 1.0 then normal() else x * sqrt (-2.0 * (log s) / s) - normal() + // Return cached sample if available + if hasCachedNormal then + hasCachedNormal <- false + cachedNormal + else + // Marsaglia polar method - generates two samples, cache the second one + let rec generatePair() = + let x, y = (rnd.NextDouble()) * 2.0 - 1.0, (rnd.NextDouble()) * 2.0 - 1.0 + let s = x * x + y * y + if s > 1.0 then + generatePair() + else + let multiplier = sqrt (-2.0 * (log s) / s) + let sample1 = x * multiplier + let sample2 = y * multiplier + cachedNormal <- sample2 + hasCachedNormal <- true + sample1 + generatePair() /// Samples a random value from the normal distribution with the given mean and standard deviation. static member Normal(mean, stddev) = mean + Random.Normal() * stddev