Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,11 @@ class ProbeCenterConstraintOptions(FeatureOptions):
to keep it centered, or to use the total probe intensity.
"""

use_com_for_each_mode_independently: bool = False
"""
Determine and use the center of mass of the probe magnitude to keep each
shared mode centered independently.
Comment on lines +639 to +640
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_com_for_each_mode_independently interacts with the existing use_intensity_for_com option, but the docstring doesn’t clarify precedence (or whether they’re meant to be mutually exclusive). Consider updating the field docstring to state how it combines with use_intensity_for_com (e.g., “if True, this takes precedence and COM is computed per incoherent mode using |probe|”).

Suggested change
Determine and use the center of mass of the probe magnitude to keep each
shared mode centered independently.
Determine and use the center of mass (COM) of the probe magnitude to keep each
shared mode centered independently.
When set to ``True``, the COM is computed separately for each incoherent/shared
probe mode using ``|probe|`` and this per-mode COM centering takes precedence
over :attr:`use_intensity_for_com`. In this case,
:attr:`use_intensity_for_com` is ignored for the COM computation.
When set to ``False``, the COM is computed according to
:attr:`use_intensity_for_com` (e.g., using either the dominant shared probe
mode or the total probe intensity) rather than per-mode centering.

Copilot uses AI. Check for mistakes.
"""

@dataclasses.dataclass
class ProbeOptions(ParameterOptions):
Expand Down
17 changes: 9 additions & 8 deletions src/ptychi/data_structures/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,12 @@ def shift(self, shifts: Tensor):
)
shifted_probe = shifted_probe.view(*self.shape)
else:
n_shifts = shifts.shape[0]
n_images_each_probe = self.shape[0] * self.shape[1]
probe_straightened = self.tensor.complex().view(n_images_each_probe, *self.shape[-2:])
probe_straightened = probe_straightened.repeat(n_shifts, 1, 1)
shifts = shifts.repeat_interleave(n_images_each_probe, dim=0)
shifted_probe = ip.fourier_shift(probe_straightened, shifts)
shifted_probe = shifted_probe.reshape(n_shifts, *self.shape)
shifted_probe = torch.zeros_like(self.data)
shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier')
shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...],
shifts[0,:].repeat(self.data[1:,0,...].shape[0],1),
method='fourier'
)
Comment on lines +87 to +92
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In shift() when shifts is 2D, shifted_probe is initialized with zeros and only shifted_probe[0, ...] plus shifted_probe[1:, 0, ...] are populated. This silently zeroes self.data[1:, 1:, ...] and also doesn’t apply the per-mode shift to all OPR modes of each incoherent mode (which is what a (n_modes, 2) shift tensor implies). Consider building shifted_probe from a clone of self.data and shifting all (opr_mode, mode) images with the appropriate shift (e.g., flatten (n_opr_modes*n_modes, H, W) and expand shifts to match), so no parts of the probe are dropped.

Suggested change
shifted_probe = torch.zeros_like(self.data)
shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier')
shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...],
shifts[0,:].repeat(self.data[1:,0,...].shape[0],1),
method='fourier'
)
# Expect one shift vector per incoherent mode: shifts.shape == (n_modes, 2)
if shifts.shape[-1] != 2:
raise ValueError(
f"Expected shifts to have shape (n_modes, 2), got {tuple(shifts.shape)}."
)
if shifts.shape[0] != self.n_modes:
raise ValueError(
f"Number of shifts ({shifts.shape[0]}) must match n_modes ({self.n_modes})."
)
# Flatten (n_opr_modes, n_modes, H, W) -> (n_opr_modes * n_modes, H, W)
n_opr_modes, n_modes, h, w = self.shape
probe_straightened = self.tensor.complex().view(-1, h, w)
# Expand per-mode shifts to all OPR modes:
# (n_modes, 2) -> (n_opr_modes, n_modes, 2) -> (n_opr_modes * n_modes, 2)
shifts_expanded = (
shifts.unsqueeze(0) # (1, n_modes, 2)
.expand(n_opr_modes, -1, -1)
.reshape(-1, 2)
)
shifted_flat = ip.fourier_shift(probe_straightened, shifts_expanded)
shifted_probe = shifted_flat.view(n_opr_modes, n_modes, h, w)

Copilot uses AI. Check for mistakes.
Comment on lines +88 to +92
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ip.shift_images(self.data[1:, 0, ...], ...) will be called even when there is only 1 OPR mode (i.e., self.data[1:, 0, ...] is an empty batch). Depending on how shift_images handles empty batches (especially the pad branch), this can raise unexpectedly. Consider guarding this assignment with if self.shape[0] > 1: (or similar) to avoid calling shift_images on an empty tensor.

Suggested change
shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier')
shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...],
shifts[0,:].repeat(self.data[1:,0,...].shape[0],1),
method='fourier'
)
shifted_probe[0, ...] = ip.shift_images(self.data[0, ...], shifts, method='fourier')
if self.n_opr_modes > 1:
shifted_probe[1:, 0, ...] = ip.shift_images(
self.data[1:, 0, ...],
shifts[0, :].repeat(self.data[1:, 0, ...].shape[0], 1),
method='fourier',
)

Copilot uses AI. Check for mistakes.
return shifted_probe

@property
Expand Down Expand Up @@ -417,7 +416,9 @@ def center_probe(self):
Move the probe's center of mass to the center of the probe array.
"""

if self.options.center_constraint.use_intensity_for_com:
if self.options.center_constraint.use_com_for_each_mode_independently:
probe_to_be_shifted = self.data[0,...]
elif self.options.center_constraint.use_intensity_for_com:
probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0)
else:
probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0)
Comment on lines +419 to 424
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With use_com_for_each_mode_independently=True, probe_to_be_shifted becomes a (n_modes, H, W) stack and find_center_of_mass returns (n_modes, 2) shifts. Given that, the centering behavior depends on Probe.shift() correctly interpreting a (n_modes, 2) shift tensor as “shift each incoherent mode by its own shift”. It would be safer to either (a) document/validate that shifts.shape[0] == self.shape[1] in this branch, or (b) make center_probe() apply the per-mode shifting directly so it can’t be misinterpreted as a batch-of-probes shift.

Copilot uses AI. Check for mistakes.
Expand Down
Loading