Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import logging
import time
from collections.abc import Mapping
from datetime import timedelta
from itertools import combinations_with_replacement, product
from typing import TYPE_CHECKING, Any
Expand All @@ -39,7 +40,7 @@
from dataclasses import replace

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping
from collections.abc import MutableMapping

from numpy.typing import ArrayLike, NDArray

Expand Down Expand Up @@ -404,15 +405,15 @@ def mixing_matrices(

def invert_mixing_matrix(
M,
rtol: float = 1e-5,
rcond: float = 1e-5,
progress: Progress | None = None,
):
"""
Inversion model for the unmixing E/B modes.

Args:
M: Mixing matrix (mapping of keys -> Result objects)
rtol: relative tolerance for pseudo-inverse
rcond: relative tolerance for pseudo-inverse
progress: optional progress reporter

Returns:
Expand All @@ -432,14 +433,21 @@ def invert_mixing_matrix(
s1, s2 = value.spin
*_, _n, _m = _M.shape

if isinstance(rcond, Mapping):
if key not in rcond:
raise KeyError(f"Missing rcond value for wm key: {key}")
_rcond = rcond[key]
else:
_rcond = rcond

with progress.task(f"invert {key}"):
if (s1 != 0) and (s2 != 0):
# Cl^EE+Cl^BB and Cl^EE-Cl^BB transformation
# makes the mixing matrix block-diagonal
M_p = _M[0] + _M[1]
M_m = _M[0] - _M[1]
inv_M_p = np.linalg.pinv(M_p, rcond=rtol)
inv_M_m = np.linalg.pinv(M_m, rcond=rtol)
inv_M_p = np.linalg.pinv(M_p, rcond=_rcond)
inv_M_m = np.linalg.pinv(M_m, rcond=_rcond)
_inv_m = np.vstack(
(
np.hstack(((inv_M_p + inv_M_m) / 2, (inv_M_p - inv_M_m) / 2)),
Expand All @@ -448,10 +456,10 @@ def invert_mixing_matrix(
)
_inv_M_EEEE = _inv_m[:_m, :_n]
_inv_M_EEBB = _inv_m[_m:, :_n]
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=rtol)
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=_rcond)
_inv_M = np.array([_inv_M_EEEE, _inv_M_EEBB, _inv_M_EBEB])
else:
_inv_M = np.linalg.pinv(_M, rcond=rtol)
_inv_M = np.linalg.pinv(_M, rcond=_rcond)

inv_M[key] = replace(M[key], array=_inv_M)
return inv_M
Expand Down
11 changes: 9 additions & 2 deletions tests/test_twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_inverting_mixing_matrices():
("WHT", "WHT", 0, 0): Result(cl, spin=(0, 0), axis=(0,)),
}
mms = mixing_matrices(fields, cls2, l1max=10, l2max=20)
inv_mms = invert_mixing_matrix(mms)
inv_mms = invert_mixing_matrix(mms, rcond=1e-2)

# test for correct shape
for key in mms.keys():
Expand All @@ -387,7 +387,14 @@ def test_inverting_mixing_matrices():
_m = np.ones_like(mms[key].array)
mms[key] = Result(_m, spin=mms[key].spin, axis=mms[key].axis, ell=mms[key].ell)

inv_mms = invert_mixing_matrix(mms)
inv_mms = invert_mixing_matrix(
mms,
rcond={
("POS", "POS", 0, 0): 1e-2,
("POS", "SHE", 0, 0): 1e-3,
("SHE", "SHE", 0, 0): 1e-4,
},
)
assert inv_mms.keys() == mms.keys()
for key in mms:
inv_mm = inv_mms[key].array
Expand Down