diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 4c7ed14..ee5032b 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -7,6 +7,7 @@ import logging from math import ceil import enum +import warnings from numpy import ndarray from torch import Tensor @@ -627,13 +628,39 @@ class ProbeCenterConstraintOptions(FeatureOptions): optimization_plan: OptimizationPlan = dataclasses.field(default_factory=OptimizationPlan) - use_intensity_for_com: bool = False + use_total_intensity_for_com: bool = False """ Whether to use the magnitude of the dominant shared probe mode for computing the center of mass of the probe in order to keep it centered, or to use the total probe intensity. """ + use_intensity_for_com: bool = False + """ + Deprecated alias for `use_total_intensity_for_com`. + """ + + center_modes_individually: bool = False + """ + If True, each mode is shifted individually based on their own center of mass. + """ + + def check(self, options: "task_options.PtychographyTaskOptions"): + super().check(options) + if self.use_intensity_for_com: + warnings.warn( + "`probe_options.center_constraint.use_intensity_for_com` is deprecated; " + "use `probe_options.center_constraint.use_total_intensity_for_com` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.use_total_intensity_for_com = True + if self.center_modes_individually and self.use_total_intensity_for_com: + raise ValueError( + "`probe_options.center_constraint.use_total_intensity_for_com` must be False when " + "`probe_options.center_constraint.center_modes_individually` is True." + ) + @dataclasses.dataclass class ProbeOptions(ParameterOptions): @@ -680,6 +707,7 @@ def check(self, options: "task_options.PtychographyTaskOptions"): super().check(options) if not (self.initial_guess is not None and self.initial_guess.ndim == 4): raise ValueError("Probe initial_guess must be a (n_opr_modes, n_modes, h, w) tensor.") + self.center_constraint.check(options) if self.power_constraint.enabled and options.object_options.remove_object_probe_ambiguity.enabled: logger.warning( "`ObjectOptions.remove_object_probe_ambiguity` and `ProbeOptions.power_constraint` " diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 9016cec..d839836 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -416,14 +416,34 @@ def center_probe(self): """ Move the probe's center of mass to the center of the probe array. """ - - if self.options.center_constraint.use_intensity_for_com: + center_constraint = self.options.center_constraint + use_total_intensity_for_com = ( + center_constraint.use_total_intensity_for_com or center_constraint.use_intensity_for_com + ) + center = utils.to_tensor( + self.shape[-2:], + device=self.data.device, + dtype=torch.get_default_dtype(), + ) // 2 + + if center_constraint.center_modes_individually: + probe = self.data.clone() + probe_to_be_shifted = probe[0] + if use_total_intensity_for_com: + probe_to_be_shifted = torch.abs(probe_to_be_shifted) ** 2 + + shifts = center.to(probe_to_be_shifted.device) - ip.find_center_of_mass(probe_to_be_shifted) + probe[0] = ip.fourier_shift(probe[0], shifts) + self.set_data(probe) + return + + if use_total_intensity_for_com: probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0) else: probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0) - + com = ip.find_center_of_mass(probe_to_be_shifted) - shift = utils.to_tensor(self.shape[-2:]) // 2 - com + shift = center.to(com.device) - com shifted_probe = self.shift(shift) self.set_data(shifted_probe) diff --git a/tests/test_2d_ptycho_probe_center_constraint.py b/tests/test_2d_ptycho_probe_center_constraint.py new file mode 100644 index 0000000..6d955af --- /dev/null +++ b/tests/test_2d_ptycho_probe_center_constraint.py @@ -0,0 +1,80 @@ +import pytest +import torch +from types import SimpleNamespace + +import ptychi.image_proc as ip +from ptychi.api.options import base as obase +from ptychi.data_structures.probe import Probe + + +def _make_probe( + data: torch.Tensor, + *, + center_modes_individually: bool, + use_total_intensity_for_com: bool, +) -> Probe: + options = obase.ProbeOptions() + options.optimizable = False + options.center_constraint.enabled = True + options.center_constraint.center_modes_individually = center_modes_individually + options.center_constraint.use_total_intensity_for_com = use_total_intensity_for_com + return Probe(data=data, options=options) + + +def test_center_probe_can_shift_incoherent_modes_individually(): + data = torch.zeros((2, 2, 7, 7), dtype=torch.complex64) + data[0, 0, 1, 2] = 1 + 0j + data[0, 1, 4, 1] = 1 + 0j + + data[1, 0, 0, 0] = 2 + 0j + data[1, 0, 2, 5] = 3 + 0j + data[1, 1, 5, 6] = 4 + 0j + data[1, 1, 6, 1] = 5 + 0j + + secondary_opr_modes_before = data[1:].clone() + probe = _make_probe( + data, + center_modes_individually=True, + use_total_intensity_for_com=False, + ) + + probe.center_probe() + + expected_center = torch.tensor([[3.0, 3.0], [3.0, 3.0]]) + centered_mode_com = ip.find_center_of_mass(torch.abs(probe.data[0]) ** 2) + + assert torch.allclose(centered_mode_com, expected_center, atol=1e-4) + assert torch.allclose(probe.data[1:], secondary_opr_modes_before) + + +def test_probe_center_constraint_check_rejects_individual_mode_centering_with_intensity_com(): + options = obase.ProbeOptions() + options.initial_guess = torch.zeros((1, 1, 7, 7), dtype=torch.complex64) + options.center_constraint.center_modes_individually = True + options.center_constraint.use_total_intensity_for_com = True + + task_options = SimpleNamespace( + object_options=SimpleNamespace( + remove_object_probe_ambiguity=SimpleNamespace(enabled=False) + ) + ) + + with pytest.raises(ValueError, match="use_total_intensity_for_com"): + options.check(task_options) + + +def test_probe_center_constraint_check_promotes_deprecated_intensity_flag(): + options = obase.ProbeOptions() + options.initial_guess = torch.zeros((1, 1, 7, 7), dtype=torch.complex64) + options.center_constraint.use_intensity_for_com = True + + task_options = SimpleNamespace( + object_options=SimpleNamespace( + remove_object_probe_ambiguity=SimpleNamespace(enabled=False) + ) + ) + + with pytest.warns(DeprecationWarning, match="use_total_intensity_for_com"): + options.check(task_options) + + assert options.center_constraint.use_total_intensity_for_com is True