diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 62f2bfd..431a76d 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -913,6 +913,38 @@ def check(self, options: "task_options.PtychographyTaskOptions"): raise ValueError("LBFGS optimizer is currently only supported for Autodiff reconstructors.") +@dataclasses.dataclass +class RealSpaceScalingOptions(ParameterOptions): + initial_guess: float = 1.0 + """Initial global real-space scaling factor.""" + + optimizable: bool = False + """Whether the real-space scaling factor is optimizable.""" + + differentiation_method: enums.ImageGradientMethods = enums.ImageGradientMethods.FOURIER_DIFFERENTIATION + """Method used to compute object gradients for the scaling update.""" + + def check(self, options: "task_options.PtychographyTaskOptions"): + super().check(options) + if self.initial_guess <= 0: + raise ValueError("`real_space_scaling_options.initial_guess` must be positive.") + if self.optimizer == enums.Optimizers.LBFGS and "Autodiff" not in options.__class__.__name__: + raise ValueError("LBFGS optimizer is currently only supported for Autodiff reconstructors.") + affine_constraint = options.probe_position_options.affine_transform_constraint + if ( + self.optimizable + and options.probe_position_options.optimizable + and affine_constraint.enabled + and enums.AffineDegreesOfFreedom.SCALE in affine_constraint.degrees_of_freedom + ): + logger.warning( + "Do not enable `real_space_scaling_options.optimizable` together with " + "`probe_position_options.affine_transform_constraint` when " + "`AffineDegreesOfFreedom.SCALE` is included. Both refine the same " + "far-field scale ambiguity." + ) + + @dataclasses.dataclass class OPRModeWeightsSmoothingOptions(FeatureOptions): """Settings for smoothing OPR mode weights.""" diff --git a/src/ptychi/api/options/task.py b/src/ptychi/api/options/task.py index 4c3855b..8db8f15 100644 --- a/src/ptychi/api/options/task.py +++ b/src/ptychi/api/options/task.py @@ -21,6 +21,8 @@ class PtychographyTaskOptions(base.TaskOptions): probe_position_options: base.ProbePositionOptions = field(default_factory=base.ProbePositionOptions) + real_space_scaling_options: base.RealSpaceScalingOptions = field(default_factory=base.RealSpaceScalingOptions) + opr_mode_weight_options: base.OPRModeWeightsOptions = field(default_factory=base.OPRModeWeightsOptions) def check(self, *args, **kwargs): @@ -31,6 +33,7 @@ def check(self, *args, **kwargs): self.object_options, self.probe_options, self.probe_position_options, + self.real_space_scaling_options, self.opr_mode_weight_options, ): options.check(self) diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index 548ead4..49d4bf9 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -18,6 +18,7 @@ import ptychi.data_structures.probe as probe import ptychi.data_structures.probe_positions as probepos import ptychi.data_structures.parameter_group as paramgrp +import ptychi.data_structures.real_space_scaling as rsscaling import ptychi.maps as maps from ptychi.io_handles import PtychographyDataset from ptychi.reconstructors.base import Reconstructor @@ -67,6 +68,7 @@ def __init__(self, options: api.options.task.PtychographyTaskOptions, *args, **k self.object_options = options.object_options self.probe_options = options.probe_options self.position_options = options.probe_position_options + self.real_space_scaling_options = options.real_space_scaling_options self.opr_mode_weight_options = options.opr_mode_weight_options self.reconstructor_options = options.reconstructor_options @@ -74,6 +76,7 @@ def __init__(self, options: api.options.task.PtychographyTaskOptions, *args, **k self.object = None self.probe = None self.probe_positions = None + self.real_space_scaling = None self.opr_mode_weights = None self.reconstructor: Reconstructor | None = None @@ -92,6 +95,7 @@ def build(self): self.build_object() self.build_probe() self.build_probe_positions() + self.build_real_space_scaling() self.build_opr_mode_weights() self.build_reconstructor() @@ -218,6 +222,20 @@ def build_probe_positions(self): data = torch.stack([pos_y, pos_x], dim=1) self.probe_positions = probepos.ProbePositions(data=data, options=self.position_options) + def build_real_space_scaling(self): + """Build the global real-space scaling parameter. + + The constructed parameter stores a real tensor of shape ``(1,)``. + """ + data = torch.tensor( + [self.real_space_scaling_options.initial_guess], + device=torch.get_default_device(), + dtype=torch.get_default_dtype(), + ) + self.real_space_scaling = rsscaling.RealSpaceScaling( + data=data, options=self.real_space_scaling_options + ) + def build_opr_mode_weights(self): if self.opr_mode_weight_options.initial_weights is None: initial_weights = torch.ones([self.data_options.data.shape[0], 1]) @@ -237,6 +255,7 @@ def build_reconstructor(self): object=self.object, probe=self.probe, probe_positions=self.probe_positions, + real_space_scaling=self.real_space_scaling, opr_mode_weights=self.opr_mode_weights, ) @@ -280,13 +299,14 @@ def run(self, n_epochs: int = None, reset_timer_globals: bool = True): self.reconstructor.run(n_epochs=n_epochs) def get_data( - self, name: Literal["object", "probe", "probe_positions", "opr_mode_weights"] + self, + name: Literal["object", "probe", "probe_positions", "real_space_scaling", "opr_mode_weights"], ) -> Tensor: """Get a detached copy of the data of the given name. Parameters ---------- - name : Literal["object", "probe", "probe_positions", "opr_mode_weights"] + name : Literal["object", "probe", "probe_positions", "real_space_scaling", "opr_mode_weights"] The name of the data to get. Returns @@ -304,7 +324,7 @@ def get_data( def get_data_to_cpu( self, - name: Literal["object", "probe", "probe_positions", "opr_mode_weights"], + name: Literal["object", "probe", "probe_positions", "real_space_scaling", "opr_mode_weights"], as_numpy: bool = False, ) -> Union[Tensor, ndarray]: data = self.get_data(name).cpu() @@ -327,7 +347,13 @@ def get_probe_positions_x(self, as_numpy: bool = False) -> Union[Tensor, ndarray def copy_data_from_task( self, task: "PtychographyTask", - params_to_copy: tuple[str, ...] = ("object", "probe", "probe_positions", "opr_mode_weights") + params_to_copy: tuple[str, ...] = ( + "object", + "probe", + "probe_positions", + "real_space_scaling", + "opr_mode_weights", + ) ) -> None: """Copy data of reconstruction parameters from another task object. @@ -352,6 +378,10 @@ def copy_data_from_task( self.reconstructor.parameter_group.probe_positions.set_data( task.get_data("probe_positions") ) + elif param == "real_space_scaling": + self.reconstructor.parameter_group.real_space_scaling.set_data( + task.get_data("real_space_scaling") + ) elif param == "opr_mode_weights": self.reconstructor.parameter_group.opr_mode_weights.set_data( task.get_data("opr_mode_weights") diff --git a/src/ptychi/data_structures/parameter_group.py b/src/ptychi/data_structures/parameter_group.py index 24c865c..88a7794 100644 --- a/src/ptychi/data_structures/parameter_group.py +++ b/src/ptychi/data_structures/parameter_group.py @@ -11,6 +11,7 @@ import ptychi.data_structures.opr_mode_weights as oprweights import ptychi.data_structures.probe as probe import ptychi.data_structures.probe_positions as probepos +import ptychi.data_structures.real_space_scaling as rsscaling from ptychi.parallel import MultiprocessMixin @@ -95,6 +96,8 @@ class PtychographyParameterGroup(ParameterGroup): probe_positions: "probepos.ProbePositions" + real_space_scaling: "rsscaling.RealSpaceScaling" + opr_mode_weights: "oprweights.OPRModeWeights" def __post_init__(self): diff --git a/src/ptychi/data_structures/real_space_scaling.py b/src/ptychi/data_structures/real_space_scaling.py new file mode 100644 index 0000000..3df94f5 --- /dev/null +++ b/src/ptychi/data_structures/real_space_scaling.py @@ -0,0 +1,114 @@ +# Copyright © 2025 UChicago Argonne, LLC All right reserved +# Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE + +from typing import TYPE_CHECKING + +import torch +from torch import Tensor + +import ptychi.api.enums as enums +import ptychi.data_structures.base as dsbase +import ptychi.image_proc as ip + +if TYPE_CHECKING: + import ptychi.api.options.base as base_options + + +class RealSpaceScaling(dsbase.ReconstructParameter): + options: "base_options.RealSpaceScalingOptions" + + def __init__( + self, + *args, + name: str = "real_space_scaling", + options: "base_options.RealSpaceScalingOptions" = None, + **kwargs, + ): + """Global real-space scaling factor applied to the exit wavefield. + + Parameters + ---------- + data : Tensor, optional + A real tensor of shape ``(1,)`` containing the scaling factor. + """ + super().__init__(*args, name=name, options=options, is_complex=False, **kwargs) + if self.shape != (1,): + raise ValueError("RealSpaceScaling must contain exactly one scalar element.") + + def post_update_hook(self, *args, **kwargs): + """Clamp the parameter tensor in place. + + The updated tensor has shape ``(1,)``. + """ + with torch.no_grad(): + self.tensor.clamp_(min=1e-6) + + def get_update( + self, + chi: Tensor, + obj_patches: Tensor, + probe: Tensor, + eps: float = 1e-6, + ) -> Tensor: + """Estimate the update direction of the real-space scaling factor. + + This follows the same first-order approximation used by PtychoShelves + for detector-scale refinement: object gradients are combined with a + radial weighting and projected onto the exit-wave update. + + Parameters + ---------- + chi : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)`` + giving the exit-wave update at the current slice. + obj_patches : Tensor + A complex tensor of shape ``(batch_size, n_slices, h, w)`` + containing object patches for the current batch. + probe : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)`` + containing the incident wavefields at the current slice. + eps : float + Small stabilizer added to the denominator. + + Returns + ------- + Tensor + A real tensor of shape ``(1,)`` containing the additive update + direction for the global scaling factor. + """ + obj_patches = obj_patches[:, 0] + probe = probe[:, 0] + chi_m0 = chi[:, 0] + + if self.options.differentiation_method == enums.ImageGradientMethods.GAUSSIAN: + dody, dodx = ip.gaussian_gradient(obj_patches, sigma=0.33) + elif self.options.differentiation_method == enums.ImageGradientMethods.FOURIER_DIFFERENTIATION: + dody, dodx = ip.fourier_gradient(obj_patches) + elif self.options.differentiation_method == enums.ImageGradientMethods.FOURIER_SHIFT: + dody, dodx = ip.fourier_shift_gradient(obj_patches) + elif self.options.differentiation_method == enums.ImageGradientMethods.NEAREST: + dody, dodx = ip.nearest_neighbor_gradient(obj_patches, "backward") + else: + raise ValueError( + f"Unsupported differentiation method: {self.options.differentiation_method}" + ) + + h, w = obj_patches.shape[-2:] + xgrid = -torch.linspace(-1, 1, w, device=obj_patches.device, dtype=obj_patches.real.dtype) + ygrid = -torch.linspace(-1, 1, h, device=obj_patches.device, dtype=obj_patches.real.dtype) + xgrid = xgrid * ip.tukey_window(w, 0.1, device=xgrid.device, dtype=xgrid.dtype) + ygrid = ygrid * ip.tukey_window(h, 0.1, device=ygrid.device, dtype=ygrid.dtype) + xgrid = xgrid.view(1, 1, w) + ygrid = ygrid.view(1, h, 1) + + dm_o = dodx * xgrid + dody * ygrid + dm_op = dm_o * probe + nom = torch.real(dm_op.conj() * chi_m0).sum(dim=(-1, -2)) + denom = (dm_op.abs() ** 2).sum(dim=(-1, -2)) + denom_bias = torch.maximum( + denom.max(), + torch.tensor(eps, device=denom.device, dtype=denom.dtype), + ) + delta_scale = nom / (denom + denom_bias) + delta_scale = 0.5 * delta_scale.mean() / ((h + w) * 0.5) + return delta_scale.reshape(1) diff --git a/src/ptychi/forward_models.py b/src/ptychi/forward_models.py index 43b26a1..b12bcf9 100644 --- a/src/ptychi/forward_models.py +++ b/src/ptychi/forward_models.py @@ -208,6 +208,7 @@ def __init__( self.object = parameter_group.object self.probe = parameter_group.probe self.probe_positions = parameter_group.probe_positions + self.real_space_scaling = parameter_group.real_space_scaling self.opr_mode_weights = parameter_group.opr_mode_weights self.wavelength_m = wavelength_m @@ -450,10 +451,55 @@ def forward_far_field(self, psi: Tensor) -> Tensor: Tensor A (batch_size, n_probe_modes, h, w) tensor of far field waves. """ + psi = self.apply_real_space_scaling(psi) psi_far = self.free_space_propagator.propagate_forward(psi) self.record_intermediate_variable("psi_far", psi_far) return psi_far + @timer() + def apply_real_space_scaling(self, psi: Tensor) -> Tensor: + """Apply the global real-space scaling before detector propagation. + + Parameters + ---------- + psi : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``. + + Returns + ------- + Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``. + """ + scale = self.real_space_scaling.data[0] + if (not scale.requires_grad) and torch.allclose(scale, torch.ones_like(scale)): + return psi + orig_shape = psi.shape + psi = psi.reshape(-1, *orig_shape[-2:]) + psi = ip.rescale_images(psi, scale) + return psi.reshape(orig_shape) + + @timer() + def apply_real_space_scaling_adjoint(self, psi: Tensor) -> Tensor: + """Apply the adjoint of the global real-space scaling operator. + + Parameters + ---------- + psi : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``. + + Returns + ------- + Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``. + """ + scale = self.real_space_scaling.data[0] + if (not scale.requires_grad) and torch.allclose(scale, torch.ones_like(scale)): + return psi + orig_shape = psi.shape + psi = psi.reshape(-1, *orig_shape[-2:]) + psi = ip.rescale_images(psi, scale, adjoint=True) + return psi.reshape(orig_shape) + @timer() def propagate_to_next_slice(self, psi: Tensor, slice_index: int): """ @@ -640,8 +686,7 @@ def forward_low_memory(self, indices: Tensor, return_object_patches: bool = Fals probe[..., i_mode : i_mode + 1, :, :], obj_patches ) - psi_far = self.free_space_propagator.propagate_forward(exit_psi) - self.record_intermediate_variable("psi_far", psi_far) + psi_far = self.forward_far_field(exit_psi) y = y + psi_far[..., 0, :, :].abs() ** 2 diff --git a/src/ptychi/image_proc.py b/src/ptychi/image_proc.py index 1ac0465..d693ccd 100644 --- a/src/ptychi/image_proc.py +++ b/src/ptychi/image_proc.py @@ -7,7 +7,7 @@ import torch from torch import Tensor -import torch.signal +import torch.nn.functional as F import ptychi.maths as pmath from ptychi.api.types import ComplexTensor, RealTensor @@ -26,6 +26,46 @@ def __call__(self, image: Tensor, positions: Tensor, shape: Tuple[int, int]) -> logger = logging.getLogger(__name__) +def tukey_window( + length: int, + alpha: float = 0.5, + *, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> Tensor: + """Return a 1D Tukey window. + + Returns + ------- + Tensor + A real tensor of shape ``(length,)``. + """ + if length <= 0: + raise ValueError("Window length must be positive.") + if alpha <= 0: + return torch.ones(length, device=device, dtype=dtype) + if alpha >= 1: + return torch.hann_window(length, periodic=False, device=device, dtype=dtype) + if length == 1: + return torch.ones(1, device=device, dtype=dtype) + + x = torch.linspace(0, 1, length, device=device, dtype=dtype) + window = torch.ones(length, device=device, dtype=dtype) + edge = alpha / 2 + + left = x < edge + if torch.any(left): + window[left] = 0.5 * (1 + torch.cos(math.pi * ((2 * x[left] / alpha) - 1))) + + right = x >= (1 - edge) + if torch.any(right): + window[right] = 0.5 * ( + 1 + torch.cos(math.pi * ((2 * x[right] / alpha) - (2 / alpha) + 1)) + ) + + return window + + @timer() def batch_slice(image: Tensor, sy: Tensor, sx: Tensor, patch_size: Tuple[int, int]) -> Tensor: """ @@ -565,6 +605,161 @@ def shift_images( return images +def _prepare_rescale_factors( + scales: float | Tensor, n_images: int, device: torch.device, dtype: torch.dtype +) -> Tensor: + """Broadcast scaling factors to match a batch of images. + + Parameters + ---------- + scales : float | Tensor + Scalar or real tensor of shape ``(n_images,)``. + n_images : int + Batch size. + + Returns + ------- + Tensor + A real tensor of shape ``(n_images,)``. + """ + scales = torch.as_tensor(scales, device=device, dtype=dtype) + if scales.ndim == 0: + scales = scales.expand(n_images) + elif scales.numel() == 1: + scales = scales.reshape(1).expand(n_images) + elif scales.shape != (n_images,): + raise ValueError( + f"`scales` should be a scalar or a ({n_images},)-shaped tensor, got {tuple(scales.shape)}." + ) + return scales + + +def _rescale_images_channels(images: Tensor, scales: Tensor, mode: str) -> Tensor: + """Rescale channel-packed images. + + Parameters + ---------- + images : Tensor + A real tensor of shape ``(n_images, n_channels, h, w)``. + scales : Tensor + A real tensor of shape ``(n_images,)``. + + Returns + ------- + Tensor + A real tensor of shape ``(n_images, n_channels, h, w)``. + """ + n_images, n_channels, height, width = images.shape + theta = torch.zeros((n_images, 2, 3), device=images.device, dtype=images.dtype) + theta[:, 0, 0] = 1.0 / scales + theta[:, 1, 1] = 1.0 / scales + grid = F.affine_grid(theta, images.shape, align_corners=False) + return F.grid_sample( + images, + grid, + mode=mode, + padding_mode="border", + align_corners=False, + ) + + +def _pack_rescale_channels(images: Tensor) -> tuple[Tensor, bool]: + """Pack real/imaginary parts into a channel dimension. + + Parameters + ---------- + images : Tensor + A real or complex tensor of shape ``(n_images, h, w)``. + + Returns + ------- + tuple[Tensor, bool] + The first entry is a real tensor of shape + ``(n_images, n_channels, h, w)``, where ``n_channels`` is 1 for + real inputs and 2 for complex inputs. The second entry indicates + whether the original tensor was complex. + """ + is_complex = images.dtype.is_complex + if is_complex: + images = torch.stack([images.real, images.imag], dim=1) + else: + images = images.unsqueeze(1) + return images, is_complex + + +def _unpack_rescale_channels(images: Tensor, is_complex: bool) -> Tensor: + """Convert channel-packed images back to their original dtype. + + Parameters + ---------- + images : Tensor + A real tensor of shape ``(n_images, n_channels, h, w)``. + is_complex : bool + Whether to reconstruct a complex output. + + Returns + ------- + Tensor + A tensor of shape ``(n_images, h, w)``. + """ + if is_complex: + return images[:, 0] + 1j * images[:, 1] + return images[:, 0] + + +def rescale_images( + images: Tensor, + scales: float | Tensor, + *, + adjoint: bool = False, + mode: Literal["bilinear"] = "bilinear", +) -> Tensor: + """Rescale a batch of images while keeping the same output size. + + Parameters + ---------- + images : Tensor + A real or complex tensor of shape ``(n_images, h, w)``. + scales : float | Tensor + Scalar zoom factor or a real tensor of shape ``(n_images,)``. + adjoint : bool + If True, apply the exact adjoint of the forward rescaling operator. + mode : Literal["bilinear"] + Interpolation mode for the forward operator. + + Returns + ------- + Tensor + A tensor with the same shape and dtype as ``images``, i.e. + ``(n_images, h, w)``. + """ + if images.ndim != 3: + raise ValueError(f"`images` should be a (N, H, W) tensor, got shape {tuple(images.shape)}.") + + scales = _prepare_rescale_factors( + scales, images.shape[0], device=images.device, dtype=images.real.dtype + ) + if (not scales.requires_grad) and torch.allclose(scales, torch.ones_like(scales)): + return images + + images_ch, is_complex = _pack_rescale_channels(images) + if not adjoint: + scaled = _rescale_images_channels(images_ch, scales, mode=mode) + return _unpack_rescale_channels(scaled, is_complex) + + with torch.enable_grad(): + basis = torch.zeros_like(images_ch, requires_grad=True) + forward = _rescale_images_channels(basis, scales, mode=mode) + adjoint_out = torch.autograd.grad( + forward, + basis, + grad_outputs=images_ch, + create_graph=False, + retain_graph=False, + )[0] + return _unpack_rescale_channels(adjoint_out, is_complex) + + @timer() def fourier_shift(images: Tensor, shifts: Tensor, strictly_preserve_zeros: bool = False) -> Tensor: """ diff --git a/src/ptychi/reconstructors/base.py b/src/ptychi/reconstructors/base.py index d5bc456..c3a0f40 100644 --- a/src/ptychi/reconstructors/base.py +++ b/src/ptychi/reconstructors/base.py @@ -731,6 +731,43 @@ def replace_propagated_exit_wave_magnitude( if constrained_pixel_mask is not None: return torch.where(constrained_pixel_mask[:, None], psi_prime, psi) return psi_prime + + @timer() + def propagate_exit_wave_to_detector(self, psi: Tensor) -> Tensor: + """Apply real-space scaling and propagate an exit wave to the detector. + + Parameters + ---------- + psi : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``. + + Returns + ------- + Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)`` + at the detector plane. + """ + psi = self.forward_model.apply_real_space_scaling(psi) + return self.forward_model.free_space_propagator.propagate_forward(psi) + + @timer() + def propagate_detector_wave_to_exit_adjoint(self, psi_far: Tensor) -> Tensor: + """Apply the detector-to-exit-wave adjoint propagation. + + Parameters + ---------- + psi_far : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)`` + at the detector plane. + + Returns + ------- + Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)`` + at the exit plane. + """ + psi = self.forward_model.free_space_propagator.propagate_backward(psi_far) + return self.forward_model.apply_real_space_scaling_adjoint(psi) @timer() def adjoint_shift_probe_update_direction(self, indices, delta_p, first_mode_only=False): @@ -772,3 +809,34 @@ def adjoint_shift_probe_update_direction(self, indices, delta_p, first_mode_only else: delta_p = delta_p_shifted.reshape(orig_shape) return delta_p + + @timer() + def update_real_space_scaling( + self, + chi: Tensor, + obj_patches: Tensor, + unique_probes: Tensor, + apply_updates: bool = True, + ) -> None: + """Update the global real-space scaling parameter. + + Parameters + ---------- + chi : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)`` + giving the current exit-wave update. + obj_patches : Tensor + A complex tensor of shape ``(batch_size, n_slices, h, w)``. + unique_probes : Tensor + A complex tensor of shape ``(batch_size, n_probe_modes, h, w)`` + giving the incident wavefields used for the update. + apply_updates : bool + If True, apply the optimizer step immediately. Otherwise, only + populate the gradient tensor of shape ``(1,)``. + """ + scaling = self.parameter_group.real_space_scaling + delta_scale = scaling.get_update(chi, obj_patches, unique_probes) + scaling.set_grad(-delta_scale) + if apply_updates: + scaling.step_optimizer() + scaling.post_update_hook() diff --git a/src/ptychi/reconstructors/bh.py b/src/ptychi/reconstructors/bh.py index 2ffc948..8eaa369 100644 --- a/src/ptychi/reconstructors/bh.py +++ b/src/ptychi/reconstructors/bh.py @@ -117,6 +117,10 @@ def apply_updates(self, delta_o, delta_p, delta_pos, *args, **kwargs): probe_positions.set_grad(-delta_pos) probe_positions.optimizer.step() + if self.parameter_group.real_space_scaling.optimization_enabled(self.current_epoch): + self.parameter_group.real_space_scaling.step_optimizer() + self.parameter_group.real_space_scaling.post_update_hook() + def compute_updates( self, indices: torch.Tensor, y_true: torch.Tensor ) -> tuple[torch.Tensor, ...]: @@ -145,6 +149,15 @@ def compute_updates( # Gradient for the Gaussian model gradF = self.gradientF(psi_far, d) + if self.parameter_group.real_space_scaling.optimization_enabled(self.current_epoch): + unique_probe = p.unsqueeze(0).expand(op.shape[0], -1, -1, -1) + self.update_real_space_scaling( + -0.5 * gradF, + op, + unique_probe, + apply_updates=False, + ) + # shortcuts o_opt = object_.optimization_enabled(self.current_epoch) p_opt = probe.optimization_enabled(self.current_epoch) @@ -299,7 +312,10 @@ def gradientF(self, psi_far, d): # Compensate FFT normalization only for the far-field Fourier propagator. if isinstance(self.forward_model.free_space_propagator, fm.FourierPropagator): td *= psi_far.shape[-1] * psi_far.shape[-2] - res = 2 * self.forward_model.free_space_propagator.propagate_backward(td) + if hasattr(self, "propagate_detector_wave_to_exit_adjoint"): + res = 2 * self.propagate_detector_wave_to_exit_adjoint(td) + else: + res = 2 * self.forward_model.free_space_propagator.propagate_backward(td) return res def hessianF(self, psi_far, psi_far1, psi_far2, data): @@ -362,7 +378,7 @@ def calc_alpha_object(self, p, do1, do2, d, psi_far, dop2): top = -redot(do1, do2) dm2 = p * dop2 - Ldm2 = self.forward_model.free_space_propagator.propagate_forward(dm2) + Ldm2 = self.propagate_exit_wave_to_detector(dm2) bottom = self.hessianF(psi_far, Ldm2, Ldm2, d) return top / bottom @@ -374,7 +390,7 @@ def calc_alpha_object_probe(self, p, do1, dp1, do2, dp2, d, gradF, psi_far, op, dm2 = dp2 * op + p * dop2 d2m2 = 2 * dp2 * dop2 - Ldm2 = self.forward_model.free_space_propagator.propagate_forward(dm2) + Ldm2 = self.propagate_exit_wave_to_detector(dm2) bottom = redot(gradF, d2m2) + self.hessianF(psi_far, Ldm2, Ldm2, d) return top / bottom @@ -409,7 +425,7 @@ def calc_alpha_object_probe_positions( dm2 = dp2 * op + p * (dop2 + dt) d2m2 = p * (2 * dt2 + d2t) + 2 * dp2 * dop2 + 2 * dp2 * dt - Ldm2 = self.forward_model.free_space_propagator.propagate_forward(dm2) + Ldm2 = self.propagate_exit_wave_to_detector(dm2) bottom = redot(gradF, d2m2) + self.hessianF(psi_far, Ldm2, Ldm2, d) return top / bottom @@ -419,8 +435,8 @@ def calc_beta_object(self, p, d, psi_far, dop1, dop2): dm1 = p * dop1 dm2 = p * dop2 - Ldm1 = self.forward_model.free_space_propagator.propagate_forward(dm1) - Ldm2 = self.forward_model.free_space_propagator.propagate_forward(dm2) + Ldm1 = self.propagate_exit_wave_to_detector(dm1) + Ldm2 = self.propagate_exit_wave_to_detector(dm2) top = self.hessianF(psi_far, Ldm1, Ldm2, d) bottom = self.hessianF(psi_far, Ldm2, Ldm2, d) @@ -438,8 +454,8 @@ def calc_beta_object_probe(self, p, dp1, dp2, d, gradF, psi_far, op, dop1, dop2) d2m1 = dp1 * dop2 + dp2 * dop1 d2m2 = 2 * dp2 * dop2 - Ldm1 = self.forward_model.free_space_propagator.propagate_forward(dm1) - Ldm2 = self.forward_model.free_space_propagator.propagate_forward(dm2) + Ldm1 = self.propagate_exit_wave_to_detector(dm1) + Ldm2 = self.propagate_exit_wave_to_detector(dm2) top = redot(gradF, d2m1) + self.hessianF(psi_far, Ldm1, Ldm2, d) bottom = redot(gradF, d2m2) + self.hessianF(psi_far, Ldm2, Ldm2, d) @@ -495,8 +511,8 @@ def calc_beta_object_probe_positions( d2m1 = p * (dt12 + dt21 + d2t1) + dp1 * (dop2 + dt2) + dp2 * (dop1 + dt1) d2m2 = p * (2 * dt22 + d2t2) + 2 * dp2 * dop2 + 2 * dp2 * dt2 - Ldm1 = self.forward_model.free_space_propagator.propagate_forward(dm1) - Ldm2 = self.forward_model.free_space_propagator.propagate_forward(dm2) + Ldm1 = self.propagate_exit_wave_to_detector(dm1) + Ldm2 = self.propagate_exit_wave_to_detector(dm2) top = redot(gradF, d2m1) + self.hessianF(psi_far, Ldm1, Ldm2, d) bottom = redot(gradF, d2m2) + self.hessianF(psi_far, Ldm2, Ldm2, d) diff --git a/src/ptychi/reconstructors/dm.py b/src/ptychi/reconstructors/dm.py index b7a211b..32871b9 100644 --- a/src/ptychi/reconstructors/dm.py +++ b/src/ptychi/reconstructors/dm.py @@ -144,6 +144,7 @@ def compute_updates(self, y_true: Tensor) -> Tensor: probe_numerator = torch.zeros_like(probe.get_opr_mode(0)) probe_denominator = torch.zeros_like(probe.get_opr_mode(0).abs()) delta_pos = torch.zeros_like(probe_positions.data) + delta_rsscale = torch.zeros_like(self.parameter_group.real_space_scaling.data) dm_error_squared = 0 for i in range(n_chunks): obj_patches, dm_error_squared, new_psi = self.apply_dm_update_to_exit_wave_chunk( @@ -163,6 +164,23 @@ def compute_updates(self, y_true: Tensor) -> Tensor: obj_patches=obj_patches, chi=self.psi[start_pts[i] : end_pts[i]] - new_psi, ) + if self.parameter_group.real_space_scaling.optimization_enabled(self.current_epoch): + indices_chunk = torch.arange( + start_pts[i], end_pts[i], device=obj_patches.device, dtype=torch.long + ) + unique_probes = self.forward_model.get_unique_probes( + indices_chunk, + always_return_probe_batch=True, + ) + if self.forward_model.apply_subpixel_shifts_on_probe: + unique_probes = self.forward_model.shift_unique_probes( + indices_chunk, unique_probes, first_mode_only=True + ) + delta_rsscale += self.parameter_group.real_space_scaling.get_update( + self.psi[start_pts[i] : end_pts[i]] - new_psi, + obj_patches, + unique_probes, + ) # Update the probe if probe.optimization_enabled(self.current_epoch): @@ -178,6 +196,11 @@ def compute_updates(self, y_true: Tensor) -> Tensor: probe_positions.set_grad(-delta_pos) probe_positions.step_optimizer() + if self.parameter_group.real_space_scaling.optimization_enabled(self.current_epoch): + self.parameter_group.real_space_scaling.set_grad(-delta_rsscale / n_chunks) + self.parameter_group.real_space_scaling.step_optimizer() + self.parameter_group.real_space_scaling.post_update_hook() + return dm_error_squared @timer() @@ -239,9 +262,7 @@ def apply_dm_update_to_exit_wave_chunk( start_pt, end_pt, return_obj_patches=True ) # Propagate to detector plane - revised_psi = self.forward_model.free_space_propagator.propagate_forward( - 2 * new_psi - self.psi[start_pt:end_pt] - ) + revised_psi = self.propagate_exit_wave_to_detector(2 * new_psi - self.psi[start_pt:end_pt]) # Replace intensities revised_psi = self.replace_propagated_exit_wave_magnitude( revised_psi, @@ -249,7 +270,7 @@ def apply_dm_update_to_exit_wave_chunk( constrained_pixel_mask=self.get_constrained_pixel_mask(y_true[start_pt:end_pt]), ) # Propagate back to sample plane - revised_psi = self.forward_model.free_space_propagator.propagate_backward(revised_psi) + revised_psi = self.propagate_detector_wave_to_exit_adjoint(revised_psi) # Update the exit wave psi_update = (revised_psi - new_psi) * self.options.exit_wave_update_relaxation self.psi[start_pt:end_pt] += psi_update diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 71f2335..7914700 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -383,7 +383,7 @@ def run_reciprocal_space_step(self, y_pred, y_true, indices): self.alpha_psi_far = self.get_psi_far_step_size(y_pred, y_true, indices) psi_far = psi_far_0 - self.alpha_psi_far.view(-1, 1, 1, 1) * dl_dpsi_far # Eq. 14 - psi_opt = self.forward_model.free_space_propagator.propagate_backward(psi_far) + psi_opt = self.propagate_detector_wave_to_exit_adjoint(psi_far) return psi_opt @timer() @@ -432,6 +432,7 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): self._initialize_object_gradient() self.parameter_group.probe.initialize_grad() self.parameter_group.probe_positions.initialize_grad() + self.parameter_group.real_space_scaling.initialize_grad() self.parameter_group.opr_mode_weights.initialize_grad() self._initialize_object_step_size_buffer() self._initialize_probe_step_size_buffer() @@ -519,6 +520,19 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): apply_updates=False, ) + if ( + self.parameter_group.real_space_scaling.optimization_enabled(self.current_epoch) + and i_slice == self.parameter_group.probe_positions.get_slice_for_correction( + object_.n_slices + ) + ): + self.update_real_space_scaling( + chi, + obj_patches[:, i_slice : i_slice + 1], + self.forward_model.intermediate_variables.shifted_unique_probes[i_slice], + apply_updates=False, + ) + # Set chi to conjugate-modulated wavefield. chi = delta_p_i_before_adj_shift @@ -555,6 +569,14 @@ def apply_reconstruction_parameter_updates(self, indices: torch.Tensor): self.parameter_group.probe_positions.step_optimizer( clip_update=self.parameter_group.probe_positions.options.momentum_acceleration_gain <= 0 ) + + real_space_scaling = getattr(self.parameter_group, "real_space_scaling", None) + if ( + real_space_scaling is not None + and real_space_scaling.optimization_enabled(self.current_epoch) + ): + real_space_scaling.step_optimizer() + real_space_scaling.post_update_hook() # Update OPR modes and weights. if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index ab9a38f..04f02ce 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -66,6 +66,7 @@ def check_inputs(self, *args, **kwargs): @timer() def run_minibatch(self, input_data, y_true, *args, **kwargs): self.parameter_group.probe.initialize_grad() + self.parameter_group.real_space_scaling.initialize_grad() (delta_o, delta_p_i, delta_pos), y_pred = self.compute_updates(*input_data, y_true) self.apply_updates(delta_o, delta_p_i, delta_pos) self.loss_tracker.update_batch_loss_with_metric_function(y_pred, y_true) @@ -97,7 +98,7 @@ def compute_updates( y_true, constrained_pixel_mask=self.get_constrained_pixel_mask(y_true), ) - psi_prime = self.forward_model.free_space_propagator.propagate_backward(psi_prime) + psi_prime = self.propagate_detector_wave_to_exit_adjoint(psi_prime) delta_exwv_i = psi_prime - psi delta_o = torch.zeros_like(object_.data) @@ -131,6 +132,17 @@ def compute_updates( object_.step_size, ) + if ( + self.parameter_group.real_space_scaling.optimization_enabled(self.current_epoch) + and i_slice == self.parameter_group.probe_positions.get_slice_for_correction(object_.n_slices) + ): + self.update_real_space_scaling( + delta_exwv_i, + obj_patches[:, i_slice : i_slice + 1, ...], + unique_probes[i_slice], + apply_updates=False, + ) + delta_p_i = None if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)): if (self.parameter_group.probe.representation == "sparse_code"): @@ -280,6 +292,10 @@ def apply_updates(self, delta_o, delta_p_i, delta_pos, *args, **kwargs): probe_positions.set_grad(-delta_pos) probe_positions.step_optimizer() + if self.parameter_group.real_space_scaling.optimization_enabled(self.current_epoch): + self.parameter_group.real_space_scaling.step_optimizer() + self.parameter_group.real_space_scaling.post_update_hook() + class EPIEReconstructor(PIEReconstructor): """ diff --git a/tests/test_real_space_scaling.py b/tests/test_real_space_scaling.py new file mode 100644 index 0000000..708a913 --- /dev/null +++ b/tests/test_real_space_scaling.py @@ -0,0 +1,90 @@ +import torch + +import ptychi.api as api +import ptychi.image_proc as ip +from ptychi.api.task import PtychographyTask +from ptychi.api.options.base import RealSpaceScalingOptions +from ptychi.data_structures.real_space_scaling import RealSpaceScaling + + +def make_valid_options(): + options = api.LSQMLOptions() + options.data_options.data = torch.ones((2, 4, 4), dtype=torch.float32) + options.object_options.initial_guess = torch.ones((1, 8, 8), dtype=torch.complex64) + options.object_options.pixel_size_m = 1.0 + options.probe_options.initial_guess = torch.ones((1, 1, 4, 4), dtype=torch.complex64) + options.probe_position_options.position_y_px = torch.tensor([-1.0, 1.0]) + options.probe_position_options.position_x_px = torch.tensor([-1.0, 1.0]) + options.reconstructor_options.default_device = api.Devices.CPU + return options + + +def test_real_space_scaling_defaults_and_task_build(): + options = make_valid_options() + + assert options.real_space_scaling_options.initial_guess == 1.0 + assert options.real_space_scaling_options.optimizable is False + + task = PtychographyTask(options) + scaling = task.reconstructor.parameter_group.real_space_scaling + + assert scaling.shape == (1,) + assert torch.allclose(scaling.data, torch.tensor([1.0], device=scaling.data.device)) + assert task.get_data("real_space_scaling").shape == (1,) + + +def test_rescale_images_adjoint(): + torch.manual_seed(0) + x = torch.randn((3, 7, 5), dtype=torch.complex64) + y = torch.randn((3, 7, 5), dtype=torch.complex64) + scale = torch.tensor(1.07) + + ax = ip.rescale_images(x, scale) + a_star_y = ip.rescale_images(y, scale, adjoint=True) + + lhs = torch.sum(ax.conj() * y) + rhs = torch.sum(x.conj() * a_star_y) + assert torch.allclose(lhs, rhs, atol=1e-4, rtol=1e-4) + + +def test_real_space_scaling_update_matches_formula(): + options = RealSpaceScalingOptions( + optimizable=True, + differentiation_method=api.ImageGradientMethods.NEAREST, + ) + scaling = RealSpaceScaling(data=torch.tensor([1.0]), options=options) + + obj = torch.tensor( + [[ + [0.0, 1.0, 2.0, 3.0], + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + ]], + dtype=torch.complex64, + ).repeat(2, 1, 1, 1) + probe = torch.ones((2, 1, 4, 4), dtype=torch.complex64) + + dody, dodx = ip.nearest_neighbor_gradient(obj[:, 0], "backward") + xgrid = -torch.linspace(-1, 1, 4) + ygrid = -torch.linspace(-1, 1, 4) + xgrid = xgrid * ip.tukey_window(4, 0.1) + ygrid = ygrid * ip.tukey_window(4, 0.1) + dm_o = dodx * xgrid.view(1, 1, 4) + dody * ygrid.view(1, 4, 1) + chi = (dm_o * probe[:, 0]).unsqueeze(1) + + delta = scaling.get_update(chi, obj, probe) + + dm_op = dm_o * probe[:, 0] + nom = torch.real(dm_op.conj() * chi[:, 0]).sum(dim=(-1, -2)) + denom = (dm_op.abs() ** 2).sum(dim=(-1, -2)) + denom_bias = torch.maximum( + denom.max(), + torch.tensor(1e-6, device=denom.device, dtype=denom.dtype), + ) + expected = nom / (denom + denom_bias) + expected = 0.5 * expected.mean() / 4.0 + + assert delta.shape == (1,) + assert torch.allclose(delta[0], expected) + assert delta[0] > 0