-
Notifications
You must be signed in to change notification settings - Fork 122
Support non-numpy array backends #886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ea348fa to
771a8a9
Compare
|
This is now ready for review. There are a lot of changes, but most of them are essentially Bilby can once again be installed without I've managed to keep test changes minimal:
|
This required making some changes to the tests for conditional dicts as I've changed the output types and the backend introspection doesn't work on dict_items for some reason
mj-will
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments but I'll need to have another look.
| elif aac.is_cupy_namespace(xp): | ||
| from cupyx.scipy.special import erfinv | ||
| else: | ||
| raise BackendNotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be useful to include the backend in the error.
| __all__ = ["array_module", "promote_to_array"] | ||
|
|
||
|
|
||
| def array_module(arr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would benefit from a doc-string
| return np | ||
|
|
||
|
|
||
| def promote_to_array(args, backend, skip=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you thought about how devices would be handled here?
Moving arrays to a from GPUs can sometimes require more than just calling array.
| return np | ||
|
|
||
|
|
||
| def promote_to_array(args, backend, skip=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest adding a doc-string
| import os | ||
|
|
||
| import numpy as np | ||
| os.environ["SCIPY_ARRAY_API"] = "1" # noqa # flag for scipy backend switching |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worry slightly about having this hard coded. Does it introduce more overhead when using just numpy?
| This maps to the inverse CDF. This has been analytically solved for this case. | ||
| """ | ||
| return gammaincinv(self.k, val) * self.theta | ||
| return xp.asarray(gammaincinv(self.k, val)) * self.theta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean this is falling back to numpy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I should update/recheck this, but at least jax doesn't have good support for this, but it looks like tensorflow has a version that numpyro uses (jax-ml/jax#5350). cupy does have this function, so this workaround may have just been for jax. I could add a BackendNotImplementedError.
| ) | ||
| ) | ||
|
|
||
| betaln, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what this is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not anything good.
| betaln, |
| # return self.check_ln_prob(sample, ln_prob, | ||
| # normalized=normalized) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the removal of this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fairly sure it was, but I'll double check. I think check_ln_prob was problematic in some way.
| self[key].least_recently_sampled = result[key] | ||
| if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: | ||
| joint[self[key].dist.distname] = [key] | ||
| elif isinstance(self[key], JointPrior): | ||
| joint[self[key].dist.distname].append(key) | ||
| for names in joint.values(): | ||
| # this is needed to unpack how joint prior rescaling works | ||
| # as an example of a joint prior over {a, b, c, d} we might | ||
| # get the following based on the order within the joint prior | ||
| # {a: [], b: [], c: [1, 2, 3, 4], d: []} | ||
| # -> [1, 2, 3, 4] | ||
| # -> {a: 1, b: 2, c: 3, d: 4} | ||
| values = list() | ||
| for key in names: | ||
| values = np.concatenate([values, result[key]]) | ||
| for key, value in zip(names, values): | ||
| result[key] = value | ||
|
|
||
| def safe_flatten(value): | ||
| """ | ||
| this is gross but can be removed whenever we switch to returning | ||
| arrays, flatten converts 0-d arrays to 1-d so has to be special | ||
| cased | ||
| """ | ||
| if isinstance(value, (float, int)): | ||
| return value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is removing this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | ||
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest we remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | |
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
ColmTalbot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the initial comments @mj-will I'll take a pass at them ASAP.
| ) | ||
| ) | ||
|
|
||
| betaln, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not anything good.
| betaln, |
| """ | ||
| at_peak = (val == self.peak) | ||
| return np.nan_to_num(np.multiply(at_peak, np.inf)) | ||
| return at_peak * 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah
| (xp.sin(val) - xp.sin(self.minimum)) / | ||
| (xp.sin(self.maximum) - xp.sin(self.minimum)) | ||
| ) | ||
| _cdf *= val >= self.minimum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kind of in-place operation works, it's just operations on slices that don't work. Things like patterns that sometimes existed
_cdf = ...
_cdf[val < self.minimum] = 0
_cdf[val > self.maximum] = 1
| # return self.check_ln_prob(sample, ln_prob, | ||
| # normalized=normalized) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fairly sure it was, but I'll double check. I think check_ln_prob was problematic in some way.
| self[key].least_recently_sampled = result[key] | ||
| if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: | ||
| joint[self[key].dist.distname] = [key] | ||
| elif isinstance(self[key], JointPrior): | ||
| joint[self[key].dist.distname].append(key) | ||
| for names in joint.values(): | ||
| # this is needed to unpack how joint prior rescaling works | ||
| # as an example of a joint prior over {a, b, c, d} we might | ||
| # get the following based on the order within the joint prior | ||
| # {a: [], b: [], c: [1, 2, 3, 4], d: []} | ||
| # -> [1, 2, 3, 4] | ||
| # -> {a: 1, b: 2, c: 3, d: 4} | ||
| values = list() | ||
| for key in names: | ||
| values = np.concatenate([values, result[key]]) | ||
| for key, value in zip(names, values): | ||
| result[key] = value | ||
|
|
||
| def safe_flatten(value): | ||
| """ | ||
| this is gross but can be removed whenever we switch to returning | ||
| arrays, flatten converts 0-d arrays to 1-d so has to be special | ||
| cased | ||
| """ | ||
| if isinstance(value, (float, int)): | ||
| return value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | ||
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | |
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
| The natural logarithm of the bessel function | ||
| """ | ||
| return np.log(i0e(value)) + np.abs(value) | ||
| xp = array_module(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment to self: use xp_wrap here.
I've been working on this PR on and off for a few months, it isn't ready yet, but I wanted to share it in case other people had early opinions.
The goal is to make it easier to interface with models/samplers implemented in e.g., JAX, that support GPU/TPU acceleration and JIT compilation.
The general guiding principles are:
array-apispecification andscipyinteroperabilityThe primary changes so far are:
Changed behaviour:
Remaining issues:
bilby.gw.jaxstufffile should be removed and relevant functionality be moved elsewhere, it's currently just used for testing