Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/Furnace.Core/Util.fs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ type GlobalNestingLevel() =
/// Contains operations relating to pseudo-random number generation.
type Random() =
static let mutable rnd = System.Random()
static let mutable cachedNormal = 0.0
static let mutable hasCachedNormal = false

/// Sets the random seed.
static member Seed(seed) = rnd <- System.Random(seed)
static member Seed(seed) =
rnd <- System.Random(seed)
hasCachedNormal <- false // Clear cache when seeding

/// Samples a random value from the standard uniform distribution over the interval [0,1).
static member Uniform() = rnd.NextDouble()
Expand All @@ -46,13 +50,25 @@ type Random() =

/// Samples a random value from the standard normal distribution with mean 0 and standard deviation 1.
static member Normal() =
// Marsaglia polar method
// TODO: this is discarding one of the two samples that can be generated. For efficiency, we can keep the second sample around to return it in the next call.
let rec normal() =
let x, y = (rnd.NextDouble()) * 2.0 - 1.0, (rnd.NextDouble()) * 2.0 - 1.0
let s = x * x + y * y
if s > 1.0 then normal() else x * sqrt (-2.0 * (log s) / s)
normal()
// Return cached sample if available
if hasCachedNormal then
hasCachedNormal <- false
cachedNormal
else
// Marsaglia polar method - generates two samples, cache the second one
let rec generatePair() =
let x, y = (rnd.NextDouble()) * 2.0 - 1.0, (rnd.NextDouble()) * 2.0 - 1.0
let s = x * x + y * y
if s > 1.0 then
generatePair()
else
let multiplier = sqrt (-2.0 * (log s) / s)
let sample1 = x * multiplier
let sample2 = y * multiplier
cachedNormal <- sample2
hasCachedNormal <- true
sample1
generatePair()

/// Samples a random value from the normal distribution with the given mean and standard deviation.
static member Normal(mean, stddev) = mean + Random.Normal() * stddev
Expand Down
Loading