Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,38 @@ def check(self, options: "task_options.PtychographyTaskOptions"):
raise ValueError("LBFGS optimizer is currently only supported for Autodiff reconstructors.")


@dataclasses.dataclass
class RealSpaceScalingOptions(ParameterOptions):
initial_guess: float = 1.0
"""Initial global real-space scaling factor."""

optimizable: bool = False
"""Whether the real-space scaling factor is optimizable."""

differentiation_method: enums.ImageGradientMethods = enums.ImageGradientMethods.FOURIER_DIFFERENTIATION
"""Method used to compute object gradients for the scaling update."""

def check(self, options: "task_options.PtychographyTaskOptions"):
super().check(options)
if self.initial_guess <= 0:
raise ValueError("`real_space_scaling_options.initial_guess` must be positive.")
if self.optimizer == enums.Optimizers.LBFGS and "Autodiff" not in options.__class__.__name__:
raise ValueError("LBFGS optimizer is currently only supported for Autodiff reconstructors.")
affine_constraint = options.probe_position_options.affine_transform_constraint
if (
self.optimizable
and options.probe_position_options.optimizable
and affine_constraint.enabled
and enums.AffineDegreesOfFreedom.SCALE in affine_constraint.degrees_of_freedom
):
logger.warning(
"Do not enable `real_space_scaling_options.optimizable` together with "
"`probe_position_options.affine_transform_constraint` when "
"`AffineDegreesOfFreedom.SCALE` is included. Both refine the same "
"far-field scale ambiguity."
)


@dataclasses.dataclass
class OPRModeWeightsSmoothingOptions(FeatureOptions):
"""Settings for smoothing OPR mode weights."""
Expand Down
3 changes: 3 additions & 0 deletions src/ptychi/api/options/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class PtychographyTaskOptions(base.TaskOptions):

probe_position_options: base.ProbePositionOptions = field(default_factory=base.ProbePositionOptions)

real_space_scaling_options: base.RealSpaceScalingOptions = field(default_factory=base.RealSpaceScalingOptions)

opr_mode_weight_options: base.OPRModeWeightsOptions = field(default_factory=base.OPRModeWeightsOptions)

def check(self, *args, **kwargs):
Expand All @@ -31,6 +33,7 @@ def check(self, *args, **kwargs):
self.object_options,
self.probe_options,
self.probe_position_options,
self.real_space_scaling_options,
self.opr_mode_weight_options,
):
options.check(self)
38 changes: 34 additions & 4 deletions src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ptychi.data_structures.probe as probe
import ptychi.data_structures.probe_positions as probepos
import ptychi.data_structures.parameter_group as paramgrp
import ptychi.data_structures.real_space_scaling as rsscaling
import ptychi.maps as maps
from ptychi.io_handles import PtychographyDataset
from ptychi.reconstructors.base import Reconstructor
Expand Down Expand Up @@ -67,13 +68,15 @@ def __init__(self, options: api.options.task.PtychographyTaskOptions, *args, **k
self.object_options = options.object_options
self.probe_options = options.probe_options
self.position_options = options.probe_position_options
self.real_space_scaling_options = options.real_space_scaling_options
self.opr_mode_weight_options = options.opr_mode_weight_options
self.reconstructor_options = options.reconstructor_options

self.dataset = None
self.object = None
self.probe = None
self.probe_positions = None
self.real_space_scaling = None
self.opr_mode_weights = None
self.reconstructor: Reconstructor | None = None

Expand All @@ -92,6 +95,7 @@ def build(self):
self.build_object()
self.build_probe()
self.build_probe_positions()
self.build_real_space_scaling()
self.build_opr_mode_weights()
self.build_reconstructor()

Expand Down Expand Up @@ -218,6 +222,20 @@ def build_probe_positions(self):
data = torch.stack([pos_y, pos_x], dim=1)
self.probe_positions = probepos.ProbePositions(data=data, options=self.position_options)

def build_real_space_scaling(self):
"""Build the global real-space scaling parameter.

The constructed parameter stores a real tensor of shape ``(1,)``.
"""
data = torch.tensor(
[self.real_space_scaling_options.initial_guess],
device=torch.get_default_device(),
dtype=torch.get_default_dtype(),
)
self.real_space_scaling = rsscaling.RealSpaceScaling(
data=data, options=self.real_space_scaling_options
)

def build_opr_mode_weights(self):
if self.opr_mode_weight_options.initial_weights is None:
initial_weights = torch.ones([self.data_options.data.shape[0], 1])
Expand All @@ -237,6 +255,7 @@ def build_reconstructor(self):
object=self.object,
probe=self.probe,
probe_positions=self.probe_positions,
real_space_scaling=self.real_space_scaling,
opr_mode_weights=self.opr_mode_weights,
)

Expand Down Expand Up @@ -280,13 +299,14 @@ def run(self, n_epochs: int = None, reset_timer_globals: bool = True):
self.reconstructor.run(n_epochs=n_epochs)

def get_data(
self, name: Literal["object", "probe", "probe_positions", "opr_mode_weights"]
self,
name: Literal["object", "probe", "probe_positions", "real_space_scaling", "opr_mode_weights"],
) -> Tensor:
"""Get a detached copy of the data of the given name.

Parameters
----------
name : Literal["object", "probe", "probe_positions", "opr_mode_weights"]
name : Literal["object", "probe", "probe_positions", "real_space_scaling", "opr_mode_weights"]
The name of the data to get.

Returns
Expand All @@ -304,7 +324,7 @@ def get_data(

def get_data_to_cpu(
self,
name: Literal["object", "probe", "probe_positions", "opr_mode_weights"],
name: Literal["object", "probe", "probe_positions", "real_space_scaling", "opr_mode_weights"],
as_numpy: bool = False,
) -> Union[Tensor, ndarray]:
data = self.get_data(name).cpu()
Expand All @@ -327,7 +347,13 @@ def get_probe_positions_x(self, as_numpy: bool = False) -> Union[Tensor, ndarray
def copy_data_from_task(
self,
task: "PtychographyTask",
params_to_copy: tuple[str, ...] = ("object", "probe", "probe_positions", "opr_mode_weights")
params_to_copy: tuple[str, ...] = (
"object",
"probe",
"probe_positions",
"real_space_scaling",
"opr_mode_weights",
)
) -> None:
"""Copy data of reconstruction parameters from another task object.

Expand All @@ -352,6 +378,10 @@ def copy_data_from_task(
self.reconstructor.parameter_group.probe_positions.set_data(
task.get_data("probe_positions")
)
elif param == "real_space_scaling":
self.reconstructor.parameter_group.real_space_scaling.set_data(
task.get_data("real_space_scaling")
)
elif param == "opr_mode_weights":
self.reconstructor.parameter_group.opr_mode_weights.set_data(
task.get_data("opr_mode_weights")
Expand Down
3 changes: 3 additions & 0 deletions src/ptychi/data_structures/parameter_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ptychi.data_structures.opr_mode_weights as oprweights
import ptychi.data_structures.probe as probe
import ptychi.data_structures.probe_positions as probepos
import ptychi.data_structures.real_space_scaling as rsscaling
from ptychi.parallel import MultiprocessMixin


Expand Down Expand Up @@ -95,6 +96,8 @@ class PtychographyParameterGroup(ParameterGroup):

probe_positions: "probepos.ProbePositions"

real_space_scaling: "rsscaling.RealSpaceScaling"

opr_mode_weights: "oprweights.OPRModeWeights"

def __post_init__(self):
Expand Down
114 changes: 114 additions & 0 deletions src/ptychi/data_structures/real_space_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright © 2025 UChicago Argonne, LLC All right reserved
# Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE

from typing import TYPE_CHECKING

import torch
from torch import Tensor

import ptychi.api.enums as enums
import ptychi.data_structures.base as dsbase
import ptychi.image_proc as ip

if TYPE_CHECKING:
import ptychi.api.options.base as base_options


class RealSpaceScaling(dsbase.ReconstructParameter):
options: "base_options.RealSpaceScalingOptions"

def __init__(
self,
*args,
name: str = "real_space_scaling",
options: "base_options.RealSpaceScalingOptions" = None,
**kwargs,
):
"""Global real-space scaling factor applied to the exit wavefield.

Parameters
----------
data : Tensor, optional
A real tensor of shape ``(1,)`` containing the scaling factor.
"""
super().__init__(*args, name=name, options=options, is_complex=False, **kwargs)
if self.shape != (1,):
raise ValueError("RealSpaceScaling must contain exactly one scalar element.")

def post_update_hook(self, *args, **kwargs):
"""Clamp the parameter tensor in place.

The updated tensor has shape ``(1,)``.
"""
with torch.no_grad():
self.tensor.clamp_(min=1e-6)

def get_update(
self,
chi: Tensor,
obj_patches: Tensor,
probe: Tensor,
eps: float = 1e-6,
) -> Tensor:
"""Estimate the update direction of the real-space scaling factor.

This follows the same first-order approximation used by PtychoShelves
for detector-scale refinement: object gradients are combined with a
radial weighting and projected onto the exit-wave update.

Parameters
----------
chi : Tensor
A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``
giving the exit-wave update at the current slice.
obj_patches : Tensor
A complex tensor of shape ``(batch_size, n_slices, h, w)``
containing object patches for the current batch.
probe : Tensor
A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``
containing the incident wavefields at the current slice.
eps : float
Small stabilizer added to the denominator.

Returns
-------
Tensor
A real tensor of shape ``(1,)`` containing the additive update
direction for the global scaling factor.
"""
obj_patches = obj_patches[:, 0]
probe = probe[:, 0]
chi_m0 = chi[:, 0]

if self.options.differentiation_method == enums.ImageGradientMethods.GAUSSIAN:
dody, dodx = ip.gaussian_gradient(obj_patches, sigma=0.33)
elif self.options.differentiation_method == enums.ImageGradientMethods.FOURIER_DIFFERENTIATION:
dody, dodx = ip.fourier_gradient(obj_patches)
elif self.options.differentiation_method == enums.ImageGradientMethods.FOURIER_SHIFT:
dody, dodx = ip.fourier_shift_gradient(obj_patches)
elif self.options.differentiation_method == enums.ImageGradientMethods.NEAREST:
dody, dodx = ip.nearest_neighbor_gradient(obj_patches, "backward")
else:
raise ValueError(
f"Unsupported differentiation method: {self.options.differentiation_method}"
)

h, w = obj_patches.shape[-2:]
xgrid = -torch.linspace(-1, 1, w, device=obj_patches.device, dtype=obj_patches.real.dtype)
ygrid = -torch.linspace(-1, 1, h, device=obj_patches.device, dtype=obj_patches.real.dtype)
xgrid = xgrid * ip.tukey_window(w, 0.1, device=xgrid.device, dtype=xgrid.dtype)
ygrid = ygrid * ip.tukey_window(h, 0.1, device=ygrid.device, dtype=ygrid.dtype)
xgrid = xgrid.view(1, 1, w)
ygrid = ygrid.view(1, h, 1)

dm_o = dodx * xgrid + dody * ygrid
dm_op = dm_o * probe
nom = torch.real(dm_op.conj() * chi_m0).sum(dim=(-1, -2))
denom = (dm_op.abs() ** 2).sum(dim=(-1, -2))
denom_bias = torch.maximum(
denom.max(),
torch.tensor(eps, device=denom.device, dtype=denom.dtype),
)
delta_scale = nom / (denom + denom_bias)
delta_scale = 0.5 * delta_scale.mean() / ((h + w) * 0.5)
return delta_scale.reshape(1)
49 changes: 47 additions & 2 deletions src/ptychi/forward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __init__(
self.object = parameter_group.object
self.probe = parameter_group.probe
self.probe_positions = parameter_group.probe_positions
self.real_space_scaling = parameter_group.real_space_scaling
self.opr_mode_weights = parameter_group.opr_mode_weights

self.wavelength_m = wavelength_m
Expand Down Expand Up @@ -450,10 +451,55 @@ def forward_far_field(self, psi: Tensor) -> Tensor:
Tensor
A (batch_size, n_probe_modes, h, w) tensor of far field waves.
"""
psi = self.apply_real_space_scaling(psi)
psi_far = self.free_space_propagator.propagate_forward(psi)
self.record_intermediate_variable("psi_far", psi_far)
return psi_far

@timer()
def apply_real_space_scaling(self, psi: Tensor) -> Tensor:
"""Apply the global real-space scaling before detector propagation.

Parameters
----------
psi : Tensor
A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``.

Returns
-------
Tensor
A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``.
"""
scale = self.real_space_scaling.data[0]
if (not scale.requires_grad) and torch.allclose(scale, torch.ones_like(scale)):
return psi
orig_shape = psi.shape
psi = psi.reshape(-1, *orig_shape[-2:])
psi = ip.rescale_images(psi, scale)
return psi.reshape(orig_shape)

@timer()
def apply_real_space_scaling_adjoint(self, psi: Tensor) -> Tensor:
"""Apply the adjoint of the global real-space scaling operator.

Parameters
----------
psi : Tensor
A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``.

Returns
-------
Tensor
A complex tensor of shape ``(batch_size, n_probe_modes, h, w)``.
"""
scale = self.real_space_scaling.data[0]
if (not scale.requires_grad) and torch.allclose(scale, torch.ones_like(scale)):
return psi
orig_shape = psi.shape
psi = psi.reshape(-1, *orig_shape[-2:])
psi = ip.rescale_images(psi, scale, adjoint=True)
return psi.reshape(orig_shape)

@timer()
def propagate_to_next_slice(self, psi: Tensor, slice_index: int):
"""
Expand Down Expand Up @@ -640,8 +686,7 @@ def forward_low_memory(self, indices: Tensor, return_object_patches: bool = Fals
probe[..., i_mode : i_mode + 1, :, :], obj_patches
)

psi_far = self.free_space_propagator.propagate_forward(exit_psi)
self.record_intermediate_variable("psi_far", psi_far)
psi_far = self.forward_far_field(exit_psi)

y = y + psi_far[..., 0, :, :].abs() ** 2

Expand Down
Loading
Loading