Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions src/discovery/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, Phi, Fs, Phi_inv=None):

def CompoundGlobalGP(gplist):
if all(isinstance(gp, GlobalVariableGP) for gp in gplist):
fmats = [np.hstack(F) for F in zip(*[gp.Fs for gp in gplist])]
fmats = [jnp.hstack(F) for F in zip(*[gp.Fs for gp in gplist])]

npsr = len(fmats)

Expand Down Expand Up @@ -206,7 +206,7 @@ def VectorCompoundGP(gplist):
if all(isinstance(gp, (VariableGP, GlobalVariableGP)) for gp in gplist):
# each gp.F is a tuple of F matrices, one for each pulsar
# globalgp has gp.Fs instead, which maybe is not ideal
F = [np.hstack(Fs) for Fs in zip(*[gp.F if hasattr(gp, 'F') else gp.Fs for gp in gplist])]
F = [jnp.hstack(Fs) for Fs in zip(*[gp.F if hasattr(gp, 'F') else gp.Fs for gp in gplist])]

if all(isinstance(gp.Phi, VectorNoiseMatrix1D_var) for gp in gplist):
def Phi(params):
Expand Down Expand Up @@ -251,13 +251,13 @@ def CompoundGP(gplist):

if all(isinstance(gp, ConstantGP) for gp in gplist):
if all(isinstance(gp.Phi, NoiseMatrix1D_novar) for gp in gplist):
F = np.hstack([gp.F for gp in gplist])
PhiN = np.concatenate([gp.Phi.N for gp in gplist])
F = jnp.hstack([gp.F for gp in gplist])
PhiN = jnp.concatenate([gp.Phi.N for gp in gplist])

multigp = ConstantGP(NoiseMatrix1D_novar(PhiN), F)
elif all(isinstance(gp.Phi, (NoiseMatrix1D_novar, NoiseMatrix2D_novar)) for gp in gplist):
F = np.hstack([gp.F for gp in gplist])
PhiN = jsp.linalg.block_diag(*[np.diag(gp.Phi.N) if isinstance(gp.Phi, NoiseMatrix1D_novar)
F = jnp.hstack([gp.F for gp in gplist])
PhiN = jsp.linalg.block_diag(*[jnp.diag(gp.Phi.N) if isinstance(gp.Phi, NoiseMatrix1D_novar)
else gp.Phi.N
for gp in gplist])

Expand All @@ -268,7 +268,7 @@ def F(params):
return jnp.hstack([gp.F(params) if callable(gp.F) else gp.F for gp in gplist])
F.params = sum((gp.F.params if callable(gp.F) else [] for gp in gplist), [])
else:
F = np.hstack([gp.F for gp in gplist])
F = jnp.hstack([gp.F for gp in gplist])

if all(isinstance(gp.Phi, NoiseMatrix1D_var) for gp in gplist):
def Phi(params):
Expand Down Expand Up @@ -323,22 +323,22 @@ def NoiseMatrix12D_var(getN):
class NoiseMatrix1D_novar(ConstantKernel):
def __init__(self, N):
self.N = N
self.ld = np.logdet(N)
self.ld = jnp.logdet(N)

self.params = []

def make_kernelproduct(self, y):
if callable(y):
y_var = y
N, ld = jnparray(self.N), np.logdet(self.N)
N, ld = jnparray(self.N), jnp.logdet(self.N)

def kernelproduct(params):
yp = y_var(params)

return -0.5 * jnp.sum(yp**2 / N) - 0.5 * ld
kernelproduct.params = sorted(set(y.params))
else:
product = -0.5 * np.sum(y**2 / self.N) - 0.5 * np.logdet(self.N)
product = -0.5 * jnp.sum(y**2 / self.N) - 0.5 * jnp.logdet(self.N)

def kernelproduct(params):
return product
Expand Down Expand Up @@ -371,18 +371,18 @@ def solve_1d(y):
return solve_1d

def solve_2d(self, T):
return T / self.N[:, np.newaxis], self.ld
return T / self.N[:, jnp.newaxis], self.ld

def make_solve_2d(self):
N, ld = jnparray(self.N[:, np.newaxis]), jnparray(self.ld)
N, ld = jnparray(self.N[:, jnp.newaxis]), jnparray(self.ld)

def solve_2d(T):
return T / N, ld

return solve_2d

def make_sample(self):
N12 = jnparray(np.sqrt(self.N))
N12 = jnparray(jnp.sqrt(self.N))

def sample(key):
key, subkey = jnpsplit(key)
Expand Down Expand Up @@ -422,7 +422,7 @@ def SM_2d_fused(Y, N, F, P):
# indexed, carefully handwritten

def make_uind(U):
Uind = np.zeros((U.shape[1], jnp.max(jnp.sum(U, axis=0)) + 1), 'i')
Uind = np.zeros((U.shape[1], np.max(np.sum(U, axis=0)) + 1), 'i')

for i in range(U.shape[1]):
ind = np.where(U[:,i])[0]
Expand Down Expand Up @@ -598,7 +598,7 @@ def kernelproduct(params):

def inv(self, params):
N = self.getN(params)
return np.diag(1.0 / N), np.logdet(N)
return jnp.diag(1.0 / N), jnp.logdet(N)

def make_inv(self):
getN = self.getN
Expand Down Expand Up @@ -637,12 +637,12 @@ def sample(key, params):
def solve_1d(self, params, y):
N = self.getN(params)

return y / N, np.logdet(N)
return y / N, jnp.logdet(N)

def solve_2d(self, params, F):
N = self.getN(params)

return F / N[:, np.newaxis], np.logdet(N)
return F / N[:, jnp.newaxis], jnp.logdet(N)

def make_solve_1d(self):
getN = self.getN
Expand Down Expand Up @@ -671,8 +671,8 @@ class NoiseMatrix2D_novar(ConstantKernel):
def __init__(self, N):
self.N = N

self.invN = np.linalg.inv(N)
self.ld = np.linalg.slogdet(N)[1]
self.invN = jnp.linalg.inv(N)
self.ld = jnp.linalg.slogdet(N)[1]

def inv(self):
return self.invN, self.ld
Expand Down Expand Up @@ -832,8 +832,8 @@ def __init__(self, N, F, P):
FtNmF = F.T @ self.NmF

Pinv, ldP = P.inv()
self.cf = sp.linalg.cho_factor(Pinv + FtNmF)
self.ld = ldN + ldP + 2.0 * np.logdet(np.diag(self.cf[0]))
self.cf = matrix_factor(Pinv + FtNmF)
self.ld = ldN + ldP + matrix_norm * jnp.logdet(jnp.diag(self.cf[0]))

self.params = []

Expand All @@ -860,12 +860,12 @@ def make_kernelproduct(self, y):
def kernelproduct(params):
yp = y_var(params)

Nmy = N_solve_1d(yp)[0] - self.NmF @ jsp.linalg.cho_solve(cf, NmF.T @ yp)
Nmy = N_solve_1d(yp)[0] - self.NmF @ matrix_solve(cf, NmF.T @ yp)

return -0.5 * yp @ Nmy - 0.5 * ld
kernelproduct.params = sorted(set(y.params))
else:
Nmy = self.N.solve_1d(y)[0] - self.NmF @ sp.linalg.cho_solve(self.cf, self.NmF.T @ y)
Nmy = self.N.solve_1d(y)[0] - self.NmF @ matrix_solve(self.cf, self.NmF.T @ y)
product = -0.5 * y @ Nmy - 0.5 * self.ld

# closes on product
Expand Down Expand Up @@ -897,8 +897,8 @@ def make_kernelterms(self, y, T):
FtNmT = self.F.T @ NmT
TtNmT = T.T @ NmT

sol = sp.linalg.cho_solve(self.cf, FtNmy)
sol2 = sp.linalg.cho_solve(self.cf, FtNmT)
sol = matrix_solve(self.cf, FtNmy)
sol2 = matrix_solve(self.cf, FtNmT)

a = -0.5 * (ytNmy - FtNmy.T @ sol) - 0.5 * self.ld
b = jnparray(TtNmy - TtNmF @ sol)
Expand Down Expand Up @@ -938,8 +938,8 @@ def kernelsolve(params):
FtNmT = F.T @ NmT
TtNmT = Tmat.T @ NmT

TtSy = TtNmy - TtNmF @ jsp.linalg.cho_solve(cf, FtNmy)
TtST = TtNmT - TtNmF @ jsp.linalg.cho_solve(cf, FtNmT)
TtSy = TtNmy - TtNmF @ matrix_solve(cf, FtNmy)
TtST = TtNmT - TtNmF @ matrix_solve(cf, FtNmT)

return TtSy, TtST

Expand All @@ -952,8 +952,8 @@ def kernelsolve(params):
FtNmT = self.F.T @ NmT
TtNmT = T.T @ NmT

TtSy = jnparray(TtNmy - TtNmF @ sp.linalg.cho_solve(self.cf, FtNmy))
TtST = jnparray(TtNmT - TtNmF @ sp.linalg.cho_solve(self.cf, FtNmT))
TtSy = jnparray(TtNmy - TtNmF @ matrix_solve(self.cf, FtNmy))
TtST = jnparray(TtNmT - TtNmF @ matrix_solve(self.cf, FtNmT))

# closes on TtSy and TtST
def kernelsolve(params={}):
Expand All @@ -964,7 +964,7 @@ def kernelsolve(params={}):
return kernelsolve

def solve_1d(self, y):
return self.N.solve_1d(y)[0] - self.NmF @ sp.linalg.cho_solve(self.cf, self.NmF.T @ y), self.ld
return self.N.solve_1d(y)[0] - self.NmF @ matrix_solve(self.cf, self.NmF.T @ y), self.ld

def make_solve_1d(self):
N_solve_1d = self.N.make_solve_1d()
Expand All @@ -974,12 +974,12 @@ def make_solve_1d(self):

# closes on N_solve_1d, NmF, cf, ld
def solve_1d(y):
return N_solve_1d(y)[0] - NmF @ jsp.linalg.cho_solve(cf, NmF.T @ y), ld
return N_solve_1d(y)[0] - NmF @ matrix_solve(cf, NmF.T @ y), ld

return solve_1d

def solve_2d(self, y):
return self.N.solve_2d(y)[0] - self.NmF @ sp.linalg.cho_solve(self.cf, self.NmF.T @ y), self.ld
return self.N.solve_2d(y)[0] - self.NmF @ matrix_solve(self.cf, self.NmF.T @ y), self.ld

def make_solve_2d(self):
N_solve_2d = self.N.make_solve_2d()
Expand All @@ -988,7 +988,7 @@ def make_solve_2d(self):
ld = jnp.array(self.ld)

def solve_2d(F):
return N_solve_2d(F)[0] - NmF @ jsp.linalg.cho_solve(cf, NmF.T @ F), self.ld
return N_solve_2d(F)[0] - NmF @ matrix_solve(cf, NmF.T @ F), self.ld

return solve_2d

Expand Down Expand Up @@ -1121,10 +1121,10 @@ def kernelsolve(params):
TtNmF = T.T @ NmF

Pinv, _ = P_var_inv(params)
cf = jsp.linalg.cho_factor(Pinv + FtNmF)
cf = matrix_factor(Pinv + FtNmF)

TtSy = TtNmy - TtNmF @ jsp.linalg.cho_solve(cf, FtNmy)
TtST = TtNmT - TtNmF @ jsp.linalg.cho_solve(cf, FtNmT)
TtSy = TtNmy - TtNmF @ matrix_solve(cf, FtNmy)
TtST = TtNmT - TtNmF @ matrix_solve(cf, FtNmT)

return TtSy, TtST

Expand Down Expand Up @@ -1532,10 +1532,10 @@ def solve_1d(self, params, y):
NmF, ldN = self.N_var.solve_2d(params, self.F)
NmFty = NmF.T @ y

cf = sp.linalg.cho_factor(self.Pinv + self.F.T @ NmF)
ld = ldN + self.ldP + 2.0 * jnp.logdet(np.diag(cf[0]))
cf = matrix_factor(self.Pinv + self.F.T @ NmF)
ld = ldN + self.ldP + matrix_norm * jnp.logdet(jnp.diag(cf[0]))

return self.N_var.solve_1d(params, y)[0] - NmF @ sp.linalg.cho_solve(cf, NmFty), ld
return self.N_var.solve_1d(params, y)[0] - NmF @ matrix_solve(cf, NmFty), ld

def make_solve_1d(self):
N_solve_1d = self.N_var.make_solve_1d()
Expand All @@ -1558,10 +1558,10 @@ def solve_2d(self, params, Fr):
NmFl, ldN = self.N_var.solve_2d(params, self.F)
NmFltFr = NmFl.T @ Fr

cf = sp.linalg.cho_factor(self.Pinv + self.F.T @ NmFl)
ld = ldN + self.ldP + 2.0 * np.logdet(np.diag(cf[0]))
cf = matrix_factor(self.Pinv + self.F.T @ NmFl)
ld = ldN + self.ldP + matrix_norm * jnp.logdet(jnp.diag(cf[0]))

return self.N_var.solve_2d(params, Fr)[0] - NmFl @ sp.linalg.cho_solve(cf, NmFltFr), ld
return self.N_var.solve_2d(params, Fr)[0] - NmFl @ matrix_solve(cf, NmFltFr), ld

def make_solve_2d(self):
N_solve_2d = self.N_var.make_solve_2d()
Expand Down Expand Up @@ -1854,4 +1854,4 @@ def kernelterms(params):

kernelterms.params = self.P_var.params

return kernelterms
return kernelterms
Loading