diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 4c7ed148..f7910fcc 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -572,7 +572,9 @@ class ProbeOrthogonalizeIncoherentModesOptions(FeatureOptions): method: enums.OrthogonalizationMethods = enums.OrthogonalizationMethods.SVD """The method to use for incoherent_mode orthogonalization.""" - + + sort_by_occupancy: bool = False + """If True, keep the probes sorted so that mode with highest occupancy is the 0th shared mode.""" @dataclasses.dataclass class ProbeOrthogonalizeOPRModesOptions(FeatureOptions): diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 9016cecf..1ef8200e 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -221,6 +221,11 @@ def constrain_incoherent_modes_orthogonality(self): return probe = self.data + if self.options.orthogonalize_incoherent_modes.sort_by_occupancy: + shared_occupancy = torch.sum(torch.abs(probe[0, ...]) ** 2, (-2, -1)) + shared_occupancy = torch.sort(shared_occupancy, dim=0, descending=True) + sorted_idx = shared_occupancy[1] + probe[0] = torch.index_select(probe[0], 0, sorted_idx) norm_first_mode_orig = pmath.norm(probe[0, 0], dim=(-2, -1))