diff --git a/heracles/twopoint.py b/heracles/twopoint.py index d53c792..5f7bd97 100644 --- a/heracles/twopoint.py +++ b/heracles/twopoint.py @@ -22,6 +22,7 @@ import logging import time +from collections.abc import Mapping from datetime import timedelta from itertools import combinations_with_replacement, product from typing import TYPE_CHECKING, Any @@ -39,7 +40,7 @@ from dataclasses import replace if TYPE_CHECKING: - from collections.abc import Mapping, MutableMapping + from collections.abc import MutableMapping from numpy.typing import ArrayLike, NDArray @@ -404,7 +405,7 @@ def mixing_matrices( def invert_mixing_matrix( M, - rtol: float = 1e-5, + rcond: float = 1e-5, progress: Progress | None = None, ): """ @@ -412,7 +413,7 @@ def invert_mixing_matrix( Args: M: Mixing matrix (mapping of keys -> Result objects) - rtol: relative tolerance for pseudo-inverse + rcond: relative tolerance for pseudo-inverse progress: optional progress reporter Returns: @@ -432,14 +433,21 @@ def invert_mixing_matrix( s1, s2 = value.spin *_, _n, _m = _M.shape + if isinstance(rcond, Mapping): + if key not in rcond: + raise KeyError(f"Missing rcond value for wm key: {key}") + _rcond = rcond[key] + else: + _rcond = rcond + with progress.task(f"invert {key}"): if (s1 != 0) and (s2 != 0): # Cl^EE+Cl^BB and Cl^EE-Cl^BB transformation # makes the mixing matrix block-diagonal M_p = _M[0] + _M[1] M_m = _M[0] - _M[1] - inv_M_p = np.linalg.pinv(M_p, rcond=rtol) - inv_M_m = np.linalg.pinv(M_m, rcond=rtol) + inv_M_p = np.linalg.pinv(M_p, rcond=_rcond) + inv_M_m = np.linalg.pinv(M_m, rcond=_rcond) _inv_m = np.vstack( ( np.hstack(((inv_M_p + inv_M_m) / 2, (inv_M_p - inv_M_m) / 2)), @@ -448,10 +456,10 @@ def invert_mixing_matrix( ) _inv_M_EEEE = _inv_m[:_m, :_n] _inv_M_EEBB = _inv_m[_m:, :_n] - _inv_M_EBEB = np.linalg.pinv(_M[2], rcond=rtol) + _inv_M_EBEB = np.linalg.pinv(_M[2], rcond=_rcond) _inv_M = np.array([_inv_M_EEEE, _inv_M_EEBB, _inv_M_EBEB]) else: - _inv_M = np.linalg.pinv(_M, rcond=rtol) + _inv_M = np.linalg.pinv(_M, rcond=_rcond) inv_M[key] = replace(M[key], array=_inv_M) return inv_M diff --git a/tests/test_twopoint.py b/tests/test_twopoint.py index 4863a7d..802771a 100644 --- a/tests/test_twopoint.py +++ b/tests/test_twopoint.py @@ -368,7 +368,7 @@ def test_inverting_mixing_matrices(): ("WHT", "WHT", 0, 0): Result(cl, spin=(0, 0), axis=(0,)), } mms = mixing_matrices(fields, cls2, l1max=10, l2max=20) - inv_mms = invert_mixing_matrix(mms) + inv_mms = invert_mixing_matrix(mms, rcond=1e-2) # test for correct shape for key in mms.keys(): @@ -387,7 +387,14 @@ def test_inverting_mixing_matrices(): _m = np.ones_like(mms[key].array) mms[key] = Result(_m, spin=mms[key].spin, axis=mms[key].axis, ell=mms[key].ell) - inv_mms = invert_mixing_matrix(mms) + inv_mms = invert_mixing_matrix( + mms, + rcond={ + ("POS", "POS", 0, 0): 1e-2, + ("POS", "SHE", 0, 0): 1e-3, + ("SHE", "SHE", 0, 0): 1e-4, + }, + ) assert inv_mms.keys() == mms.keys() for key in mms: inv_mm = inv_mms[key].array