Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/discovery/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def delayfunc(params):

# standard parameters t, pos, d;
def makedelay_deterministic(psr, delay, name='deterministic'):
argspec = inspect.getfullargspec(prior)
argspec = inspect.getfullargspec(delay)
argmap = [f'{name}_{arg}' + (f'({components})' if argspec.annotations.get(arg) == typing.Sequence else '')
for arg in argspec.args if arg not in ['t', 'pos', 'd']]

Expand Down
78 changes: 75 additions & 3 deletions src/discovery/solar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import inspect
import jax.numpy as jnp

from . import const
from . import matrix
from . import fourierbasis
from . import quantize

AU_light_sec = const.AU / const.c # 1 AU in light seconds
AU_pc = const.AU / const.pc # 1 AU in parsecs (for DM normalization)
Expand Down Expand Up @@ -31,9 +35,9 @@ def make_solardm(psr):
theta, r_earth, _, _ = theta_impact(psr)
shape = matrix.jnparray(AU_light_sec * AU_pc / r_earth / np.sinc(1 - theta/np.pi) * 4.148808e3 / psr.freqs**2)

def solardm(n_earth):
return n_earth * shape

def solardm(params):
return params['n_earth'] * shape
solardm.params = ['n_earth']
return solardm

def make_chromaticdecay(psr):
Expand All @@ -46,3 +50,71 @@ def decay(t0, log10_Amp, log10_tau, idx):
return matrix.jnp.where(dt > 0.0, -1.0 * (10**log10_Amp) * matrix.jnp.exp(-dt / (10**log10_tau)) * normfreqs**idx, 0.0)

return decay

def _dm_solar_close(n_earth, r_earth):
return (n_earth * AU_light_sec * AU_pc / r_earth)


def _dm_solar(n_earth, theta, r_earth):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary. I think the np.sinc in make_solardm should be faster than using a jnp.where. Note that the definition of np.sinc is sin(pi * x) / (pi * x) hence the 1 - theta / pi in the argument. But make_solardm above should give back the You 2007 function.

return ((np.pi - theta) *
(n_earth * AU_light_sec * AU_pc
/ (r_earth * np.sin(theta))))


def dm_solar(n_earth, theta, r_earth):
"""
Calculates Dispersion measure due to 1/r^2 solar wind density model.
::param :n_earth Solar wind proton/electron density at Earth (1/cm^3)
::param :theta: angle between sun and line-of-sight to pulsar (rad)
::param :r_earth :distance from Earth to Sun in (light seconds).
See You et al. 2007 for more details.
"""
return matrix.jnp.where(np.pi - theta >= 1e-5,
_dm_solar(n_earth, theta, r_earth),
_dm_solar_close(n_earth, r_earth))

def fourierbasis_solar_dm(psr,
components,
T=None):
"""
From enterprise_extions: construct DM-Solar Model Fourier design matrix.

:param psr: Pulsar object
:param components: Number of Fourier components in the model
:param T: Total timespan of the data

:return: F: SW DM-variation fourier design matrix
:return: f: Sampling frequencies
"""

# get base Fourier design matrix and frequencies
f, df, fmat = fourierbasis(psr, components, T)
theta, R_earth, _, _ = theta_impact(psr)
dm_sol_wind = dm_solar(1.0, theta, R_earth)
dt_DM = dm_sol_wind * 4.148808e3 / (psr.freqs**2)

return f, df, fmat * dt_DM[:, None]

def makegp_timedomain_solar_dm(psr, covariance, dt=1.0, common=[], name='timedomain_sw_gp'):
argspec = inspect.getfullargspec(covariance)
argmap = [(arg if arg in common else f'{name}_{arg}' if f'{name}_{arg}' in common else f'{psr.name}_{name}_{arg}')
for arg in argspec.args if arg not in ['tau']]

# get solar wind ingredients
theta, R_earth, _, _ = theta_impact(psr)
dm_sol_wind = dm_solar(1.0, theta, R_earth)
dt_DM = dm_sol_wind * 4.148808e3 / (psr.freqs**2)

bins = quantize(psr.toas, dt)
Umat = np.vstack([bins == i for i in range(bins.max() + 1)]).T.astype('d')
Umat = Umat * dt_DM[:, None]
toas = psr.toas @ Umat / Umat.sum(axis=0)

get_tmat = covariance
tau = jnp.abs(toas[:, jnp.newaxis] - toas[jnp.newaxis, :])

def getphi(params):
return get_tmat(tau, *[params[arg] for arg in argmap])
getphi.params = argmap

return matrix.VariableGP(matrix.NoiseMatrix2D_var(getphi), Umat)