Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions fastshermanmorrison/cython_fastshermanmorrison.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cnp.import_array()
from libc.math cimport log, sqrt, hypot
import cython
from scipy.linalg.cython_blas cimport dgemm, dger, dgemv
from scipy.linalg.cython_lapack cimport dpotrf, dpotri


cdef public void dgemm_(char *transa, char *transb, int *m, int *n, int *k,
Expand Down Expand Up @@ -1312,3 +1313,257 @@ def cython_blas_idx_block_shermor_2D_asymm(Z1, Z2, Nvec, Jvec, Uinds, slc_isort)
Jldet = c_blas_block_shermor_2D_asymm(Z1n, Z2n, Nvecn, Jvec, np.array(Uinds, order="C"), ZNZ)

return Jldet, ZNZ

cpdef cnp.ndarray[cnp.double_t, ndim=1] cython_block_shermor_solve_D1_small(
cnp.ndarray[cnp.double_t, ndim=1] r,
cnp.ndarray[cnp.double_t, ndim=1] Nvec,
cnp.ndarray[cnp.double_t, ndim=1] Jvec,
cnp.ndarray[cnp.double_t, ndim=2] U
):
"""
Block-wise D1 solve: returns (D + U J U^T)^{-1} r.
U is a dense 0/1 matrix of shape (D, k).
"""
cdef int D = r.shape[0]
cdef int KK = Jvec.shape[0]
cdef int info
cdef char uplo = b'L'[0]

cdef cnp.ndarray[cnp.double_t, ndim=1] Dinv = 1.0 / Nvec
cdef cnp.ndarray[cnp.double_t, ndim=1] out = r * Dinv
cdef cnp.ndarray[cnp.double_t, ndim=2] M = np.diag(1.0 / Jvec)
M += (U.T * Dinv[None, :]).dot(U)
#M = np.linalg.inv(M)

# Invert M in-place using Cholesky
M = np.asfortranarray(M)
dpotrf(&uplo, &KK, &M[0,0], &KK, &info)
if info != 0:
raise RuntimeError("Cholesky failed in cython_block_shermor_solve_D1_small")
dpotri(&uplo, &KK, &M[0,0], &KK, &info)
if info != 0:
raise RuntimeError("Inversion failed in cython_block_shermor_solve_D1_small")

# Make M^{-1} symmetric
for i in range(KK):
for j in range(i+1, KK):
M[i,j] = M[j,i]

cdef cnp.ndarray[cnp.double_t, ndim=1] kvec = U.T.dot(Dinv * r)
cdef cnp.ndarray[cnp.double_t, ndim=1] corr = Dinv * (U.dot(M.dot(kvec)))

out -= corr

return out

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void cython_block_shermor_solve_D1_k(
cnp.ndarray[cnp.double_t, ndim=1] r,
cnp.ndarray[cnp.double_t, ndim=1] Nvec,
object idxs_list, # list of intp‐arrays
object j_list, # list of 1D double‐arrays
object U_list, # list of 2D double‐arrays
cnp.ndarray[cnp.double_t, ndim=1] out # pre‐allocated output
):
"""
Batch all D1 solves: for each component i,
out[idxs] = (D[idxs] + U_i J_i U_i^T)^{-1} r[idxs]
"""
cdef Py_ssize_t ncomps = len(idxs_list)
cdef Py_ssize_t ci
cdef cnp.ndarray[cnp.double_t, ndim=1] rsub, nsub, res
cdef object idxs, jv, U

for ci in range(ncomps):
idxs = idxs_list[ci] # 1D np.intp array of indices
jv = j_list[ci] # 1D np.double array of jitter j's
U = U_list[ci] # 2D np.double array (Dblock × k)

# fancy‐indexing slices out the subblocks
rsub = r[idxs] # shape (Dblock,)
nsub = Nvec[idxs] # shape (Dblock,)

# call the small-block solver
res = cython_block_shermor_solve_D1_small(rsub, nsub, jv, U)

# scatter back
out[idxs] = res


# ----------------------------------------------------------------
# 1D1 solve: (D + U J U^T)^{-1} action in the y^T N^-1 x form
# ----------------------------------------------------------------
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef tuple cython_block_shermor_solve_1D1_small(
cnp.ndarray[cnp.double_t, ndim=1] x,
cnp.ndarray[cnp.double_t, ndim=1] y,
cnp.ndarray[cnp.double_t, ndim=1] Nvec,
cnp.ndarray[cnp.double_t, ndim=1] Jvec,
cnp.ndarray[cnp.double_t, ndim=2] U
):
"""
Returns (logdet, y^T (D+U J U^T)^{-1} x)
where D = diag(Nvec), and U is Dxk (0/1).
"""
cdef int D = x.shape[0]
cdef int k = Jvec.shape[0]
cdef char uplo = b'L'[0]
cdef int info

# Precompute Dinv and initial y^T N^{-1} x
cdef cnp.ndarray[cnp.double_t, ndim=1] Dinv = 1.0 / Nvec
cdef cnp.ndarray[cnp.double_t, ndim=1] Nx = x * Dinv
cdef double yNx = 0.0 # np.dot(y, Nx)
cdef double logdetM = 0.0

# Build the k×k “M = J^{-1} + Uᵀ (Dinv * U)” matrix
cdef cnp.ndarray[cnp.double_t, ndim=2] M = np.diag(1.0 / Jvec)
M += (U.T * Dinv[None, :]).dot(U)

# Cholesky‐factor M in place (so we can get logdet), then invert
M = np.asfortranarray(M) # ensure Fortran‐order for LAPACK
dpotrf(&uplo, &k, &M[0,0], &k, &info)
if info != 0:
raise RuntimeError("dpotrf failed in 1D1 solver")

for i in range(k):
logdetM += 2.0 * log(M[i,i])
dpotri(&uplo, &k, &M[0,0], &k, &info)
if info != 0:
raise RuntimeError("dpotri failed in 1D1 solver")

# Make M^{-1} symmetric
for i in range(k):
for j in range(i+1, k):
M[i,j] = M[j,i]

# Compute kx = Uᵀ (Dinv * x), ky = Uᵀ (Dinv * y)
cdef cnp.ndarray[cnp.double_t, ndim=1] kx = U.T.dot(Dinv * x)
cdef cnp.ndarray[cnp.double_t, ndim=1] ky = U.T.dot(Dinv * y)

# Subtract the low‐rank piece
yNx -= np.dot(ky, M.dot(kx))

# Total log‐determinant = sum(log N_i) + sum(log J_j) + logdet(M)
cdef double logdet = np.sum(np.log(Jvec)) + logdetM

return logdet, yNx



# ----------------------------------------------------------------
# 2D2 solve: Z2ᵀ (D+U J Uᵀ)⁻¹ Z1 + logdet
# ----------------------------------------------------------------
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef tuple cython_block_shermor_solve_2D2_small(
cnp.ndarray[cnp.double_t, ndim=2] Z1, # shape (D, c1)
cnp.ndarray[cnp.double_t, ndim=2] Z2, # shape (D, c2)
cnp.ndarray[cnp.double_t, ndim=1] Nvec,
cnp.ndarray[cnp.double_t, ndim=1] Jvec,
cnp.ndarray[cnp.double_t, ndim=2] U
):
"""
Returns (logdet, Z2ᵀ (D+U J Uᵀ)⁻¹ Z1), where D=diag(Nvec).
"""
cdef int D = Z1.shape[0]
cdef int c1 = Z1.shape[1]
cdef int c2 = Z2.shape[1]
cdef int k = Jvec.shape[0]
cdef char uplo = b'L'[0]
cdef int info
cdef double logdetM = 0.0

# build Dinv
cdef cnp.ndarray[cnp.double_t, ndim=1] Dinv = 1.0 / Nvec

# start with the purely diagonal part: Z2ᵀ D⁻¹ Z1
#cdef cnp.ndarray[cnp.double_t, ndim=2] ZNZ = (Z2.T * Dinv[None,:]).dot(Z1)
cdef cnp.ndarray[cnp.double_t, ndim=2] ZNZ = np.zeros((c2, c1))

# build and factor M exactly as above
cdef cnp.ndarray[cnp.double_t, ndim=2] M = np.diag(1.0 / Jvec)
M += (U.T * Dinv[None, :]).dot(U)
M = np.asfortranarray(M)
dpotrf(&uplo, &k, &M[0,0], &k, &info)
if info != 0:
raise RuntimeError("dpotrf failed in 2D2 solver")
for i in range(k):
logdetM += 2.0 * log(M[i,i])
dpotri(&uplo, &k, &M[0,0], &k, &info)
if info != 0:
raise RuntimeError("dpotri failed in 2D2 solver")
for i in range(k):
for j in range(i+1, k):
M[i,j] = M[j,i]

# subtract the low‐rank correction:
# (Z2ᵀ D⁻¹ U) · M · (Uᵀ D⁻¹ Z1)
cdef cnp.ndarray[cnp.double_t, ndim=2] L = (Z2 * Dinv[:,None]).T.dot(U) # (c2×k)
cdef cnp.ndarray[cnp.double_t, ndim=2] R = U.T.dot(Dinv[:,None] * Z1) # (k×c1)
ZNZ -= L.dot(M.dot(R))

# total log‐det: sum log N + sum log J + logdetM
cdef double logdet = np.sum(np.log(Jvec)) + logdetM

return logdet, ZNZ



# ----------------------------------------------------------------
# sqrt-solve: (D+U J Uᵀ)^{-1/2} X via small-block Cholesky
# ----------------------------------------------------------------
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef cnp.ndarray[cnp.double_t, ndim=2] cython_block_shermor_sqrtsolve_small(
cnp.ndarray[cnp.double_t, ndim=2] X,
cnp.ndarray[cnp.double_t, ndim=1] Nvec,
cnp.ndarray[cnp.double_t, ndim=1] Jvec,
cnp.ndarray[cnp.double_t, ndim=2] U
):
"""
Returns (D+U J Uᵀ)^{-1/2} X. U is D×k.
"""
cdef int D = X.shape[0]
cdef int m = X.shape[1]
cdef int k = Jvec.shape[0]
cdef char uplo = b'L'[0]
cdef int info
cdef double jv, v, s

# 1) start with N^{-1/2} scaling
cdef cnp.ndarray[cnp.double_t, ndim=2] out = np.zeros_like(X)
cdef double[:] Nvv = Nvec
for i in range(D):
s = sqrt(Nvv[i])
for j in range(m):
out[i,j] = X[i,j] / s

# 2) build the full D×D block A = diag(Nvec) + U·diag(Jvec)·Uᵀ
cdef cnp.ndarray[cnp.double_t, ndim=2] A = np.diag(Nvec)
for col in range(k):
jv = Jvec[col]
for i in range(D):
if U[i,col] != 0.0:
for ii in range(D):
if U[ii,col] != 0.0:
A[i,ii] += jv

# factor A = L Lᵀ
A = np.asfortranarray(A)
dpotrf(&uplo, &D, &A[0,0], &D, &info)
if info != 0:
raise RuntimeError("dpotrf failed in sqrtsolve")

# forward‐solve L y = out_block
# (this overwrites out in place)
for col in range(m):
for i in range(D):
v = X[i,col]
for p in range(i):
v -= A[i,p] * out[p,col]
out[i,col] = v / A[i,i]

return out
Loading
Loading