Skip to content

fix for if the probe centering constraint shift is not a 2 element vector#62

Closed
a4894z wants to merge 1 commit intomainfrom
probe_centering_each_mode_its_own_com
Closed

fix for if the probe centering constraint shift is not a 2 element vector#62
a4894z wants to merge 1 commit intomainfrom
probe_centering_each_mode_its_own_com

Conversation

@a4894z
Copy link
Copy Markdown
Collaborator

@a4894z a4894z commented Mar 11, 2026

In the probe class definition in probe.py, the centering part has a bug if we want to shift each probe independently according to its own center of mass.

for example, let the probe be of shape [3,5,512,512] (3 OPRs, 5 shared modes); if I use:

probe_to_be_shifted = self.data[0,...]

instead of

if self.options.center_constraint.use_intensity_for_com:
probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0)
else:
probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0)

then after executing this:

        n_shifts = shifts.shape[0]
        n_images_each_probe = self.shape[0] * self.shape[1]
        probe_straightened = self.tensor.complex().view(n_images_each_probe, *self.shape[-2:])
        probe_straightened = probe_straightened.repeat(n_shifts, 1, 1)
        shifts = shifts.repeat_interleave(n_images_each_probe, dim=0)
        shifted_probe = ip.fourier_shift(probe_straightened, shifts)
        shifted_probe = shifted_probe.reshape(n_shifts, *self.shape)

the returned probe shape will be

[5, 3, 5, 512, 512]

which will then cause an error when you call the setter.

I'm introducing a proposed fix here.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes probe centering when computing a separate center-of-mass shift per incoherent probe mode (e.g., probe.shape == (n_opr_modes, n_modes, H, W)), by adding an option to compute COM per mode and updating probe shifting logic to avoid shape mismatches when shifts are not a single (2,) vector.

Changes:

  • Add use_com_for_each_mode_independently to ProbeCenterConstraintOptions.
  • Update Probe.center_probe() to optionally compute COM from self.data[0, ...] (stack of modes) instead of a single mode / total intensity.
  • Rewrite Probe.shift() handling for 2D shifts inputs to use ip.shift_images on selected probe slices.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
src/ptychi/data_structures/probe.py Changes how batched/per-mode shifting is performed and adds a per-mode COM centering path.
src/ptychi/api/options/base.py Introduces a new center constraint option flag controlling per-mode COM computation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +87 to +92
shifted_probe = torch.zeros_like(self.data)
shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier')
shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...],
shifts[0,:].repeat(self.data[1:,0,...].shape[0],1),
method='fourier'
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In shift() when shifts is 2D, shifted_probe is initialized with zeros and only shifted_probe[0, ...] plus shifted_probe[1:, 0, ...] are populated. This silently zeroes self.data[1:, 1:, ...] and also doesn’t apply the per-mode shift to all OPR modes of each incoherent mode (which is what a (n_modes, 2) shift tensor implies). Consider building shifted_probe from a clone of self.data and shifting all (opr_mode, mode) images with the appropriate shift (e.g., flatten (n_opr_modes*n_modes, H, W) and expand shifts to match), so no parts of the probe are dropped.

Suggested change
shifted_probe = torch.zeros_like(self.data)
shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier')
shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...],
shifts[0,:].repeat(self.data[1:,0,...].shape[0],1),
method='fourier'
)
# Expect one shift vector per incoherent mode: shifts.shape == (n_modes, 2)
if shifts.shape[-1] != 2:
raise ValueError(
f"Expected shifts to have shape (n_modes, 2), got {tuple(shifts.shape)}."
)
if shifts.shape[0] != self.n_modes:
raise ValueError(
f"Number of shifts ({shifts.shape[0]}) must match n_modes ({self.n_modes})."
)
# Flatten (n_opr_modes, n_modes, H, W) -> (n_opr_modes * n_modes, H, W)
n_opr_modes, n_modes, h, w = self.shape
probe_straightened = self.tensor.complex().view(-1, h, w)
# Expand per-mode shifts to all OPR modes:
# (n_modes, 2) -> (n_opr_modes, n_modes, 2) -> (n_opr_modes * n_modes, 2)
shifts_expanded = (
shifts.unsqueeze(0) # (1, n_modes, 2)
.expand(n_opr_modes, -1, -1)
.reshape(-1, 2)
)
shifted_flat = ip.fourier_shift(probe_straightened, shifts_expanded)
shifted_probe = shifted_flat.view(n_opr_modes, n_modes, h, w)

Copilot uses AI. Check for mistakes.
Comment on lines +88 to +92
shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier')
shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...],
shifts[0,:].repeat(self.data[1:,0,...].shape[0],1),
method='fourier'
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ip.shift_images(self.data[1:, 0, ...], ...) will be called even when there is only 1 OPR mode (i.e., self.data[1:, 0, ...] is an empty batch). Depending on how shift_images handles empty batches (especially the pad branch), this can raise unexpectedly. Consider guarding this assignment with if self.shape[0] > 1: (or similar) to avoid calling shift_images on an empty tensor.

Suggested change
shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier')
shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...],
shifts[0,:].repeat(self.data[1:,0,...].shape[0],1),
method='fourier'
)
shifted_probe[0, ...] = ip.shift_images(self.data[0, ...], shifts, method='fourier')
if self.n_opr_modes > 1:
shifted_probe[1:, 0, ...] = ip.shift_images(
self.data[1:, 0, ...],
shifts[0, :].repeat(self.data[1:, 0, ...].shape[0], 1),
method='fourier',
)

Copilot uses AI. Check for mistakes.
Comment on lines +419 to 424
if self.options.center_constraint.use_com_for_each_mode_independently:
probe_to_be_shifted = self.data[0,...]
elif self.options.center_constraint.use_intensity_for_com:
probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0)
else:
probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With use_com_for_each_mode_independently=True, probe_to_be_shifted becomes a (n_modes, H, W) stack and find_center_of_mass returns (n_modes, 2) shifts. Given that, the centering behavior depends on Probe.shift() correctly interpreting a (n_modes, 2) shift tensor as “shift each incoherent mode by its own shift”. It would be safer to either (a) document/validate that shifts.shape[0] == self.shape[1] in this branch, or (b) make center_probe() apply the per-mode shifting directly so it can’t be misinterpreted as a batch-of-probes shift.

Copilot uses AI. Check for mistakes.
Comment on lines +639 to +640
Determine and use the center of mass of the probe magnitude to keep each
shared mode centered independently.
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_com_for_each_mode_independently interacts with the existing use_intensity_for_com option, but the docstring doesn’t clarify precedence (or whether they’re meant to be mutually exclusive). Consider updating the field docstring to state how it combines with use_intensity_for_com (e.g., “if True, this takes precedence and COM is computed per incoherent mode using |probe|”).

Suggested change
Determine and use the center of mass of the probe magnitude to keep each
shared mode centered independently.
Determine and use the center of mass (COM) of the probe magnitude to keep each
shared mode centered independently.
When set to ``True``, the COM is computed separately for each incoherent/shared
probe mode using ``|probe|`` and this per-mode COM centering takes precedence
over :attr:`use_intensity_for_com`. In this case,
:attr:`use_intensity_for_com` is ignored for the COM computation.
When set to ``False``, the COM is computed according to
:attr:`use_intensity_for_com` (e.g., using either the dominant shared probe
mode or the total probe intensity) rather than per-mode centering.

Copilot uses AI. Check for mistakes.
@a4894z a4894z changed the title fix for if the probe centering constraint shift is not a 2 element ve… fix for if the probe centering constraint shift is not a 2 element vector Mar 11, 2026
@mdw771
Copy link
Copy Markdown
Collaborator

mdw771 commented Mar 13, 2026

Probe.shift is meant to create copies of the whole 4D probes for all shift vectors in shift, not for individual mode. What you mentioned is actually the desired behavior of this method; it just shouldn't be used for shifting individual modes. For now I'll close this PR and create a new routine to realize the per-mode centering you mentioned.

@mdw771 mdw771 closed this Mar 13, 2026
@a4894z a4894z deleted the probe_centering_each_mode_its_own_com branch March 17, 2026 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants