From 764bc955f2dbf70c457e53e14def3cf495b82d5c Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 13 Mar 2026 11:04:02 -0500 Subject: [PATCH 1/2] FEAT: allow shifting each mode individually in probe centering --- src/ptychi/api/options/base.py | 14 +++++ src/ptychi/data_structures/probe.py | 25 ++++++-- .../test_2d_ptycho_probe_center_constraint.py | 63 +++++++++++++++++++ 3 files changed, 98 insertions(+), 4 deletions(-) create mode 100644 tests/test_2d_ptycho_probe_center_constraint.py diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 4c7ed14..0a49638 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -634,6 +634,19 @@ class ProbeCenterConstraintOptions(FeatureOptions): to keep it centered, or to use the total probe intensity. """ + center_modes_individually: bool = False + """ + If True, each mode is shifted individually based on their own center of mass. + """ + + def check(self, options: "task_options.PtychographyTaskOptions"): + super().check(options) + if self.center_modes_individually and self.use_intensity_for_com: + raise ValueError( + "`probe_options.center_constraint.use_intensity_for_com` must be False when " + "`probe_options.center_constraint.center_modes_individually` is True." + ) + @dataclasses.dataclass class ProbeOptions(ParameterOptions): @@ -680,6 +693,7 @@ def check(self, options: "task_options.PtychographyTaskOptions"): super().check(options) if not (self.initial_guess is not None and self.initial_guess.ndim == 4): raise ValueError("Probe initial_guess must be a (n_opr_modes, n_modes, h, w) tensor.") + self.center_constraint.check(options) if self.power_constraint.enabled and options.object_options.remove_object_probe_ambiguity.enabled: logger.warning( "`ObjectOptions.remove_object_probe_ambiguity` and `ProbeOptions.power_constraint` " diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 9016cec..48c67d5 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -416,14 +416,31 @@ def center_probe(self): """ Move the probe's center of mass to the center of the probe array. """ - - if self.options.center_constraint.use_intensity_for_com: + center_constraint = self.options.center_constraint + center = utils.to_tensor( + self.shape[-2:], + device=self.data.device, + dtype=torch.get_default_dtype(), + ) // 2 + + if center_constraint.center_modes_individually: + probe = self.data.clone() + probe_to_be_shifted = probe[0] + if center_constraint.use_intensity_for_com: + probe_to_be_shifted = torch.abs(probe_to_be_shifted) ** 2 + + shifts = center.to(probe_to_be_shifted.device) - ip.find_center_of_mass(probe_to_be_shifted) + probe[0] = ip.fourier_shift(probe[0], shifts) + self.set_data(probe) + return + + if center_constraint.use_intensity_for_com: probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0) else: probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0) - + com = ip.find_center_of_mass(probe_to_be_shifted) - shift = utils.to_tensor(self.shape[-2:]) // 2 - com + shift = center.to(com.device) - com shifted_probe = self.shift(shift) self.set_data(shifted_probe) diff --git a/tests/test_2d_ptycho_probe_center_constraint.py b/tests/test_2d_ptycho_probe_center_constraint.py new file mode 100644 index 0000000..a7d7ca5 --- /dev/null +++ b/tests/test_2d_ptycho_probe_center_constraint.py @@ -0,0 +1,63 @@ +import pytest +import torch +from types import SimpleNamespace + +import ptychi.image_proc as ip +from ptychi.api.options import base as obase +from ptychi.data_structures.probe import Probe + + +def _make_probe( + data: torch.Tensor, + *, + center_modes_individually: bool, + use_intensity_for_com: bool, +) -> Probe: + options = obase.ProbeOptions() + options.optimizable = False + options.center_constraint.enabled = True + options.center_constraint.center_modes_individually = center_modes_individually + options.center_constraint.use_intensity_for_com = use_intensity_for_com + return Probe(data=data, options=options) + + +def test_center_probe_can_shift_incoherent_modes_individually(): + data = torch.zeros((2, 2, 7, 7), dtype=torch.complex64) + data[0, 0, 1, 2] = 1 + 0j + data[0, 1, 4, 1] = 1 + 0j + + data[1, 0, 0, 0] = 2 + 0j + data[1, 0, 2, 5] = 3 + 0j + data[1, 1, 5, 6] = 4 + 0j + data[1, 1, 6, 1] = 5 + 0j + + secondary_opr_modes_before = data[1:].clone() + probe = _make_probe( + data, + center_modes_individually=True, + use_intensity_for_com=False, + ) + + probe.center_probe() + + expected_center = torch.tensor([[3.0, 3.0], [3.0, 3.0]]) + centered_mode_com = ip.find_center_of_mass(torch.abs(probe.data[0]) ** 2) + + assert torch.allclose(centered_mode_com, expected_center, atol=1e-4) + assert torch.allclose(probe.data[1:], secondary_opr_modes_before) + + +def test_probe_center_constraint_check_rejects_individual_mode_centering_with_intensity_com(): + options = obase.ProbeOptions() + options.initial_guess = torch.zeros((1, 1, 7, 7), dtype=torch.complex64) + options.center_constraint.center_modes_individually = True + options.center_constraint.use_intensity_for_com = True + + task_options = SimpleNamespace( + object_options=SimpleNamespace( + remove_object_probe_ambiguity=SimpleNamespace(enabled=False) + ) + ) + + with pytest.raises(ValueError, match="use_intensity_for_com"): + options.check(task_options) From f27acf594e5ad3c8c6251707d750a97a118261a3 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 13 Mar 2026 11:09:08 -0500 Subject: [PATCH 2/2] REFACTOR: rename `use_intensity_for_com` to `use_total_intensity_for_com`, add deprecation warning --- src/ptychi/api/options/base.py | 20 ++++++++++++--- src/ptychi/data_structures/probe.py | 7 ++++-- .../test_2d_ptycho_probe_center_constraint.py | 25 ++++++++++++++++--- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 0a49638..ee5032b 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -7,6 +7,7 @@ import logging from math import ceil import enum +import warnings from numpy import ndarray from torch import Tensor @@ -627,13 +628,18 @@ class ProbeCenterConstraintOptions(FeatureOptions): optimization_plan: OptimizationPlan = dataclasses.field(default_factory=OptimizationPlan) - use_intensity_for_com: bool = False + use_total_intensity_for_com: bool = False """ Whether to use the magnitude of the dominant shared probe mode for computing the center of mass of the probe in order to keep it centered, or to use the total probe intensity. """ + use_intensity_for_com: bool = False + """ + Deprecated alias for `use_total_intensity_for_com`. + """ + center_modes_individually: bool = False """ If True, each mode is shifted individually based on their own center of mass. @@ -641,9 +647,17 @@ class ProbeCenterConstraintOptions(FeatureOptions): def check(self, options: "task_options.PtychographyTaskOptions"): super().check(options) - if self.center_modes_individually and self.use_intensity_for_com: + if self.use_intensity_for_com: + warnings.warn( + "`probe_options.center_constraint.use_intensity_for_com` is deprecated; " + "use `probe_options.center_constraint.use_total_intensity_for_com` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.use_total_intensity_for_com = True + if self.center_modes_individually and self.use_total_intensity_for_com: raise ValueError( - "`probe_options.center_constraint.use_intensity_for_com` must be False when " + "`probe_options.center_constraint.use_total_intensity_for_com` must be False when " "`probe_options.center_constraint.center_modes_individually` is True." ) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 48c67d5..d839836 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -417,6 +417,9 @@ def center_probe(self): Move the probe's center of mass to the center of the probe array. """ center_constraint = self.options.center_constraint + use_total_intensity_for_com = ( + center_constraint.use_total_intensity_for_com or center_constraint.use_intensity_for_com + ) center = utils.to_tensor( self.shape[-2:], device=self.data.device, @@ -426,7 +429,7 @@ def center_probe(self): if center_constraint.center_modes_individually: probe = self.data.clone() probe_to_be_shifted = probe[0] - if center_constraint.use_intensity_for_com: + if use_total_intensity_for_com: probe_to_be_shifted = torch.abs(probe_to_be_shifted) ** 2 shifts = center.to(probe_to_be_shifted.device) - ip.find_center_of_mass(probe_to_be_shifted) @@ -434,7 +437,7 @@ def center_probe(self): self.set_data(probe) return - if center_constraint.use_intensity_for_com: + if use_total_intensity_for_com: probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0) else: probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0) diff --git a/tests/test_2d_ptycho_probe_center_constraint.py b/tests/test_2d_ptycho_probe_center_constraint.py index a7d7ca5..6d955af 100644 --- a/tests/test_2d_ptycho_probe_center_constraint.py +++ b/tests/test_2d_ptycho_probe_center_constraint.py @@ -11,13 +11,13 @@ def _make_probe( data: torch.Tensor, *, center_modes_individually: bool, - use_intensity_for_com: bool, + use_total_intensity_for_com: bool, ) -> Probe: options = obase.ProbeOptions() options.optimizable = False options.center_constraint.enabled = True options.center_constraint.center_modes_individually = center_modes_individually - options.center_constraint.use_intensity_for_com = use_intensity_for_com + options.center_constraint.use_total_intensity_for_com = use_total_intensity_for_com return Probe(data=data, options=options) @@ -35,7 +35,7 @@ def test_center_probe_can_shift_incoherent_modes_individually(): probe = _make_probe( data, center_modes_individually=True, - use_intensity_for_com=False, + use_total_intensity_for_com=False, ) probe.center_probe() @@ -51,6 +51,21 @@ def test_probe_center_constraint_check_rejects_individual_mode_centering_with_in options = obase.ProbeOptions() options.initial_guess = torch.zeros((1, 1, 7, 7), dtype=torch.complex64) options.center_constraint.center_modes_individually = True + options.center_constraint.use_total_intensity_for_com = True + + task_options = SimpleNamespace( + object_options=SimpleNamespace( + remove_object_probe_ambiguity=SimpleNamespace(enabled=False) + ) + ) + + with pytest.raises(ValueError, match="use_total_intensity_for_com"): + options.check(task_options) + + +def test_probe_center_constraint_check_promotes_deprecated_intensity_flag(): + options = obase.ProbeOptions() + options.initial_guess = torch.zeros((1, 1, 7, 7), dtype=torch.complex64) options.center_constraint.use_intensity_for_com = True task_options = SimpleNamespace( @@ -59,5 +74,7 @@ def test_probe_center_constraint_check_rejects_individual_mode_centering_with_in ) ) - with pytest.raises(ValueError, match="use_intensity_for_com"): + with pytest.warns(DeprecationWarning, match="use_total_intensity_for_com"): options.check(task_options) + + assert options.center_constraint.use_total_intensity_for_com is True