diff --git a/.gitignore b/.gitignore index debe62a7..be239dda 100644 --- a/.gitignore +++ b/.gitignore @@ -157,9 +157,7 @@ workspace/ *.hdf5 *.h5 -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +# IDE +.idea/ +.vscode/ +.loglogin \ No newline at end of file diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 4c7ed148..d36c7e63 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -1062,6 +1062,12 @@ class ReconstructorOptions(Options): and is not involved in the reconstruction math. """ + exclude_measured_pixels_below: Optional[float] = None + """ + If not None, gradients corresponding to measured diffraction pixels whose intensity + is less than or equal to this value are set to 0 in reconstructors that support it. + """ + forward_model_options: ForwardModelOptions = dataclasses.field( default_factory=ForwardModelOptions ) diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index ffe9ebbf..548ead4f 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -173,7 +173,7 @@ def build_data(self): free_space_propagation_distance_m=self.data_options.free_space_propagation_distance_m, fft_shift=self.data_options.fft_shift, save_data_on_device=save_on_device, - valid_pixel_mask=self.data_options.valid_pixel_mask + valid_pixel_mask=self.data_options.valid_pixel_mask, ) def build_object(self): diff --git a/src/ptychi/forward_models.py b/src/ptychi/forward_models.py index 25b5b673..43b26a19 100644 --- a/src/ptychi/forward_models.py +++ b/src/ptychi/forward_models.py @@ -229,7 +229,7 @@ def __init__( self.intermediate_variables = self.PlanarPtychographyIntermediateVariables() self.diffraction_pattern_blur_sigma = diffraction_pattern_blur_sigma - + self.check_inputs() def check_inputs(self): @@ -760,12 +760,18 @@ def scale_gradients(self, patterns): class NoiseModel(torch.nn.Module): - def __init__(self, eps=1e-6, valid_pixel_mask: Optional[Tensor] = None) -> None: + def __init__( + self, + eps: float = 1e-6, + valid_pixel_mask: Optional[Tensor] = None, + exclude_measured_pixels_below: Optional[float] = None, + ) -> None: super().__init__() self.eps = eps self.noise_statistics = None self.valid_pixel_mask = valid_pixel_mask - + self.exclude_measured_pixels_below = exclude_measured_pixels_below + def nll(self, y_pred: Tensor, y_true: Tensor) -> Tensor: """ Calculate the negative log-likelihood. @@ -775,6 +781,25 @@ def nll(self, y_pred: Tensor, y_true: Tensor) -> Tensor: def backward(self, *args, **kwargs): raise NotImplementedError + @staticmethod + def get_constrained_pixel_mask( + valid_pixel_mask: Optional[Tensor], + exclude_measured_pixels_below: Optional[float], + y_true: Tensor, + ) -> Tensor: + constrained_pixel_mask = torch.ones_like(y_true, dtype=torch.bool) + if valid_pixel_mask is not None: + constrained_pixel_mask = valid_pixel_mask.to(y_true.device) + constrained_pixel_mask = constrained_pixel_mask.unsqueeze(0).expand( + y_true.shape[0], -1, -1 + ) + if exclude_measured_pixels_below is not None: + constrained_pixel_mask = torch.logical_and( + constrained_pixel_mask, + y_true > exclude_measured_pixels_below, + ) + return constrained_pixel_mask + @timer() def conform_to_exit_wave_size( self, @@ -822,17 +847,24 @@ def backward_to_psi_far(self, y_pred, y_true, psi_far): $g = \frac{\partial L}{\partial \psi_{far}}$. When `self.valid_pixel_mask` is not None, pixels of the gradient `g` where the - mask is False are set to 0. When `g` is used to update the far-field wavefield - `psi_far`, the invalid pixels are kept unchanged. + mask is False are set to 0. When `self.exclude_measured_pixels_below` is not + None, gradients at pixels with measured intensities less than or equal to that + threshold are also set to 0. """ # Shape of g: (batch_size, h, w) # Shape of psi_far: (batch_size, n_probe_modes, h, w) y_pred, y_true, valid_pixel_mask = self.conform_to_exit_wave_size( y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:] ) + constrained_pixel_mask = self.get_constrained_pixel_mask( + valid_pixel_mask, + self.exclude_measured_pixels_below, + y_true, + ) g = 1 - torch.sqrt(y_true) / (torch.sqrt(y_pred) + self.eps) # Eq. 12b - if valid_pixel_mask is not None: - g[:, torch.logical_not(valid_pixel_mask)] = 0 + + g[torch.logical_not(constrained_pixel_mask)] = 0 + w = 1 / (2 * self.sigma) ** 2 g = 2 * w * g[:, None, :, :] * psi_far return g @@ -863,8 +895,12 @@ def backward_to_psi_far(self, y_pred: Tensor, y_true: Tensor, psi_far: Tensor): y_pred, y_true, valid_pixel_mask = self.conform_to_exit_wave_size( y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:] ) + constrained_pixel_mask = self.get_constrained_pixel_mask( + valid_pixel_mask, + self.exclude_measured_pixels_below, + y_true, + ) g = 1 - y_true / (y_pred + self.eps) # Eq. 12b - if valid_pixel_mask is not None: - g[:, torch.logical_not(valid_pixel_mask)] = 0 + g[torch.logical_not(constrained_pixel_mask)] = 0 g = g[:, None, :, :] * psi_far return g diff --git a/src/ptychi/io_handles.py b/src/ptychi/io_handles.py index f95c17a7..be0eaaf7 100644 --- a/src/ptychi/io_handles.py +++ b/src/ptychi/io_handles.py @@ -52,7 +52,7 @@ def __init__( self.free_space_propagation_distance_m = free_space_propagation_distance_m self.save_data_on_device = save_data_on_device - + def __getitem__(self, index): if not isinstance(index, torch.Tensor): index = torch.tensor(index, device=self.patterns.device, dtype=torch.long) diff --git a/src/ptychi/reconstructors/ad_ptychography.py b/src/ptychi/reconstructors/ad_ptychography.py index 7b25d33d..25601f8b 100644 --- a/src/ptychi/reconstructors/ad_ptychography.py +++ b/src/ptychi/reconstructors/ad_ptychography.py @@ -137,8 +137,9 @@ def get_retain_graph(self) -> bool: def run_minibatch(self, input_data, y_true, *args, **kwargs): y_pred = self.forward_model(*input_data) + constrained_pixel_mask = self.get_constrained_pixel_mask(y_true) batch_loss = self.loss_function( - y_pred[:, self.dataset.valid_pixel_mask], y_true[:, self.dataset.valid_pixel_mask] + y_pred[constrained_pixel_mask], y_true[constrained_pixel_mask] ) batch_loss.backward(retain_graph=self.get_retain_graph()) @@ -149,5 +150,9 @@ def run_minibatch(self, input_data, y_true, *args, **kwargs): self.forward_model.zero_grad() self.run_post_update_hooks() - self.loss_tracker.update_batch_loss(y_pred=y_pred, y_true=y_true, loss=batch_loss.item()) + self.loss_tracker.update_batch_loss( + y_pred=y_pred[constrained_pixel_mask], + y_true=y_true[constrained_pixel_mask], + loss=batch_loss.item(), + ) self.loss_tracker.update_batch_regularization_loss(reg_loss.item()) diff --git a/src/ptychi/reconstructors/base.py b/src/ptychi/reconstructors/base.py index 1ebe05e6..d79f39fd 100644 --- a/src/ptychi/reconstructors/base.py +++ b/src/ptychi/reconstructors/base.py @@ -419,6 +419,14 @@ def build_dataloader(self): ) return super().build_dataloader(batch_sampler=batch_sampler) + def get_constrained_pixel_mask(self, y_true: Tensor) -> Tensor: + """Get the detector pixels that should constrain the update for a batch.""" + return fm.NoiseModel.get_constrained_pixel_mask( + valid_pixel_mask=self.dataset.valid_pixel_mask, + exclude_measured_pixels_below=self.options.exclude_measured_pixels_below, + y_true=y_true, + ) + def update_preconditioners(self): # Update preconditioner of the object only if: # - the preconditioner does not exist, or @@ -684,9 +692,11 @@ def prepare_data(self, *args, **kwargs): self.parameter_group.probe.normalize_eigenmodes() logger.info("Probe eigenmodes normalized.") - @staticmethod def replace_propagated_exit_wave_magnitude( - psi: Tensor, actual_pattern_intensity: Tensor + self, + psi: Tensor, + actual_pattern_intensity: Tensor, + constrained_pixel_mask: Optional[Tensor] = None, ) -> Tensor: """ Replace the propogated exit wave amplitude. @@ -697,6 +707,9 @@ def replace_propagated_exit_wave_magnitude( Predicted exit wave propagated to the detector plane. actual_pattern_intensity : Tensor The measured diffraction pattern at the detector. + constrained_pixel_mask : Tensor, optional + If given, only pixels where this mask is True are constrained. Other pixels + are left unchanged from `psi`. Returns ------- @@ -705,11 +718,14 @@ def replace_propagated_exit_wave_magnitude( of `actual_pattern_intensity`. """ - return ( + psi_prime = ( psi / ((psi.abs() ** 2).sum(1, keepdims=True).sqrt() + 1e-7) * torch.sqrt(actual_pattern_intensity + 1e-7)[:, None] ) + if constrained_pixel_mask is not None: + return torch.where(constrained_pixel_mask[:, None], psi_prime, psi) + return psi_prime @timer() def adjoint_shift_probe_update_direction(self, indices, delta_p, first_mode_only=False): diff --git a/src/ptychi/reconstructors/bh.py b/src/ptychi/reconstructors/bh.py index 0b0a6889..2ffc9481 100644 --- a/src/ptychi/reconstructors/bh.py +++ b/src/ptychi/reconstructors/bh.py @@ -137,6 +137,7 @@ def compute_updates( psi_far = self.forward_model.intermediate_variables["psi_far"] p = probe.get_opr_mode(0) # to do for multi modes pos = self.positions + self.current_constrained_pixel_mask = self.get_constrained_pixel_mask(y_true)[:, None] # sqrt of data d = torch.sqrt(y_true)[:, torch.newaxis] @@ -294,6 +295,7 @@ def gradientF(self, psi_far, d): td = d * (psi_far / (torch.abs(psi_far) + self.eps)) td = psi_far - td + td = td * self.current_constrained_pixel_mask # Compensate FFT normalization only for the far-field Fourier propagator. if isinstance(self.forward_model.free_space_propagator, fm.FourierPropagator): td *= psi_far.shape[-1] * psi_far.shape[-2] @@ -305,8 +307,13 @@ def hessianF(self, psi_far, psi_far1, psi_far2, data): l0 = psi_far / (torch.abs(psi_far) + self.eps) d0 = data / (torch.abs(psi_far) + self.eps) - v1 = torch.sum((1 - d0) * reprod(psi_far1, psi_far2)) - v2 = torch.sum(d0 * reprod(l0, psi_far1) * reprod(l0, psi_far2)) + v1 = torch.sum(self.current_constrained_pixel_mask * (1 - d0) * reprod(psi_far1, psi_far2)) + v2 = torch.sum( + self.current_constrained_pixel_mask + * d0 + * reprod(l0, psi_far1) + * reprod(l0, psi_far2) + ) return 2 * (v1 + v2) def gradient_o(self, p, gradF): diff --git a/src/ptychi/reconstructors/dm.py b/src/ptychi/reconstructors/dm.py index 947563fc..b7a211b3 100644 --- a/src/ptychi/reconstructors/dm.py +++ b/src/ptychi/reconstructors/dm.py @@ -98,11 +98,11 @@ def build_dataloader(self): @timer() def run_minibatch(self, input_data, y_true, *args, **kwargs): - dm_error_squared = self.compute_updates(y_true, self.dataset.valid_pixel_mask) + dm_error_squared = self.compute_updates(y_true) self.loss_tracker.update_batch_loss(loss=dm_error_squared.sqrt()) @timer() - def compute_updates(self, y_true: Tensor, valid_pixel_mask: Tensor) -> Tensor: + def compute_updates(self, y_true: Tensor) -> Tensor: """ Compute the updates to the object, probe, and exit wave using the procedure described here: [Probe retrieval in ptychographic coherent diffractive imaging @@ -147,7 +147,7 @@ def compute_updates(self, y_true: Tensor, valid_pixel_mask: Tensor) -> Tensor: dm_error_squared = 0 for i in range(n_chunks): obj_patches, dm_error_squared, new_psi = self.apply_dm_update_to_exit_wave_chunk( - start_pts[i], end_pts[i], y_true, valid_pixel_mask, dm_error_squared + start_pts[i], end_pts[i], y_true, dm_error_squared ) if probe.optimization_enabled(self.current_epoch): self.add_to_probe_update_terms( @@ -223,7 +223,6 @@ def apply_dm_update_to_exit_wave_chunk( start_pt: int, end_pt: int, y_true: Tensor, - valid_pixel_mask: Tensor, dm_error_squared: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """ @@ -235,8 +234,6 @@ def apply_dm_update_to_exit_wave_chunk( # - revised_psi --> 2 * Pi_o(psi_n)- psi_n # - new_psi --> Pi_o(psi_n) - probe = self.parameter_group.probe - # Get the update exit wave new_psi, obj_patches = self.calculate_exit_wave_chunk( start_pt, end_pt, return_obj_patches=True @@ -246,10 +243,10 @@ def apply_dm_update_to_exit_wave_chunk( 2 * new_psi - self.psi[start_pt:end_pt] ) # Replace intensities - revised_psi = torch.where( - valid_pixel_mask.repeat(revised_psi.shape[0], probe.n_modes, 1, 1), - self.replace_propagated_exit_wave_magnitude(revised_psi, y_true[start_pt:end_pt]), + revised_psi = self.replace_propagated_exit_wave_magnitude( revised_psi, + y_true[start_pt:end_pt], + constrained_pixel_mask=self.get_constrained_pixel_mask(y_true[start_pt:end_pt]), ) # Propagate back to sample plane revised_psi = self.forward_model.free_space_propagator.propagate_backward(revised_psi) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 7183bfdc..20c1e5df 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -65,7 +65,9 @@ def __init__( "gaussian": fm.PtychographyGaussianNoiseModel, "poisson": fm.PtychographyPoissonNoiseModel, }[options.noise_model]( - **noise_model_params, valid_pixel_mask=self.dataset.valid_pixel_mask.clone() + **noise_model_params, + valid_pixel_mask=self.dataset.valid_pixel_mask.clone(), + exclude_measured_pixels_below=self.options.exclude_measured_pixels_below, ) self.alpha_psi_far = 0.5 diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index ac032acf..11484ca3 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -66,15 +66,13 @@ def check_inputs(self, *args, **kwargs): @timer() def run_minibatch(self, input_data, y_true, *args, **kwargs): self.parameter_group.probe.initialize_grad() - (delta_o, delta_p_i, delta_pos), y_pred = self.compute_updates( - *input_data, y_true, self.dataset.valid_pixel_mask - ) + (delta_o, delta_p_i, delta_pos), y_pred = self.compute_updates(*input_data, y_true) self.apply_updates(delta_o, delta_p_i, delta_pos) self.loss_tracker.update_batch_loss_with_metric_function(y_pred, y_true) @timer() def compute_updates( - self, indices: torch.Tensor, y_true: torch.Tensor, valid_pixel_mask: torch.Tensor + self, indices: torch.Tensor, y_true: torch.Tensor ) -> tuple[torch.Tensor, ...]: """ Calculates the updates of the whole object, the probe, and other parameters. @@ -94,10 +92,10 @@ def compute_updates( psi_far = self.forward_model.intermediate_variables["psi_far"] unique_probes = self.forward_model.intermediate_variables.shifted_unique_probes - psi_prime = self.replace_propagated_exit_wave_magnitude(psi_far, y_true) - # Do not swap magnitude for bad pixels. - psi_prime = torch.where( - valid_pixel_mask.repeat(psi_prime.shape[0], probe.n_modes, 1, 1), psi_prime, psi_far + psi_prime = self.replace_propagated_exit_wave_magnitude( + psi_far, + y_true, + constrained_pixel_mask=self.get_constrained_pixel_mask(y_true), ) psi_prime = self.forward_model.free_space_propagator.propagate_backward(psi_prime)