Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ workspace/
*.hdf5
*.h5

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# IDE
.idea/
.vscode/
.loglogin
6 changes: 6 additions & 0 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,12 @@ class ReconstructorOptions(Options):
and is not involved in the reconstruction math.
"""

exclude_measured_pixels_below: Optional[float] = None
"""
If not None, gradients corresponding to measured diffraction pixels whose intensity
is less than or equal to this value are set to 0 in reconstructors that support it.
"""

forward_model_options: ForwardModelOptions = dataclasses.field(
default_factory=ForwardModelOptions
)
Expand Down
2 changes: 1 addition & 1 deletion src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def build_data(self):
free_space_propagation_distance_m=self.data_options.free_space_propagation_distance_m,
fft_shift=self.data_options.fft_shift,
save_data_on_device=save_on_device,
valid_pixel_mask=self.data_options.valid_pixel_mask
valid_pixel_mask=self.data_options.valid_pixel_mask,
)

def build_object(self):
Expand Down
54 changes: 45 additions & 9 deletions src/ptychi/forward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __init__(
self.intermediate_variables = self.PlanarPtychographyIntermediateVariables()

self.diffraction_pattern_blur_sigma = diffraction_pattern_blur_sigma

self.check_inputs()

def check_inputs(self):
Expand Down Expand Up @@ -760,12 +760,18 @@ def scale_gradients(self, patterns):


class NoiseModel(torch.nn.Module):
def __init__(self, eps=1e-6, valid_pixel_mask: Optional[Tensor] = None) -> None:
def __init__(
self,
eps: float = 1e-6,
valid_pixel_mask: Optional[Tensor] = None,
exclude_measured_pixels_below: Optional[float] = None,
) -> None:
super().__init__()
self.eps = eps
self.noise_statistics = None
self.valid_pixel_mask = valid_pixel_mask

self.exclude_measured_pixels_below = exclude_measured_pixels_below

def nll(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
"""
Calculate the negative log-likelihood.
Expand All @@ -775,6 +781,25 @@ def nll(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
def backward(self, *args, **kwargs):
raise NotImplementedError

@staticmethod
def get_constrained_pixel_mask(
valid_pixel_mask: Optional[Tensor],
exclude_measured_pixels_below: Optional[float],
y_true: Tensor,
) -> Tensor:
constrained_pixel_mask = torch.ones_like(y_true, dtype=torch.bool)
if valid_pixel_mask is not None:
constrained_pixel_mask = valid_pixel_mask.to(y_true.device)
constrained_pixel_mask = constrained_pixel_mask.unsqueeze(0).expand(
y_true.shape[0], -1, -1
)
if exclude_measured_pixels_below is not None:
constrained_pixel_mask = torch.logical_and(
constrained_pixel_mask,
y_true > exclude_measured_pixels_below,
)
return constrained_pixel_mask

@timer()
def conform_to_exit_wave_size(
self,
Expand Down Expand Up @@ -822,17 +847,24 @@ def backward_to_psi_far(self, y_pred, y_true, psi_far):
$g = \frac{\partial L}{\partial \psi_{far}}$.

When `self.valid_pixel_mask` is not None, pixels of the gradient `g` where the
mask is False are set to 0. When `g` is used to update the far-field wavefield
`psi_far`, the invalid pixels are kept unchanged.
mask is False are set to 0. When `self.exclude_measured_pixels_below` is not
None, gradients at pixels with measured intensities less than or equal to that
threshold are also set to 0.
"""
# Shape of g: (batch_size, h, w)
# Shape of psi_far: (batch_size, n_probe_modes, h, w)
y_pred, y_true, valid_pixel_mask = self.conform_to_exit_wave_size(
y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:]
)
constrained_pixel_mask = self.get_constrained_pixel_mask(
valid_pixel_mask,
self.exclude_measured_pixels_below,
y_true,
)
g = 1 - torch.sqrt(y_true) / (torch.sqrt(y_pred) + self.eps) # Eq. 12b
if valid_pixel_mask is not None:
g[:, torch.logical_not(valid_pixel_mask)] = 0

g[torch.logical_not(constrained_pixel_mask)] = 0

w = 1 / (2 * self.sigma) ** 2
g = 2 * w * g[:, None, :, :] * psi_far
return g
Expand Down Expand Up @@ -863,8 +895,12 @@ def backward_to_psi_far(self, y_pred: Tensor, y_true: Tensor, psi_far: Tensor):
y_pred, y_true, valid_pixel_mask = self.conform_to_exit_wave_size(
y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:]
)
constrained_pixel_mask = self.get_constrained_pixel_mask(
valid_pixel_mask,
self.exclude_measured_pixels_below,
y_true,
)
g = 1 - y_true / (y_pred + self.eps) # Eq. 12b
if valid_pixel_mask is not None:
g[:, torch.logical_not(valid_pixel_mask)] = 0
g[torch.logical_not(constrained_pixel_mask)] = 0
g = g[:, None, :, :] * psi_far
return g
2 changes: 1 addition & 1 deletion src/ptychi/io_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.free_space_propagation_distance_m = free_space_propagation_distance_m

self.save_data_on_device = save_data_on_device

def __getitem__(self, index):
if not isinstance(index, torch.Tensor):
index = torch.tensor(index, device=self.patterns.device, dtype=torch.long)
Expand Down
9 changes: 7 additions & 2 deletions src/ptychi/reconstructors/ad_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def get_retain_graph(self) -> bool:

def run_minibatch(self, input_data, y_true, *args, **kwargs):
y_pred = self.forward_model(*input_data)
constrained_pixel_mask = self.get_constrained_pixel_mask(y_true)
batch_loss = self.loss_function(
y_pred[:, self.dataset.valid_pixel_mask], y_true[:, self.dataset.valid_pixel_mask]
y_pred[constrained_pixel_mask], y_true[constrained_pixel_mask]
)

batch_loss.backward(retain_graph=self.get_retain_graph())
Expand All @@ -149,5 +150,9 @@ def run_minibatch(self, input_data, y_true, *args, **kwargs):
self.forward_model.zero_grad()
self.run_post_update_hooks()

self.loss_tracker.update_batch_loss(y_pred=y_pred, y_true=y_true, loss=batch_loss.item())
self.loss_tracker.update_batch_loss(
y_pred=y_pred[constrained_pixel_mask],
y_true=y_true[constrained_pixel_mask],
loss=batch_loss.item(),
)
self.loss_tracker.update_batch_regularization_loss(reg_loss.item())
22 changes: 19 additions & 3 deletions src/ptychi/reconstructors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,14 @@ def build_dataloader(self):
)
return super().build_dataloader(batch_sampler=batch_sampler)

def get_constrained_pixel_mask(self, y_true: Tensor) -> Tensor:
"""Get the detector pixels that should constrain the update for a batch."""
return fm.NoiseModel.get_constrained_pixel_mask(
valid_pixel_mask=self.dataset.valid_pixel_mask,
exclude_measured_pixels_below=self.options.exclude_measured_pixels_below,
y_true=y_true,
)

def update_preconditioners(self):
# Update preconditioner of the object only if:
# - the preconditioner does not exist, or
Expand Down Expand Up @@ -684,9 +692,11 @@ def prepare_data(self, *args, **kwargs):
self.parameter_group.probe.normalize_eigenmodes()
logger.info("Probe eigenmodes normalized.")

@staticmethod
def replace_propagated_exit_wave_magnitude(
psi: Tensor, actual_pattern_intensity: Tensor
self,
psi: Tensor,
actual_pattern_intensity: Tensor,
constrained_pixel_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Replace the propogated exit wave amplitude.
Expand All @@ -697,6 +707,9 @@ def replace_propagated_exit_wave_magnitude(
Predicted exit wave propagated to the detector plane.
actual_pattern_intensity : Tensor
The measured diffraction pattern at the detector.
constrained_pixel_mask : Tensor, optional
If given, only pixels where this mask is True are constrained. Other pixels
are left unchanged from `psi`.

Returns
-------
Expand All @@ -705,11 +718,14 @@ def replace_propagated_exit_wave_magnitude(
of `actual_pattern_intensity`.
"""

return (
psi_prime = (
psi
/ ((psi.abs() ** 2).sum(1, keepdims=True).sqrt() + 1e-7)
* torch.sqrt(actual_pattern_intensity + 1e-7)[:, None]
)
if constrained_pixel_mask is not None:
return torch.where(constrained_pixel_mask[:, None], psi_prime, psi)
return psi_prime

@timer()
def adjoint_shift_probe_update_direction(self, indices, delta_p, first_mode_only=False):
Expand Down
11 changes: 9 additions & 2 deletions src/ptychi/reconstructors/bh.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def compute_updates(
psi_far = self.forward_model.intermediate_variables["psi_far"]
p = probe.get_opr_mode(0) # to do for multi modes
pos = self.positions
self.current_constrained_pixel_mask = self.get_constrained_pixel_mask(y_true)[:, None]

# sqrt of data
d = torch.sqrt(y_true)[:, torch.newaxis]
Expand Down Expand Up @@ -294,6 +295,7 @@ def gradientF(self, psi_far, d):

td = d * (psi_far / (torch.abs(psi_far) + self.eps))
td = psi_far - td
td = td * self.current_constrained_pixel_mask
# Compensate FFT normalization only for the far-field Fourier propagator.
if isinstance(self.forward_model.free_space_propagator, fm.FourierPropagator):
td *= psi_far.shape[-1] * psi_far.shape[-2]
Expand All @@ -305,8 +307,13 @@ def hessianF(self, psi_far, psi_far1, psi_far2, data):

l0 = psi_far / (torch.abs(psi_far) + self.eps)
d0 = data / (torch.abs(psi_far) + self.eps)
v1 = torch.sum((1 - d0) * reprod(psi_far1, psi_far2))
v2 = torch.sum(d0 * reprod(l0, psi_far1) * reprod(l0, psi_far2))
v1 = torch.sum(self.current_constrained_pixel_mask * (1 - d0) * reprod(psi_far1, psi_far2))
v2 = torch.sum(
self.current_constrained_pixel_mask
* d0
* reprod(l0, psi_far1)
* reprod(l0, psi_far2)
)
return 2 * (v1 + v2)

def gradient_o(self, p, gradF):
Expand Down
15 changes: 6 additions & 9 deletions src/ptychi/reconstructors/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def build_dataloader(self):

@timer()
def run_minibatch(self, input_data, y_true, *args, **kwargs):
dm_error_squared = self.compute_updates(y_true, self.dataset.valid_pixel_mask)
dm_error_squared = self.compute_updates(y_true)
self.loss_tracker.update_batch_loss(loss=dm_error_squared.sqrt())

@timer()
def compute_updates(self, y_true: Tensor, valid_pixel_mask: Tensor) -> Tensor:
def compute_updates(self, y_true: Tensor) -> Tensor:
"""
Compute the updates to the object, probe, and exit wave using the procedure
described here: [Probe retrieval in ptychographic coherent diffractive imaging
Expand Down Expand Up @@ -147,7 +147,7 @@ def compute_updates(self, y_true: Tensor, valid_pixel_mask: Tensor) -> Tensor:
dm_error_squared = 0
for i in range(n_chunks):
obj_patches, dm_error_squared, new_psi = self.apply_dm_update_to_exit_wave_chunk(
start_pts[i], end_pts[i], y_true, valid_pixel_mask, dm_error_squared
start_pts[i], end_pts[i], y_true, dm_error_squared
)
if probe.optimization_enabled(self.current_epoch):
self.add_to_probe_update_terms(
Expand Down Expand Up @@ -223,7 +223,6 @@ def apply_dm_update_to_exit_wave_chunk(
start_pt: int,
end_pt: int,
y_true: Tensor,
valid_pixel_mask: Tensor,
dm_error_squared: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Expand All @@ -235,8 +234,6 @@ def apply_dm_update_to_exit_wave_chunk(
# - revised_psi --> 2 * Pi_o(psi_n)- psi_n
# - new_psi --> Pi_o(psi_n)

probe = self.parameter_group.probe

# Get the update exit wave
new_psi, obj_patches = self.calculate_exit_wave_chunk(
start_pt, end_pt, return_obj_patches=True
Expand All @@ -246,10 +243,10 @@ def apply_dm_update_to_exit_wave_chunk(
2 * new_psi - self.psi[start_pt:end_pt]
)
# Replace intensities
revised_psi = torch.where(
valid_pixel_mask.repeat(revised_psi.shape[0], probe.n_modes, 1, 1),
self.replace_propagated_exit_wave_magnitude(revised_psi, y_true[start_pt:end_pt]),
revised_psi = self.replace_propagated_exit_wave_magnitude(
revised_psi,
y_true[start_pt:end_pt],
constrained_pixel_mask=self.get_constrained_pixel_mask(y_true[start_pt:end_pt]),
)
# Propagate back to sample plane
revised_psi = self.forward_model.free_space_propagator.propagate_backward(revised_psi)
Expand Down
4 changes: 3 additions & 1 deletion src/ptychi/reconstructors/lsqml.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(
"gaussian": fm.PtychographyGaussianNoiseModel,
"poisson": fm.PtychographyPoissonNoiseModel,
}[options.noise_model](
**noise_model_params, valid_pixel_mask=self.dataset.valid_pixel_mask.clone()
**noise_model_params,
valid_pixel_mask=self.dataset.valid_pixel_mask.clone(),
exclude_measured_pixels_below=self.options.exclude_measured_pixels_below,
)

self.alpha_psi_far = 0.5
Expand Down
14 changes: 6 additions & 8 deletions src/ptychi/reconstructors/pie.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,13 @@ def check_inputs(self, *args, **kwargs):
@timer()
def run_minibatch(self, input_data, y_true, *args, **kwargs):
self.parameter_group.probe.initialize_grad()
(delta_o, delta_p_i, delta_pos), y_pred = self.compute_updates(
*input_data, y_true, self.dataset.valid_pixel_mask
)
(delta_o, delta_p_i, delta_pos), y_pred = self.compute_updates(*input_data, y_true)
self.apply_updates(delta_o, delta_p_i, delta_pos)
self.loss_tracker.update_batch_loss_with_metric_function(y_pred, y_true)

@timer()
def compute_updates(
self, indices: torch.Tensor, y_true: torch.Tensor, valid_pixel_mask: torch.Tensor
self, indices: torch.Tensor, y_true: torch.Tensor
) -> tuple[torch.Tensor, ...]:
"""
Calculates the updates of the whole object, the probe, and other parameters.
Expand All @@ -94,10 +92,10 @@ def compute_updates(
psi_far = self.forward_model.intermediate_variables["psi_far"]
unique_probes = self.forward_model.intermediate_variables.shifted_unique_probes

psi_prime = self.replace_propagated_exit_wave_magnitude(psi_far, y_true)
# Do not swap magnitude for bad pixels.
psi_prime = torch.where(
valid_pixel_mask.repeat(psi_prime.shape[0], probe.n_modes, 1, 1), psi_prime, psi_far
psi_prime = self.replace_propagated_exit_wave_magnitude(
psi_far,
y_true,
constrained_pixel_mask=self.get_constrained_pixel_mask(y_true),
)
psi_prime = self.forward_model.free_space_propagator.propagate_backward(psi_prime)

Expand Down
Loading