-
Notifications
You must be signed in to change notification settings - Fork 7
fix for if the probe centering constraint shift is not a 2 element vector #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -84,13 +84,12 @@ def shift(self, shifts: Tensor): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifted_probe = shifted_probe.view(*self.shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_shifts = shifts.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_images_each_probe = self.shape[0] * self.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| probe_straightened = self.tensor.complex().view(n_images_each_probe, *self.shape[-2:]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| probe_straightened = probe_straightened.repeat(n_shifts, 1, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifts = shifts.repeat_interleave(n_images_each_probe, dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifted_probe = ip.fourier_shift(probe_straightened, shifts) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifted_probe = shifted_probe.reshape(n_shifts, *self.shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifted_probe = torch.zeros_like(self.data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| method='fourier' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+87
to
+92
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shifted_probe = torch.zeros_like(self.data) | |
| shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') | |
| shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], | |
| shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), | |
| method='fourier' | |
| ) | |
| # Expect one shift vector per incoherent mode: shifts.shape == (n_modes, 2) | |
| if shifts.shape[-1] != 2: | |
| raise ValueError( | |
| f"Expected shifts to have shape (n_modes, 2), got {tuple(shifts.shape)}." | |
| ) | |
| if shifts.shape[0] != self.n_modes: | |
| raise ValueError( | |
| f"Number of shifts ({shifts.shape[0]}) must match n_modes ({self.n_modes})." | |
| ) | |
| # Flatten (n_opr_modes, n_modes, H, W) -> (n_opr_modes * n_modes, H, W) | |
| n_opr_modes, n_modes, h, w = self.shape | |
| probe_straightened = self.tensor.complex().view(-1, h, w) | |
| # Expand per-mode shifts to all OPR modes: | |
| # (n_modes, 2) -> (n_opr_modes, n_modes, 2) -> (n_opr_modes * n_modes, 2) | |
| shifts_expanded = ( | |
| shifts.unsqueeze(0) # (1, n_modes, 2) | |
| .expand(n_opr_modes, -1, -1) | |
| .reshape(-1, 2) | |
| ) | |
| shifted_flat = ip.fourier_shift(probe_straightened, shifts_expanded) | |
| shifted_probe = shifted_flat.view(n_opr_modes, n_modes, h, w) |
Copilot
AI
Mar 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ip.shift_images(self.data[1:, 0, ...], ...) will be called even when there is only 1 OPR mode (i.e., self.data[1:, 0, ...] is an empty batch). Depending on how shift_images handles empty batches (especially the pad branch), this can raise unexpectedly. Consider guarding this assignment with if self.shape[0] > 1: (or similar) to avoid calling shift_images on an empty tensor.
| shifted_probe[0,...] = ip.shift_images(self.data[0,...], shifts, method='fourier') | |
| shifted_probe[1:,0,...] = ip.shift_images(self.data[1:,0,...], | |
| shifts[0,:].repeat(self.data[1:,0,...].shape[0],1), | |
| method='fourier' | |
| ) | |
| shifted_probe[0, ...] = ip.shift_images(self.data[0, ...], shifts, method='fourier') | |
| if self.n_opr_modes > 1: | |
| shifted_probe[1:, 0, ...] = ip.shift_images( | |
| self.data[1:, 0, ...], | |
| shifts[0, :].repeat(self.data[1:, 0, ...].shape[0], 1), | |
| method='fourier', | |
| ) |
Copilot
AI
Mar 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With use_com_for_each_mode_independently=True, probe_to_be_shifted becomes a (n_modes, H, W) stack and find_center_of_mass returns (n_modes, 2) shifts. Given that, the centering behavior depends on Probe.shift() correctly interpreting a (n_modes, 2) shift tensor as “shift each incoherent mode by its own shift”. It would be safer to either (a) document/validate that shifts.shape[0] == self.shape[1] in this branch, or (b) make center_probe() apply the per-mode shifting directly so it can’t be misinterpreted as a batch-of-probes shift.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_com_for_each_mode_independentlyinteracts with the existinguse_intensity_for_comoption, but the docstring doesn’t clarify precedence (or whether they’re meant to be mutually exclusive). Consider updating the field docstring to state how it combines withuse_intensity_for_com(e.g., “if True, this takes precedence and COM is computed per incoherent mode using |probe|”).