From 1a9fadb653f47c7aa54fcdb2d69ef17d8e730a65 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 15 Aug 2023 17:41:19 -0400 Subject: [PATCH 01/14] DEV: make sure all priors return float when needed --- bilby/core/prior/analytical.py | 95 ++++++++++++++------------------ bilby/core/prior/base.py | 7 ++- bilby/core/prior/dict.py | 45 +++++++++------ bilby/core/prior/interpolated.py | 8 +-- bilby/core/utils/calculus.py | 25 ++++++++- bilby/gw/prior.py | 27 ++------- test/core/prior/prior_test.py | 9 +++ 7 files changed, 113 insertions(+), 103 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index f9f57ec90..804cecc4d 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -169,8 +169,10 @@ def cdf(self, val): _cdf = (np.log(val / self.minimum) / np.log(self.maximum / self.minimum)) else: - _cdf = np.atleast_1d(val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) / \ - (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + _cdf = ( + (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + ) _cdf = np.minimum(_cdf, 1) _cdf = np.maximum(_cdf, 0) return _cdf @@ -367,16 +369,16 @@ def ln_prob(self, val): return np.nan_to_num(- np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) def cdf(self, val): - val = np.atleast_1d(val) norm = 0.5 / np.log(self.maximum / self.minimum) - cdf = np.zeros((len(val))) - lower_indices = np.where(np.logical_and(-self.maximum <= val, val <= -self.minimum))[0] - upper_indices = np.where(np.logical_and(self.minimum <= val, val <= self.maximum))[0] - cdf[lower_indices] = -norm * np.log(-val[lower_indices] / self.maximum) - cdf[np.where(np.logical_and(-self.minimum < val, val < self.minimum))] = 0.5 - cdf[upper_indices] = 0.5 + norm * np.log(val[upper_indices] / self.minimum) - cdf[np.where(self.maximum < val)] = 1 - return cdf + _cdf = ( + -norm * np.log(abs(val) / self.maximum) + * (val <= -self.minimum) * (val >= -self.maximum) + + (0.5 + norm * np.log(abs(val) / self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 0.5 * (val >= -self.minimum) * (val <= self.minimum) + + 1 * (val > self.maximum) + ) + return _cdf class Cosine(Prior): @@ -426,10 +428,12 @@ def prob(self, val): return np.cos(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = np.atleast_1d((np.sin(val) - np.sin(self.minimum)) / - (np.sin(self.maximum) - np.sin(self.minimum))) - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + (np.sin(val) - np.sin(self.minimum)) + / (np.sin(self.maximum) - np.sin(self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -480,10 +484,12 @@ def prob(self, val): return np.sin(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = np.atleast_1d((np.cos(val) - np.cos(self.minimum)) / - (np.cos(self.maximum) - np.cos(self.minimum))) - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + (np.cos(val) - np.cos(self.minimum)) + / (np.cos(self.maximum) - np.cos(self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -625,11 +631,13 @@ def prob(self, val): / self.sigma / self.normalisation * self.is_in_prior_range(val) def cdf(self, val): - val = np.atleast_1d(val) - _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( - (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + ( + erf((val - self.mu) / 2 ** 0.5 / self.sigma) + - erf((self.minimum - self.mu) / 2 ** 0.5 / self.sigma) + ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -1367,6 +1375,8 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, raise ValueError("For the Fermi-Dirac prior the values of sigma and r " "must be positive.") + self.expr = np.exp(self.r) + def rescale(self, val): """ 'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior. @@ -1384,21 +1394,8 @@ def rescale(self, val): .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 `_, 2017. """ - inv = (-np.exp(-1. * self.r) + (1. + np.exp(self.r)) ** -val + - np.exp(-1. * self.r) * (1. + np.exp(self.r)) ** -val) - - # if val is 1 this will cause inv to be negative (due to numerical - # issues), so return np.inf - if isinstance(val, (float, int)): - if inv < 0: - return np.inf - else: - return -self.sigma * np.log(inv) - else: - idx = inv >= 0. - tmpinv = np.inf * np.ones(len(np.atleast_1d(val))) - tmpinv[idx] = -self.sigma * np.log(inv[idx]) - return tmpinv + inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr + return -self.sigma * np.log(np.maximum(inv, 0)) def prob(self, val): """Return the prior probability of val. @@ -1411,7 +1408,11 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return ( + (np.exp((val - self.mu) / self.sigma) + 1)**-1 + / (self.sigma * np.log1p(self.expr)) + * (val >= self.minimum) + ) def ln_prob(self, val): """Return the log prior probability of val. @@ -1424,19 +1425,7 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - - norm = -np.log(self.sigma * np.log(1. + np.exp(self.r))) - if isinstance(val, (float, int)): - if val < self.minimum: - return -np.inf - else: - return norm - np.logaddexp((val / self.sigma) - self.r, 0.) - else: - val = np.atleast_1d(val) - lnp = -np.inf * np.ones(len(val)) - idx = val >= self.minimum - lnp[idx] = norm - np.logaddexp((val[idx] / self.sigma) - self.r, 0.) - return lnp + return np.log(self.prob(val)) class DiscreteValues(Prior): diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index fb3f00773..5ca28de28 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -5,7 +5,6 @@ import numpy as np import scipy.stats -from scipy.interpolate import interp1d from ..utils import ( infer_args_from_method, @@ -13,6 +12,7 @@ decode_bilby_json, logger, get_dict_with_properties, + WrappedInterp1d as interp1d, ) @@ -178,7 +178,10 @@ def cdf(self, val): cdf = cumulative_trapezoid(pdf, x, initial=0) interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, fill_value=(0, 1)) - return interp(val) + output = interp(val) + if isinstance(val, (int, float)): + output = float(output) + return output def ln_prob(self, val): """Return the prior ln probability of val, this should be overwritten diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 1b3bac5c8..32c3bc780 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -504,7 +504,9 @@ def check_efficiency(n_tested, n_valid): def normalize_constraint_factor( self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10 ): - if keys in self._cached_normalizations.keys(): + if len(self.constraint_keys) == 0: + return 1 + elif keys in self._cached_normalizations.keys(): return self._cached_normalizations[keys] else: factor_estimates = [ @@ -566,8 +568,10 @@ def check_prob(self, sample, prob): return 0.0 else: constrained_prob = np.zeros_like(prob) - keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_prob[keep] = prob[keep] * ratio + in_bounds = np.isfinite(prob) + subsample = {key: sample[key][in_bounds] for key in sample} + keep = np.array(self.evaluate_constraints(subsample), dtype=bool) + constrained_prob[in_bounds] = prob[in_bounds] * keep * ratio return constrained_prob def ln_prob(self, sample, axis=None, normalized=True): @@ -608,8 +612,10 @@ def check_ln_prob(self, sample, ln_prob, normalized=True): return -np.inf else: constrained_ln_prob = -np.inf * np.ones_like(ln_prob) - keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio) + in_bounds = np.isfinite(ln_prob) + subsample = {key: sample[key][in_bounds] for key in sample} + keep = np.log(np.array(self.evaluate_constraints(subsample), dtype=bool)) + constrained_ln_prob[in_bounds] = ln_prob[in_bounds] + keep + np.log(ratio) return constrained_ln_prob def cdf(self, sample): @@ -643,12 +649,9 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ - theta = list(theta) - samples = [] - for key, units in zip(keys, theta): - samps = self[key].rescale(units) - samples += list(np.asarray(samps).flatten()) - return samples + return list( + [self[key].rescale(sample) for key, sample in zip(keys, theta)] + ) def test_redundancy(self, key, disable_logging=False): """Empty redundancy test, should be overwritten in subclasses""" @@ -670,9 +673,7 @@ def test_has_redundant_keys(self): del temp[key] if temp.test_redundancy(key, disable_logging=True): logger.warning( - "{} is a redundant key in this {}.".format( - key, self.__class__.__name__ - ) + f"{key} is a redundant key in this {self.__class__.__name__}." ) redundant = True return redundant @@ -880,6 +881,7 @@ def rescale(self, keys, theta): self._check_resolved() self._update_rescale_keys(keys) result = dict() + joint = dict() for key, index in zip( self.sorted_keys_without_fixed_parameters, self._rescale_indexes ): @@ -887,10 +889,17 @@ def rescale(self, keys, theta): theta[index], **self.get_required_variables(key) ) self[key].least_recently_sampled = result[key] - samples = [] - for key in keys: - samples += list(np.asarray(result[key]).flatten()) - return samples + if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: + joint[self[key].dist.distname] = [key] + elif isinstance(self[key], JointPrior): + joint[self[key].dist.distname].append(key) + for names in joint.values(): + values = list() + for key in names: + values = np.concatenate([values, result[key]]) + for key, value in zip(names, values): + result[key] = value + return list([result[key] for key in keys]) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 2cee669d9..6a7b383a5 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -1,8 +1,7 @@ import numpy as np -from scipy.interpolate import interp1d from .base import Prior -from ..utils import logger +from ..utils import logger, WrappedInterp1d as interp1d class Interped(Prior): @@ -86,10 +85,7 @@ def rescale(self, val): This maps to the inverse CDF. This is done using interpolation. """ - rescaled = self.inverse_cumulative_distribution(val) - if rescaled.shape == (): - rescaled = float(rescaled) - return rescaled + return self.inverse_cumulative_distribution(val) @property def minimum(self): diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index ac6fcefcd..e10ce6111 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,7 +1,7 @@ import math import numpy as np -from scipy.interpolate import RectBivariateSpline +from scipy.interpolate import interp1d, RectBivariateSpline from scipy.special import logsumexp from .log import logger @@ -219,6 +219,29 @@ def __call__(self, x, y, dx=0, dy=0, grid=False): return result +class WrappedInterp1d(interp1d): + """ + A wrapper around scipy interp1d which sets equality-by-instantiation and + makes sure that the output is a float if the input is a float or int. + """ + def __call__(self, x): + output = super().__call__(x) + if isinstance(x, (float, int)): + output = output.item() + return output + + def __eq__(self, other): + for key in self.__dict__: + if type(self.__dict__[key]) is np.ndarray: + if not np.array_equal(self.__dict__[key], other.__dict__[key]): + return False + elif key == "_spline": + pass + elif getattr(self, key) != getattr(other, key): + return False + return True + + def round_up_to_power_of_two(x): """Round up to the next power of two diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 6885af67b..08d7178e0 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -3,7 +3,7 @@ import numpy as np from scipy.integrate import quad -from scipy.interpolate import InterpolatedUnivariateSpline, interp1d +from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm @@ -13,7 +13,7 @@ ConditionalPriorDict, ConditionalBasePrior, BaseJointPriorDist, JointPrior, JointPriorDistError, ) -from ..core.utils import infer_args_from_method, logger, random +from ..core.utils import infer_args_from_method, logger, random, WrappedInterp1d as interp1d from .conversion import ( convert_to_lal_binary_black_hole_parameters, convert_to_lal_binary_neutron_star_parameters, generate_mass_parameters, @@ -379,21 +379,6 @@ def __init__(self, minimum, maximum, name='chirp_mass', name=name, latex_label=latex_label, unit=unit, boundary=boundary) -class WrappedInterp1d(interp1d): - """ A wrapper around scipy interp1d which sets equality-by-instantiation """ - def __eq__(self, other): - - for key in self.__dict__: - if type(self.__dict__[key]) is np.ndarray: - if not np.array_equal(self.__dict__[key], other.__dict__[key]): - return False - elif key == "_spline": - pass - elif getattr(self, key) != getattr(other, key): - return False - return True - - class UniformInComponentsMassRatio(Prior): r""" Prior distribution for chirp mass which is uniform in component masses. @@ -437,7 +422,7 @@ def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', latex_label=latex_label, unit=unit, boundary=boundary) self.norm = self._integral(maximum) - self._integral(minimum) qs = np.linspace(minimum, maximum, 1000) - self.icdf = WrappedInterp1d( + self.icdf = interp1d( self.cdf(qs), qs, kind='cubic', bounds_error=False, fill_value=(minimum, maximum)) @@ -451,11 +436,7 @@ def cdf(self, val): def rescale(self, val): if self.equal_mass: val = 2 * np.minimum(val, 1 - val) - resc = self.icdf(val) - if resc.ndim == 0: - return resc.item() - else: - return resc + return self.icdf(val) def prob(self, val): in_prior = (val >= self.minimum) & (val <= self.maximum) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 37c6d93f3..95edb8ab2 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -380,6 +380,15 @@ def test_cdf_zero_below_domain(self): ) self.assertTrue(all(np.nan_to_num(prior.cdf(outside_domain)) == 0)) + def test_cdf_float_with_float_input(self): + for prior in self.priors: + if ( + bilby.core.prior.JointPrior in prior.__class__.__mro__ + and prior.maximum == np.inf + ): + continue + self.assertIsInstance(prior.cdf(prior.sample()), float) + def test_log_normal_fail(self): with self.assertRaises(ValueError): bilby.core.prior.LogNormal(name="test", unit="unit", mu=0, sigma=-1) From eafd06a3915595aeef289f7fda78509f67800776 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 12 Jul 2024 16:22:00 -0500 Subject: [PATCH 02/14] FEAT: add fermi-dirac CDF --- bilby/core/prior/analytical.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 804cecc4d..4e66cede8 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1427,6 +1427,33 @@ def ln_prob(self, val): """ return np.log(self.prob(val)) + def cdf(self, val): + """ + Evaluate the CDF of the Fermi-Dirac distribution using a slightly + modified form of Equation 23 of [1]_. + + Parameters + ========== + val: Union[float, int, array_like] + The value(s) to evaluate the CDF at + + Returns + ======= + Union[float, array_like]: + The CDF value(s) + + References + ========== + + .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 + `_, 2017. + """ + result = ( + (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) + / np.logaddexp(0, self.r) + ) + return np.clip(result, 0, 1) + class DiscreteValues(Prior): def __init__(self, values, name=None, latex_label=None, From a40cead381d3923370d6ae39633ad608cd70b82b Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 12 Jul 2024 16:35:08 -0500 Subject: [PATCH 03/14] TST: add testing of fermi dirac and symmetric log uniform priors --- test/core/prior/prior_test.py | 39 +++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 95edb8ab2..d45718204 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -86,6 +86,8 @@ def condition_func(reference_params, test_param): bilby.core.prior.Lorentzian(name="test", unit="unit", alpha=0, beta=1), bilby.core.prior.Gamma(name="test", unit="unit", k=1, theta=1), bilby.core.prior.ChiSquared(name="test", unit="unit", nu=2), + bilby.core.prior.FermiDirac(name="test", unit="unit", mu=1, sigma=1), + bilby.core.prior.SymmetricLogUniform(name="test", unit="unit", minimum=1e-2, maximum=1e2), bilby.gw.prior.AlignedSpin(name="test", unit="unit"), bilby.gw.prior.AlignedSpin( a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), @@ -235,6 +237,9 @@ def tearDown(self): def test_minimum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if isinstance(prior, bilby.gw.prior.AlignedSpin): # the edge of the prior is extremely suppressed for these priors # and so the rescale function doesn't quite return the lower bound @@ -263,6 +268,9 @@ def test_maximum_rescaling(self): def test_many_sample_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): @@ -281,6 +289,9 @@ def test_least_recently_sampled(self): def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue single_sample = prior.sample() self.assertTrue( (single_sample >= prior.minimum) & (single_sample <= prior.maximum) @@ -289,6 +300,9 @@ def test_sampling_single(self): def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue many_samples = prior.sample(5000) self.assertTrue( (all(many_samples >= prior.minimum)) @@ -311,6 +325,9 @@ def test_probability_above_domain(self): def test_probability_below_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if prior.minimum != -np.inf: outside_domain = np.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 @@ -369,6 +386,9 @@ def test_cdf_one_above_domain(self): def test_cdf_zero_below_domain(self): for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if ( bilby.core.prior.JointPrior in prior.__class__.__mro__ and prior.maximum == np.inf @@ -538,6 +558,9 @@ def test_probability_surrounding_domain(self): # skip delta function prior in this case if isinstance(prior, bilby.core.prior.DeltaFunction): continue + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) indomain = (surround_domain >= prior.minimum) | ( surround_domain <= prior.maximum @@ -552,11 +575,18 @@ def test_probability_surrounding_domain(self): self.assertTrue(all(prior.prob(surround_domain[outdomain]) == 0)) def test_normalized(self): - """Test that each of the priors are normalised, this needs care for delta function and Gaussian priors""" + """ + Test that each of the priors are normalised. + This needs extra care for priors defined on infinite domains and the + Cauchy, DeltaFunction, and SymmetricLogUniform priors are skipped + because they are too sharply peaked to be tested efficiently in this way. + """ for prior in self.priors: - if isinstance(prior, bilby.core.prior.DeltaFunction): - continue - if isinstance(prior, bilby.core.prior.Cauchy): + if isinstance(prior, ( + bilby.core.prior.DeltaFunction, + bilby.core.prior.Cauchy, + bilby.core.prior.SymmetricLogUniform + )): continue if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue @@ -770,6 +800,7 @@ def test_set_minimum_setting(self): bilby.core.prior.MultivariateGaussian, bilby.core.prior.FermiDirac, bilby.core.prior.Triangular, + bilby.core.prior.SymmetricLogUniform, bilby.gw.prior.HealPixPrior, ), ): From bd9ee07b13bbb7c1abad40637684c054c27853db Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 12 Jul 2024 16:38:40 -0500 Subject: [PATCH 04/14] FMT: remove extraneous whitespace --- bilby/core/prior/analytical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 4e66cede8..5ac9455ef 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1436,7 +1436,7 @@ def cdf(self, val): ========== val: Union[float, int, array_like] The value(s) to evaluate the CDF at - + Returns ======= Union[float, array_like]: From 71ac6233f4d953988e8d787cbff5ff78317cbc15 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 10:05:41 -0600 Subject: [PATCH 05/14] BUG: revert bad changes to equal comparisons --- bilby/core/prior/analytical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 5ac9455ef..f218e9cdd 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -375,7 +375,7 @@ def cdf(self, val): * (val <= -self.minimum) * (val >= -self.maximum) + (0.5 + norm * np.log(abs(val) / self.minimum)) * (val >= self.minimum) * (val <= self.maximum) - + 0.5 * (val >= -self.minimum) * (val <= self.minimum) + + 0.5 * (val > -self.minimum) * (val < self.minimum) + 1 * (val > self.maximum) ) return _cdf From 24fc64386202e6aa54e6ff0f11f318a3511311ae Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 10:33:40 -0600 Subject: [PATCH 06/14] BUG: Fix syntax error in return statement --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 32c3bc780..d21485cfd 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -899,7 +899,7 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - return list([result[key] for key in keys]) + return [list(np.asarray(result[key]).flatten()) for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From b0bc19670796af1d5dca256fb51e30d94583496d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 11:11:30 -0600 Subject: [PATCH 07/14] BUG: Fix list return --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index d21485cfd..c5751e32b 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -899,7 +899,7 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - return [list(np.asarray(result[key]).flatten()) for key in keys] + return [np.asarray(result[key]).flatten() for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 9fcdd4a4268af73b35183faec68e354c56e3a6d3 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 11:38:59 -0600 Subject: [PATCH 08/14] BUG: fix array type output for rescale --- bilby/core/prior/dict.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index c5751e32b..0e684635b 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -899,7 +899,10 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - return [np.asarray(result[key]).flatten() for key in keys] + # this is gross but can be removed whenever we switch to returning + # arrays, flatten converts 0-d arrays to 1-d and squeeze converts it + # back + return [np.asarray(result[key]).flatten().squeeze for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 27fe1df2fa1b5098a3c988964582f778ad72896a Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Jan 2025 02:14:14 -0600 Subject: [PATCH 09/14] BUG: Fix missing parentheses in squeeze method call --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 0e684635b..686988ab5 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -902,7 +902,7 @@ def rescale(self, keys, theta): # this is gross but can be removed whenever we switch to returning # arrays, flatten converts 0-d arrays to 1-d and squeeze converts it # back - return [np.asarray(result[key]).flatten().squeeze for key in keys] + return [np.asarray(result[key]).flatten().squeeze() for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 0ec8d9b2c59cbe828cbb895b9fb2b31847a294f4 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 4 Feb 2025 15:16:43 +0000 Subject: [PATCH 10/14] BUG: fix flattening logic --- bilby/core/prior/dict.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 686988ab5..5e917b064 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -899,10 +899,19 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - # this is gross but can be removed whenever we switch to returning - # arrays, flatten converts 0-d arrays to 1-d and squeeze converts it - # back - return [np.asarray(result[key]).flatten().squeeze() for key in keys] + + def safe_flatten(value): + """ + this is gross but can be removed whenever we switch to returning + arrays, flatten converts 0-d arrays to 1-d so has to be special + cased + """ + if isinstance(value, (float, int)): + return value + else: + return result[key].flatten() + + return [safe_flatten(result[key]) for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 7b80c24429b0c7aa83a7b5057c86c51b944eb0ec Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 4 Feb 2025 15:17:09 +0000 Subject: [PATCH 11/14] BUG: stop using arr[np.where(cond)] --- bilby/core/prior/slabspike.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 92664b15e..6910be608 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -88,11 +88,11 @@ def rescale(self, val): original_is_number = isinstance(val, Number) val = np.atleast_1d(val) - lower_indices = np.where(val < self.inverse_cdf_below_spike)[0] - intermediate_indices = np.where(np.logical_and( + lower_indices = val < self.inverse_cdf_below_spike + intermediate_indices = np.logical_and( self.inverse_cdf_below_spike <= val, - val <= self.inverse_cdf_below_spike + self.spike_height))[0] - higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0] + val <= (self.inverse_cdf_below_spike + self.spike_height)) + higher_indices = val > (self.inverse_cdf_below_spike + self.spike_height) res = np.zeros(len(val)) res[lower_indices] = self._contracted_rescale(val[lower_indices]) @@ -137,7 +137,7 @@ def prob(self, val): original_is_number = isinstance(val, Number) res = self.slab.prob(val) * self.slab_fraction res = np.atleast_1d(res) - res[np.where(val == self.spike_location)] = np.inf + res[val == self.spike_location] = np.inf if original_is_number: try: res = res[0] @@ -161,7 +161,7 @@ def ln_prob(self, val): original_is_number = isinstance(val, Number) res = self.slab.ln_prob(val) + np.log(self.slab_fraction) res = np.atleast_1d(res) - res[np.where(val == self.spike_location)] = np.inf + res[val == self.spike_location] = np.inf if original_is_number: try: res = res[0] @@ -185,7 +185,5 @@ def cdf(self, val): """ res = self.slab.cdf(val) * self.slab_fraction - res = np.atleast_1d(res) - indices_above_spike = np.where(val > self.spike_location)[0] - res[indices_above_spike] += self.spike_height + res += self.spike_height * (val > self.spike_location) return res From 5486aaf75c793cac04a98ac787c9c6f58cb99a73 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 21 Aug 2025 08:56:20 -0500 Subject: [PATCH 12/14] Address comments --- bilby/core/prior/dict.py | 6 ++++++ test/core/prior/prior_test.py | 5 +---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 5e917b064..a3655ce76 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -894,6 +894,12 @@ def rescale(self, keys, theta): elif isinstance(self[key], JointPrior): joint[self[key].dist.distname].append(key) for names in joint.values(): + # this is needed to unpack how joint prior rescaling works + # as an example of a joint prior over {a, b, c, d} we might + # get the following based on the order within the joint prior + # {a: [], b: [], c: [1, 2, 3, 4], d: []} + # -> [1, 2, 3, 4] + # -> {a: 1, b: 2, c: 3, d: 4} values = list() for key in names: values = np.concatenate([values, result[key]]) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index d45718204..2725981c8 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -402,10 +402,7 @@ def test_cdf_zero_below_domain(self): def test_cdf_float_with_float_input(self): for prior in self.priors: - if ( - bilby.core.prior.JointPrior in prior.__class__.__mro__ - and prior.maximum == np.inf - ): + if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue self.assertIsInstance(prior.cdf(prior.sample()), float) From 983430f220fe2a28331dd5307c232429274b7f3d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 27 Aug 2025 07:00:57 -0500 Subject: [PATCH 13/14] BUG: fix scipy integrate imports --- bilby/gw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index a814292f6..e262eaaf3 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -2,6 +2,7 @@ import copy import numpy as np +from scipy.integrate import cumulative_trapezoid, trapezoid, quad from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm @@ -1498,7 +1499,6 @@ def _build_attributes(self): """ Method that builds the inverse cdf of the P(pixel) distribution for rescaling """ - from scipy.integrate import cumulative_trapezoid yy = self._all_interped(self.pix_xx) yy /= trapezoid(yy, self.pix_xx) YY = cumulative_trapezoid(yy, self.pix_xx, initial=0) From 08a2205894e3699a0b467e08dbed886a8c388a5e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 27 Aug 2025 07:17:42 -0500 Subject: [PATCH 14/14] BUG: Remove unused import of interp1d Removed unused import of interp1d from scipy.interpolate. --- bilby/core/prior/interpolated.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 56213fed0..5fbf8f9c1 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -1,6 +1,5 @@ import numpy as np from scipy.integrate import trapezoid -from scipy.interpolate import interp1d from .base import Prior from ..utils import logger, WrappedInterp1d as interp1d