diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 4c7ed14..181b3c0 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -634,6 +634,11 @@ class ProbeCenterConstraintOptions(FeatureOptions): to keep it centered, or to use the total probe intensity. """ + use_com_for_each_mode_independently: bool = False + """ + Determine and use the center of mass of the probe magnitude to keep each + shared mode centered independently. + """ @dataclasses.dataclass class ProbeOptions(ParameterOptions): diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 9016cec..08b90d2 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -84,13 +84,12 @@ def shift(self, shifts: Tensor): ) shifted_probe = shifted_probe.view(*self.shape) else: - n_shifts = shifts.shape[0] - n_images_each_probe = self.shape[0] * self.shape[1] - probe_straightened = self.tensor.complex().view(n_images_each_probe, *self.shape[-2:]) - probe_straightened = probe_straightened.repeat(n_shifts, 1, 1) - shifts = shifts.repeat_interleave(n_images_each_probe, dim=0) - shifted_probe = ip.fourier_shift(probe_straightened, shifts) - shifted_probe = shifted_probe.reshape(n_shifts, *self.shape) + shifted_probe = torch.zeros_like(self.data) + shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') + shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], + shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), + method='fourier' + ) return shifted_probe @property @@ -417,7 +416,9 @@ def center_probe(self): Move the probe's center of mass to the center of the probe array. """ - if self.options.center_constraint.use_intensity_for_com: + if self.options.center_constraint.use_com_for_each_mode_independently: + probe_to_be_shifted = self.data[0,...] + elif self.options.center_constraint.use_intensity_for_com: probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0) else: probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0)