Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from math import ceil
import enum
import warnings

from numpy import ndarray
from torch import Tensor
Expand Down Expand Up @@ -627,13 +628,39 @@ class ProbeCenterConstraintOptions(FeatureOptions):

optimization_plan: OptimizationPlan = dataclasses.field(default_factory=OptimizationPlan)

use_intensity_for_com: bool = False
use_total_intensity_for_com: bool = False
"""
Whether to use the magnitude of the dominant shared probe
mode for computing the center of mass of the probe in order
to keep it centered, or to use the total probe intensity.
"""

use_intensity_for_com: bool = False
"""
Deprecated alias for `use_total_intensity_for_com`.
"""

center_modes_individually: bool = False
"""
If True, each mode is shifted individually based on their own center of mass.
"""

def check(self, options: "task_options.PtychographyTaskOptions"):
super().check(options)
if self.use_intensity_for_com:
warnings.warn(
"`probe_options.center_constraint.use_intensity_for_com` is deprecated; "
"use `probe_options.center_constraint.use_total_intensity_for_com` instead.",
DeprecationWarning,
stacklevel=2,
)
self.use_total_intensity_for_com = True
if self.center_modes_individually and self.use_total_intensity_for_com:
raise ValueError(
"`probe_options.center_constraint.use_total_intensity_for_com` must be False when "
"`probe_options.center_constraint.center_modes_individually` is True."
)


@dataclasses.dataclass
class ProbeOptions(ParameterOptions):
Expand Down Expand Up @@ -680,6 +707,7 @@ def check(self, options: "task_options.PtychographyTaskOptions"):
super().check(options)
if not (self.initial_guess is not None and self.initial_guess.ndim == 4):
raise ValueError("Probe initial_guess must be a (n_opr_modes, n_modes, h, w) tensor.")
self.center_constraint.check(options)
if self.power_constraint.enabled and options.object_options.remove_object_probe_ambiguity.enabled:
logger.warning(
"`ObjectOptions.remove_object_probe_ambiguity` and `ProbeOptions.power_constraint` "
Expand Down
28 changes: 24 additions & 4 deletions src/ptychi/data_structures/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,34 @@ def center_probe(self):
"""
Move the probe's center of mass to the center of the probe array.
"""

if self.options.center_constraint.use_intensity_for_com:
center_constraint = self.options.center_constraint
use_total_intensity_for_com = (
center_constraint.use_total_intensity_for_com or center_constraint.use_intensity_for_com
)
center = utils.to_tensor(
self.shape[-2:],
device=self.data.device,
dtype=torch.get_default_dtype(),
) // 2

if center_constraint.center_modes_individually:
probe = self.data.clone()
probe_to_be_shifted = probe[0]
if use_total_intensity_for_com:
probe_to_be_shifted = torch.abs(probe_to_be_shifted) ** 2

shifts = center.to(probe_to_be_shifted.device) - ip.find_center_of_mass(probe_to_be_shifted)
probe[0] = ip.fourier_shift(probe[0], shifts)
self.set_data(probe)
return

if use_total_intensity_for_com:
probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0)
else:
probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0)

com = ip.find_center_of_mass(probe_to_be_shifted)
shift = utils.to_tensor(self.shape[-2:]) // 2 - com
shift = center.to(com.device) - com
shifted_probe = self.shift(shift)
self.set_data(shifted_probe)

Expand Down
80 changes: 80 additions & 0 deletions tests/test_2d_ptycho_probe_center_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
import torch
from types import SimpleNamespace

import ptychi.image_proc as ip
from ptychi.api.options import base as obase
from ptychi.data_structures.probe import Probe


def _make_probe(
data: torch.Tensor,
*,
center_modes_individually: bool,
use_total_intensity_for_com: bool,
) -> Probe:
options = obase.ProbeOptions()
options.optimizable = False
options.center_constraint.enabled = True
options.center_constraint.center_modes_individually = center_modes_individually
options.center_constraint.use_total_intensity_for_com = use_total_intensity_for_com
return Probe(data=data, options=options)


def test_center_probe_can_shift_incoherent_modes_individually():
data = torch.zeros((2, 2, 7, 7), dtype=torch.complex64)
data[0, 0, 1, 2] = 1 + 0j
data[0, 1, 4, 1] = 1 + 0j

data[1, 0, 0, 0] = 2 + 0j
data[1, 0, 2, 5] = 3 + 0j
data[1, 1, 5, 6] = 4 + 0j
data[1, 1, 6, 1] = 5 + 0j

secondary_opr_modes_before = data[1:].clone()
probe = _make_probe(
data,
center_modes_individually=True,
use_total_intensity_for_com=False,
)

probe.center_probe()

expected_center = torch.tensor([[3.0, 3.0], [3.0, 3.0]])
centered_mode_com = ip.find_center_of_mass(torch.abs(probe.data[0]) ** 2)

assert torch.allclose(centered_mode_com, expected_center, atol=1e-4)
assert torch.allclose(probe.data[1:], secondary_opr_modes_before)


def test_probe_center_constraint_check_rejects_individual_mode_centering_with_intensity_com():
options = obase.ProbeOptions()
options.initial_guess = torch.zeros((1, 1, 7, 7), dtype=torch.complex64)
options.center_constraint.center_modes_individually = True
options.center_constraint.use_total_intensity_for_com = True

task_options = SimpleNamespace(
object_options=SimpleNamespace(
remove_object_probe_ambiguity=SimpleNamespace(enabled=False)
)
)

with pytest.raises(ValueError, match="use_total_intensity_for_com"):
options.check(task_options)


def test_probe_center_constraint_check_promotes_deprecated_intensity_flag():
options = obase.ProbeOptions()
options.initial_guess = torch.zeros((1, 1, 7, 7), dtype=torch.complex64)
options.center_constraint.use_intensity_for_com = True

task_options = SimpleNamespace(
object_options=SimpleNamespace(
remove_object_probe_ambiguity=SimpleNamespace(enabled=False)
)
)

with pytest.warns(DeprecationWarning, match="use_total_intensity_for_com"):
options.check(task_options)

assert options.center_constraint.use_total_intensity_for_com is True
Loading