Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import logging
import time
from collections.abc import Mapping
from datetime import timedelta
from itertools import combinations_with_replacement, product
from typing import TYPE_CHECKING, Any
Expand All @@ -39,7 +40,7 @@
from dataclasses import replace

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping
from collections.abc import MutableMapping

from numpy.typing import ArrayLike, NDArray

Expand Down Expand Up @@ -404,15 +405,15 @@ def mixing_matrices(

def invert_mixing_matrix(
M,
rtol: float = 1e-5,
rcond: float = 1e-5,
progress: Progress | None = None,
):
"""
Inversion model for the unmixing E/B modes.

Args:
M: Mixing matrix (mapping of keys -> Result objects)
rtol: relative tolerance for pseudo-inverse
rcond: relative tolerance for pseudo-inverse
progress: optional progress reporter

Returns:
Expand All @@ -432,14 +433,21 @@ def invert_mixing_matrix(
s1, s2 = value.spin
*_, _n, _m = _M.shape

if isinstance(rcond, Mapping):
if key not in rcond:
raise KeyError(f"Missing rcond value for wm key: {key}")
_rcond = rcond[key]
else:
_rcond = rcond

with progress.task(f"invert {key}"):
if (s1 != 0) and (s2 != 0):
# Cl^EE+Cl^BB and Cl^EE-Cl^BB transformation
# makes the mixing matrix block-diagonal
M_p = _M[0] + _M[1]
M_m = _M[0] - _M[1]
inv_M_p = np.linalg.pinv(M_p, rcond=rtol)
inv_M_m = np.linalg.pinv(M_m, rcond=rtol)
inv_M_p = np.linalg.pinv(M_p, rcond=_rcond)
inv_M_m = np.linalg.pinv(M_m, rcond=_rcond)
_inv_m = np.vstack(
(
np.hstack(((inv_M_p + inv_M_m) / 2, (inv_M_p - inv_M_m) / 2)),
Expand All @@ -448,16 +456,16 @@ def invert_mixing_matrix(
)
_inv_M_EEEE = _inv_m[:_m, :_n]
_inv_M_EEBB = _inv_m[_m:, :_n]
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=rtol)
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=_rcond)
_inv_M = np.array([_inv_M_EEEE, _inv_M_EEBB, _inv_M_EBEB])
else:
_inv_M = np.linalg.pinv(_M, rcond=rtol)
_inv_M = np.linalg.pinv(_M, rcond=_rcond)

inv_M[key] = replace(M[key], array=_inv_M)
return inv_M


def apply_mixing_matrix(d, M, lmax=None):
def apply_mixing_matrix(d, M):
"""
Apply mixing matrix to the data Cl.
Args:
Expand All @@ -471,7 +479,6 @@ def apply_mixing_matrix(d, M, lmax=None):
if lmax is None:
*_, lmax = d[key].shape
dtype = d[key].array.dtype
ell_mask = M[key].ell
s1, s2 = d[key].spin
_d = np.atleast_2d(d[key].array)
_M = M[key].array
Expand All @@ -487,7 +494,5 @@ def apply_mixing_matrix(d, M, lmax=None):
_corr_d.append(_M @ cl)
_corr_d = np.squeeze(_corr_d)
_corr_d = np.array(list(_corr_d), dtype=dtype)
corr_d[key] = replace(d[key], array=_corr_d, ell=ell_mask)
# truncate
corr_d = binned(corr_d, np.arange(0, lmax + 1))
corr_d[key] = replace(d[key], array=_corr_d)
return corr_d
11 changes: 9 additions & 2 deletions tests/test_twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_inverting_mixing_matrices():
("WHT", "WHT", 0, 0): Result(cl, spin=(0, 0), axis=(0,)),
}
mms = mixing_matrices(fields, cls2, l1max=10, l2max=20)
inv_mms = invert_mixing_matrix(mms)
inv_mms = invert_mixing_matrix(mms, rcond=1e-2)

# test for correct shape
for key in mms.keys():
Expand All @@ -387,7 +387,14 @@ def test_inverting_mixing_matrices():
_m = np.ones_like(mms[key].array)
mms[key] = Result(_m, spin=mms[key].spin, axis=mms[key].axis, ell=mms[key].ell)

inv_mms = invert_mixing_matrix(mms)
inv_mms = invert_mixing_matrix(
mms,
rcond={
("POS", "POS", 0, 0): 1e-2,
("POS", "SHE", 0, 0): 1e-3,
("SHE", "SHE", 0, 0): 1e-4,
},
)
assert inv_mms.keys() == mms.keys()
for key in mms:
inv_mm = inv_mms[key].array
Expand Down
Loading