fix for if the probe centering constraint shift is not a 2 element vector#62
fix for if the probe centering constraint shift is not a 2 element vector#62
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes probe centering when computing a separate center-of-mass shift per incoherent probe mode (e.g., probe.shape == (n_opr_modes, n_modes, H, W)), by adding an option to compute COM per mode and updating probe shifting logic to avoid shape mismatches when shifts are not a single (2,) vector.
Changes:
- Add
use_com_for_each_mode_independentlytoProbeCenterConstraintOptions. - Update
Probe.center_probe()to optionally compute COM fromself.data[0, ...](stack of modes) instead of a single mode / total intensity. - Rewrite
Probe.shift()handling for 2Dshiftsinputs to useip.shift_imageson selected probe slices.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
src/ptychi/data_structures/probe.py |
Changes how batched/per-mode shifting is performed and adds a per-mode COM centering path. |
src/ptychi/api/options/base.py |
Introduces a new center constraint option flag controlling per-mode COM computation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| shifted_probe = torch.zeros_like(self.data) | ||
| shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') | ||
| shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], | ||
| shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), | ||
| method='fourier' | ||
| ) |
There was a problem hiding this comment.
In shift() when shifts is 2D, shifted_probe is initialized with zeros and only shifted_probe[0, ...] plus shifted_probe[1:, 0, ...] are populated. This silently zeroes self.data[1:, 1:, ...] and also doesn’t apply the per-mode shift to all OPR modes of each incoherent mode (which is what a (n_modes, 2) shift tensor implies). Consider building shifted_probe from a clone of self.data and shifting all (opr_mode, mode) images with the appropriate shift (e.g., flatten (n_opr_modes*n_modes, H, W) and expand shifts to match), so no parts of the probe are dropped.
| shifted_probe = torch.zeros_like(self.data) | |
| shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') | |
| shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], | |
| shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), | |
| method='fourier' | |
| ) | |
| # Expect one shift vector per incoherent mode: shifts.shape == (n_modes, 2) | |
| if shifts.shape[-1] != 2: | |
| raise ValueError( | |
| f"Expected shifts to have shape (n_modes, 2), got {tuple(shifts.shape)}." | |
| ) | |
| if shifts.shape[0] != self.n_modes: | |
| raise ValueError( | |
| f"Number of shifts ({shifts.shape[0]}) must match n_modes ({self.n_modes})." | |
| ) | |
| # Flatten (n_opr_modes, n_modes, H, W) -> (n_opr_modes * n_modes, H, W) | |
| n_opr_modes, n_modes, h, w = self.shape | |
| probe_straightened = self.tensor.complex().view(-1, h, w) | |
| # Expand per-mode shifts to all OPR modes: | |
| # (n_modes, 2) -> (n_opr_modes, n_modes, 2) -> (n_opr_modes * n_modes, 2) | |
| shifts_expanded = ( | |
| shifts.unsqueeze(0) # (1, n_modes, 2) | |
| .expand(n_opr_modes, -1, -1) | |
| .reshape(-1, 2) | |
| ) | |
| shifted_flat = ip.fourier_shift(probe_straightened, shifts_expanded) | |
| shifted_probe = shifted_flat.view(n_opr_modes, n_modes, h, w) |
| shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') | ||
| shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], | ||
| shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), | ||
| method='fourier' | ||
| ) |
There was a problem hiding this comment.
ip.shift_images(self.data[1:, 0, ...], ...) will be called even when there is only 1 OPR mode (i.e., self.data[1:, 0, ...] is an empty batch). Depending on how shift_images handles empty batches (especially the pad branch), this can raise unexpectedly. Consider guarding this assignment with if self.shape[0] > 1: (or similar) to avoid calling shift_images on an empty tensor.
| shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') | |
| shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], | |
| shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), | |
| method='fourier' | |
| ) | |
| shifted_probe[0, ...] = ip.shift_images(self.data[0, ...], shifts, method='fourier') | |
| if self.n_opr_modes > 1: | |
| shifted_probe[1:, 0, ...] = ip.shift_images( | |
| self.data[1:, 0, ...], | |
| shifts[0, :].repeat(self.data[1:, 0, ...].shape[0], 1), | |
| method='fourier', | |
| ) |
| if self.options.center_constraint.use_com_for_each_mode_independently: | ||
| probe_to_be_shifted = self.data[0,...] | ||
| elif self.options.center_constraint.use_intensity_for_com: | ||
| probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0) | ||
| else: | ||
| probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0) |
There was a problem hiding this comment.
With use_com_for_each_mode_independently=True, probe_to_be_shifted becomes a (n_modes, H, W) stack and find_center_of_mass returns (n_modes, 2) shifts. Given that, the centering behavior depends on Probe.shift() correctly interpreting a (n_modes, 2) shift tensor as “shift each incoherent mode by its own shift”. It would be safer to either (a) document/validate that shifts.shape[0] == self.shape[1] in this branch, or (b) make center_probe() apply the per-mode shifting directly so it can’t be misinterpreted as a batch-of-probes shift.
| Determine and use the center of mass of the probe magnitude to keep each | ||
| shared mode centered independently. |
There was a problem hiding this comment.
use_com_for_each_mode_independently interacts with the existing use_intensity_for_com option, but the docstring doesn’t clarify precedence (or whether they’re meant to be mutually exclusive). Consider updating the field docstring to state how it combines with use_intensity_for_com (e.g., “if True, this takes precedence and COM is computed per incoherent mode using |probe|”).
| Determine and use the center of mass of the probe magnitude to keep each | |
| shared mode centered independently. | |
| Determine and use the center of mass (COM) of the probe magnitude to keep each | |
| shared mode centered independently. | |
| When set to ``True``, the COM is computed separately for each incoherent/shared | |
| probe mode using ``|probe|`` and this per-mode COM centering takes precedence | |
| over :attr:`use_intensity_for_com`. In this case, | |
| :attr:`use_intensity_for_com` is ignored for the COM computation. | |
| When set to ``False``, the COM is computed according to | |
| :attr:`use_intensity_for_com` (e.g., using either the dominant shared probe | |
| mode or the total probe intensity) rather than per-mode centering. |
|
|
In the probe class definition in probe.py, the centering part has a bug if we want to shift each probe independently according to its own center of mass.
for example, let the probe be of shape [3,5,512,512] (3 OPRs, 5 shared modes); if I use:
probe_to_be_shifted = self.data[0,...]
instead of
if self.options.center_constraint.use_intensity_for_com:
probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0)
else:
probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0)
then after executing this:
the returned probe shape will be
[5, 3, 5, 512, 512]
which will then cause an error when you call the setter.
I'm introducing a proposed fix here.