Skip to content

Commit 538bca2

Browse files
author
martinspetlik
committed
correlated field
1 parent 0edaf14 commit 538bca2

1 file changed

Lines changed: 58 additions & 26 deletions

File tree

mlmc/random/correlated_field.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -401,68 +401,100 @@ def _sample(self):
401401

402402
class GSToolsSpatialCorrelatedField(RandomFieldBase):
403403
"""
404-
Spatially correlated random field using the GSTools library.
404+
Spatially correlated random field generator using GSTools.
405405
406-
Uses Fourier modes to generate spatial random fields efficiently.
406+
This class acts as an adapter between :mod:`gstools` and the MLMC
407+
random field interface (:class:`mlmc.random.random_field_base.RandomFieldBase`).
408+
It supports 1D, 2D, and 3D random fields with optional logarithmic transformation,
409+
and can generate fields on both structured and unstructured grids.
407410
"""
408411

409-
def __init__(self, model, mode_no=1000, log=False, sigma=1, seed=None):
412+
def __init__(self, model, mode_no=1000, log=False, sigma=1, seed=None, mode=None, structured=False):
410413
"""
411-
Initialize GSTools-based spatial random field.
414+
Initialize a spatially correlated random field generator.
412415
413-
:param model: gstools covariance model (subclass of gstools.CovModel)
414-
:param mode_no: Number of Fourier modes (default 1000)
415-
:param log: If True, output field is exponentiated
416-
:param sigma: Standard deviation
417-
:param seed: Optional random seed for reproducibility
416+
:param model: Covariance model instance (subclass of ``gstools.covmodel.CovModel``)
417+
defining the spatial correlation structure.
418+
:param mode_no: Number of Fourier modes used in the random field generation.
419+
Default is 1000.
420+
:param log: If True, applies an exponential transformation to obtain
421+
a lognormal field. Default is False.
422+
:param sigma: Standard deviation scaling factor applied to the generated field.
423+
Default is 1.
424+
:param seed: Random seed for reproducibility. Default is None.
425+
:param mode: Sampling mode for GSTools SRF. Use "fft" for structured grids or
426+
None for unstructured. Default is None.
427+
:param structured: If True, assumes a structured grid for field evaluation.
428+
Default is False.
418429
"""
419430
self.model = model
420431
self.mode_no = mode_no
421-
self.srf = gstools.SRF(model, mode_no=mode_no, seed=seed)
432+
if mode == "fft":
433+
self.srf = gstools.SRF(model, mode="fft", seed=seed)
434+
else:
435+
self.srf = gstools.SRF(model, mode_no=mode_no, seed=seed)
422436
self.mu = self.srf.mean
423437
self.sigma = sigma
424438
self.dim = model.dim
425439
self.log = log
440+
self.structured = structured
426441

427442
def change_srf(self, seed):
428443
"""
429-
Generate a new spatial random field with a different seed.
444+
Reinitialize the GSTools random field with a new random seed.
430445
431-
:param seed: Random seed
446+
:param seed: Random seed used to reinitialize the underlying
447+
:class:`gstools.SRF` instance.
448+
:return: None
432449
"""
433450
self.srf = gstools.SRF(self.model, seed=seed, mode_no=self.mode_no)
434451

435-
def random_field(self):
452+
def random_field(self, seed=None):
436453
"""
437-
Evaluate the spatial random field at the current points.
454+
Generate a raw random field realization (without scaling or transformation).
438455
439-
:return: Field values (np.ndarray)
456+
:param seed: Optional random seed for reproducibility. Default is None.
457+
:return: numpy.ndarray
458+
Field values evaluated at the points defined by :meth:`set_points`.
440459
"""
441460
if self.dim == 1:
442461
x = self.points
443-
x = x.reshape(len(x),)
444-
return self.srf((x,))
462+
x.reshape(len(x))
463+
field = self.srf((x,))
445464
elif self.dim == 2:
446465
x, y = self.points.T
447466
x = x.reshape(len(x), 1)
448467
y = y.reshape(len(y), 1)
449-
return self.srf((x, y))
450-
else: # dim == 3
468+
field = self.srf((x, y))
469+
else:
451470
x, y, z = self.points.T
452471
x = x.reshape(len(x), 1)
453472
y = y.reshape(len(y), 1)
454473
z = z.reshape(len(z), 1)
455-
return self.srf((x, y, z))
456474

457-
def sample(self):
475+
if self.structured:
476+
field = self.srf([np.squeeze(x), np.squeeze(y), np.squeeze(z)], seed=seed)
477+
field = field.flatten()
478+
else:
479+
if seed is not None:
480+
field = self.srf(self.points.T, seed=seed)
481+
else:
482+
field = self.srf(self.points.T)
483+
return field
484+
485+
def sample(self, seed=None):
458486
"""
459-
Generate a realization of the GSTools spatial random field.
487+
Evaluate the scaled random field at the defined points.
460488
461-
:return: Field values (np.ndarray)
489+
:param seed: Optional random seed for reproducibility. Default is None.
490+
:return: numpy.ndarray
491+
Field values evaluated at the defined points, scaled by ``sigma``
492+
and shifted by ``mu``. If ``log=True``, returns
493+
``exp(sigma * field + mu)`` instead.
462494
"""
463-
field = self.random_field()
464-
field = self.sigma * field + self.mu
465-
return np.exp(field) if self.log else field
495+
if not self.log:
496+
return self.sigma * self.random_field(seed) + self.mu
497+
return np.exp(self.sigma * self.random_field(seed) + self.mu)
466498

467499

468500
class FourierSpatialCorrelatedField(RandomFieldBase):

0 commit comments

Comments
 (0)