From cbf44a489543e11dc324fda7a84451487b83feb1 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 11 Mar 2026 18:01:59 -0500 Subject: [PATCH 1/4] currently we use a fixed unmeasured pixel binary mask to define regions where we do not enforce the measurement constraint. this fix allows us to treat all unmeasured pixels as unconsstrained --- src/ptychi/api/options/data.py | 3 +++ src/ptychi/api/task.py | 3 ++- src/ptychi/forward_models.py | 17 ++++++++++++++--- src/ptychi/io_handles.py | 4 +++- src/ptychi/reconstructors/base.py | 1 + src/ptychi/reconstructors/lsqml.py | 4 +++- 6 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/ptychi/api/options/data.py b/src/ptychi/api/options/data.py index 76ff82a0..5d7a49a9 100644 --- a/src/ptychi/api/options/data.py +++ b/src/ptychi/api/options/data.py @@ -37,6 +37,9 @@ class PtychographyDataOptions(base.Options): valid_pixel_mask: Optional[Union[ndarray, Tensor]] = None """A 2D boolean mask where valid pixels are True.""" + leave_all_measurement_zeros_unconstrained: bool = False + """ Treat ALL unmeasured pixels in the diffraction intensity as unconstrained.""" + save_data_on_device: bool = False """Whether to save the diffraction data on acceleration devices like GPU.""" diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index ffe9ebbf..840757c0 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -173,7 +173,8 @@ def build_data(self): free_space_propagation_distance_m=self.data_options.free_space_propagation_distance_m, fft_shift=self.data_options.fft_shift, save_data_on_device=save_on_device, - valid_pixel_mask=self.data_options.valid_pixel_mask + valid_pixel_mask=self.data_options.valid_pixel_mask, + leave_all_measurement_zeros_unconstrained=self.data_options.leave_all_measurement_zeros_unconstrained ) def build_object(self): diff --git a/src/ptychi/forward_models.py b/src/ptychi/forward_models.py index 25b5b673..b50f85ac 100644 --- a/src/ptychi/forward_models.py +++ b/src/ptychi/forward_models.py @@ -156,6 +156,7 @@ def __init__( apply_subpixel_shifts_on_probe: bool = True, diffraction_pattern_blur_sigma: Optional[float] = None, low_memory_mode: bool = False, + leave_all_measurement_zeros_unconstrained: bool = False, *args, **kwargs, ) -> None: @@ -230,6 +231,8 @@ def __init__( self.diffraction_pattern_blur_sigma = diffraction_pattern_blur_sigma + self.leave_all_measurement_zeros_unconstrained = leave_all_measurement_zeros_unconstrained + self.check_inputs() def check_inputs(self): @@ -760,12 +763,16 @@ def scale_gradients(self, patterns): class NoiseModel(torch.nn.Module): - def __init__(self, eps=1e-6, valid_pixel_mask: Optional[Tensor] = None) -> None: + def __init__(self, + eps=1e-6, + valid_pixel_mask: Optional[Tensor] = None, + leave_all_measurement_zeros_unconstrained: bool = False) -> None: super().__init__() self.eps = eps self.noise_statistics = None self.valid_pixel_mask = valid_pixel_mask - + self.leave_all_measurement_zeros_unconstrained = leave_all_measurement_zeros_unconstrained + def nll(self, y_pred: Tensor, y_true: Tensor) -> Tensor: """ Calculate the negative log-likelihood. @@ -831,8 +838,12 @@ def backward_to_psi_far(self, y_pred, y_true, psi_far): y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:] ) g = 1 - torch.sqrt(y_true) / (torch.sqrt(y_pred) + self.eps) # Eq. 12b - if valid_pixel_mask is not None: + + if self.leave_all_measurement_zeros_unconstrained: + g[y_true==0] = 0 + elif valid_pixel_mask is not None: g[:, torch.logical_not(valid_pixel_mask)] = 0 + w = 1 / (2 * self.sigma) ** 2 g = 2 * w * g[:, None, :, :] * psi_far return g diff --git a/src/ptychi/io_handles.py b/src/ptychi/io_handles.py index f95c17a7..155f9373 100644 --- a/src/ptychi/io_handles.py +++ b/src/ptychi/io_handles.py @@ -27,6 +27,7 @@ def __init__( self, patterns: Union[Tensor, ndarray], valid_pixel_mask: Optional[Union[Tensor, ndarray]] = None, + leave_all_measurement_zeros_unconstrained: bool = False, wavelength_m: float = None, free_space_propagation_distance_m: float = 1.0, fft_shift: bool = True, @@ -52,7 +53,8 @@ def __init__( self.free_space_propagation_distance_m = free_space_propagation_distance_m self.save_data_on_device = save_data_on_device - + self.leave_all_measurement_zeros_unconstrained = leave_all_measurement_zeros_unconstrained + def __getitem__(self, index): if not isinstance(index, torch.Tensor): index = torch.tensor(index, device=self.patterns.device, dtype=torch.long) diff --git a/src/ptychi/reconstructors/base.py b/src/ptychi/reconstructors/base.py index 1ebe05e6..ae448767 100644 --- a/src/ptychi/reconstructors/base.py +++ b/src/ptychi/reconstructors/base.py @@ -657,6 +657,7 @@ def build_forward_model(self): pad_for_shift=self.options.forward_model_options.pad_for_shift, low_memory_mode=self.options.forward_model_options.low_memory_mode, diffraction_pattern_blur_sigma=self.options.forward_model_options.diffraction_pattern_blur_sigma, + leave_all_measurement_zeros_unconstrained=self.dataset.leave_all_measurement_zeros_unconstrained ) def run_post_epoch_hooks(self) -> None: diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 7183bfdc..894023d7 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -65,7 +65,9 @@ def __init__( "gaussian": fm.PtychographyGaussianNoiseModel, "poisson": fm.PtychographyPoissonNoiseModel, }[options.noise_model]( - **noise_model_params, valid_pixel_mask=self.dataset.valid_pixel_mask.clone() + **noise_model_params, + valid_pixel_mask=self.dataset.valid_pixel_mask.clone(), + leave_all_measurement_zeros_unconstrained=self.dataset.leave_all_measurement_zeros_unconstrained ) self.alpha_psi_far = 0.5 From 017903f1fa8b1805f4b3b17eef54f6ca0e74b5aa Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 13 Mar 2026 12:01:22 -0500 Subject: [PATCH 2/4] REFACTOR: rename added argument to `exclude_measured_pixels_below`, make it a threshold, move it to `ReconstructorOptions`, and remove unnecessary plumbing --- .gitignore | 10 ++++------ src/ptychi/api/options/base.py | 6 ++++++ src/ptychi/api/options/data.py | 3 --- src/ptychi/api/task.py | 1 - src/ptychi/forward_models.py | 28 ++++++++++++++-------------- src/ptychi/io_handles.py | 2 -- src/ptychi/reconstructors/base.py | 1 - src/ptychi/reconstructors/lsqml.py | 2 +- 8 files changed, 25 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index debe62a7..be239dda 100644 --- a/.gitignore +++ b/.gitignore @@ -157,9 +157,7 @@ workspace/ *.hdf5 *.h5 -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +# IDE +.idea/ +.vscode/ +.loglogin \ No newline at end of file diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 4c7ed148..d36c7e63 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -1062,6 +1062,12 @@ class ReconstructorOptions(Options): and is not involved in the reconstruction math. """ + exclude_measured_pixels_below: Optional[float] = None + """ + If not None, gradients corresponding to measured diffraction pixels whose intensity + is less than or equal to this value are set to 0 in reconstructors that support it. + """ + forward_model_options: ForwardModelOptions = dataclasses.field( default_factory=ForwardModelOptions ) diff --git a/src/ptychi/api/options/data.py b/src/ptychi/api/options/data.py index 5d7a49a9..76ff82a0 100644 --- a/src/ptychi/api/options/data.py +++ b/src/ptychi/api/options/data.py @@ -37,9 +37,6 @@ class PtychographyDataOptions(base.Options): valid_pixel_mask: Optional[Union[ndarray, Tensor]] = None """A 2D boolean mask where valid pixels are True.""" - leave_all_measurement_zeros_unconstrained: bool = False - """ Treat ALL unmeasured pixels in the diffraction intensity as unconstrained.""" - save_data_on_device: bool = False """Whether to save the diffraction data on acceleration devices like GPU.""" diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index 840757c0..548ead4f 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -174,7 +174,6 @@ def build_data(self): fft_shift=self.data_options.fft_shift, save_data_on_device=save_on_device, valid_pixel_mask=self.data_options.valid_pixel_mask, - leave_all_measurement_zeros_unconstrained=self.data_options.leave_all_measurement_zeros_unconstrained ) def build_object(self): diff --git a/src/ptychi/forward_models.py b/src/ptychi/forward_models.py index b50f85ac..2743e472 100644 --- a/src/ptychi/forward_models.py +++ b/src/ptychi/forward_models.py @@ -156,7 +156,6 @@ def __init__( apply_subpixel_shifts_on_probe: bool = True, diffraction_pattern_blur_sigma: Optional[float] = None, low_memory_mode: bool = False, - leave_all_measurement_zeros_unconstrained: bool = False, *args, **kwargs, ) -> None: @@ -230,8 +229,6 @@ def __init__( self.intermediate_variables = self.PlanarPtychographyIntermediateVariables() self.diffraction_pattern_blur_sigma = diffraction_pattern_blur_sigma - - self.leave_all_measurement_zeros_unconstrained = leave_all_measurement_zeros_unconstrained self.check_inputs() @@ -763,15 +760,17 @@ def scale_gradients(self, patterns): class NoiseModel(torch.nn.Module): - def __init__(self, - eps=1e-6, - valid_pixel_mask: Optional[Tensor] = None, - leave_all_measurement_zeros_unconstrained: bool = False) -> None: + def __init__( + self, + eps: float = 1e-6, + valid_pixel_mask: Optional[Tensor] = None, + exclude_measured_pixels_below: Optional[float] = None, + ) -> None: super().__init__() self.eps = eps self.noise_statistics = None self.valid_pixel_mask = valid_pixel_mask - self.leave_all_measurement_zeros_unconstrained = leave_all_measurement_zeros_unconstrained + self.exclude_measured_pixels_below = exclude_measured_pixels_below def nll(self, y_pred: Tensor, y_true: Tensor) -> Tensor: """ @@ -829,8 +828,9 @@ def backward_to_psi_far(self, y_pred, y_true, psi_far): $g = \frac{\partial L}{\partial \psi_{far}}$. When `self.valid_pixel_mask` is not None, pixels of the gradient `g` where the - mask is False are set to 0. When `g` is used to update the far-field wavefield - `psi_far`, the invalid pixels are kept unchanged. + mask is False are set to 0. When `self.exclude_measured_pixels_below` is not + None, gradients at pixels with measured intensities less than or equal to that + threshold are also set to 0. """ # Shape of g: (batch_size, h, w) # Shape of psi_far: (batch_size, n_probe_modes, h, w) @@ -838,11 +838,11 @@ def backward_to_psi_far(self, y_pred, y_true, psi_far): y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:] ) g = 1 - torch.sqrt(y_true) / (torch.sqrt(y_pred) + self.eps) # Eq. 12b - - if self.leave_all_measurement_zeros_unconstrained: - g[y_true==0] = 0 - elif valid_pixel_mask is not None: + + if valid_pixel_mask is not None: g[:, torch.logical_not(valid_pixel_mask)] = 0 + if self.exclude_measured_pixels_below is not None: + g[y_true <= self.exclude_measured_pixels_below] = 0 w = 1 / (2 * self.sigma) ** 2 g = 2 * w * g[:, None, :, :] * psi_far diff --git a/src/ptychi/io_handles.py b/src/ptychi/io_handles.py index 155f9373..be0eaaf7 100644 --- a/src/ptychi/io_handles.py +++ b/src/ptychi/io_handles.py @@ -27,7 +27,6 @@ def __init__( self, patterns: Union[Tensor, ndarray], valid_pixel_mask: Optional[Union[Tensor, ndarray]] = None, - leave_all_measurement_zeros_unconstrained: bool = False, wavelength_m: float = None, free_space_propagation_distance_m: float = 1.0, fft_shift: bool = True, @@ -53,7 +52,6 @@ def __init__( self.free_space_propagation_distance_m = free_space_propagation_distance_m self.save_data_on_device = save_data_on_device - self.leave_all_measurement_zeros_unconstrained = leave_all_measurement_zeros_unconstrained def __getitem__(self, index): if not isinstance(index, torch.Tensor): diff --git a/src/ptychi/reconstructors/base.py b/src/ptychi/reconstructors/base.py index ae448767..1ebe05e6 100644 --- a/src/ptychi/reconstructors/base.py +++ b/src/ptychi/reconstructors/base.py @@ -657,7 +657,6 @@ def build_forward_model(self): pad_for_shift=self.options.forward_model_options.pad_for_shift, low_memory_mode=self.options.forward_model_options.low_memory_mode, diffraction_pattern_blur_sigma=self.options.forward_model_options.diffraction_pattern_blur_sigma, - leave_all_measurement_zeros_unconstrained=self.dataset.leave_all_measurement_zeros_unconstrained ) def run_post_epoch_hooks(self) -> None: diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 894023d7..20c1e5df 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -67,7 +67,7 @@ def __init__( }[options.noise_model]( **noise_model_params, valid_pixel_mask=self.dataset.valid_pixel_mask.clone(), - leave_all_measurement_zeros_unconstrained=self.dataset.leave_all_measurement_zeros_unconstrained + exclude_measured_pixels_below=self.options.exclude_measured_pixels_below, ) self.alpha_psi_far = 0.5 From 1e83127a771ef6c5d90d030749d99c75d56ca2ed Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 13 Mar 2026 12:53:39 -0500 Subject: [PATCH 3/4] FEAT: extend to other reconstructors --- src/ptychi/forward_models.py | 37 ++++++++++++++++---- src/ptychi/reconstructors/ad_ptychography.py | 9 +++-- src/ptychi/reconstructors/base.py | 22 ++++++++++-- src/ptychi/reconstructors/bh.py | 11 ++++-- src/ptychi/reconstructors/dm.py | 13 ++++--- src/ptychi/reconstructors/pie.py | 14 ++++---- 6 files changed, 78 insertions(+), 28 deletions(-) diff --git a/src/ptychi/forward_models.py b/src/ptychi/forward_models.py index 2743e472..43b26a19 100644 --- a/src/ptychi/forward_models.py +++ b/src/ptychi/forward_models.py @@ -781,6 +781,25 @@ def nll(self, y_pred: Tensor, y_true: Tensor) -> Tensor: def backward(self, *args, **kwargs): raise NotImplementedError + @staticmethod + def get_constrained_pixel_mask( + valid_pixel_mask: Optional[Tensor], + exclude_measured_pixels_below: Optional[float], + y_true: Tensor, + ) -> Tensor: + constrained_pixel_mask = torch.ones_like(y_true, dtype=torch.bool) + if valid_pixel_mask is not None: + constrained_pixel_mask = valid_pixel_mask.to(y_true.device) + constrained_pixel_mask = constrained_pixel_mask.unsqueeze(0).expand( + y_true.shape[0], -1, -1 + ) + if exclude_measured_pixels_below is not None: + constrained_pixel_mask = torch.logical_and( + constrained_pixel_mask, + y_true > exclude_measured_pixels_below, + ) + return constrained_pixel_mask + @timer() def conform_to_exit_wave_size( self, @@ -837,12 +856,14 @@ def backward_to_psi_far(self, y_pred, y_true, psi_far): y_pred, y_true, valid_pixel_mask = self.conform_to_exit_wave_size( y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:] ) + constrained_pixel_mask = self.get_constrained_pixel_mask( + valid_pixel_mask, + self.exclude_measured_pixels_below, + y_true, + ) g = 1 - torch.sqrt(y_true) / (torch.sqrt(y_pred) + self.eps) # Eq. 12b - if valid_pixel_mask is not None: - g[:, torch.logical_not(valid_pixel_mask)] = 0 - if self.exclude_measured_pixels_below is not None: - g[y_true <= self.exclude_measured_pixels_below] = 0 + g[torch.logical_not(constrained_pixel_mask)] = 0 w = 1 / (2 * self.sigma) ** 2 g = 2 * w * g[:, None, :, :] * psi_far @@ -874,8 +895,12 @@ def backward_to_psi_far(self, y_pred: Tensor, y_true: Tensor, psi_far: Tensor): y_pred, y_true, valid_pixel_mask = self.conform_to_exit_wave_size( y_pred, y_true, self.valid_pixel_mask, psi_far.shape[-2:] ) + constrained_pixel_mask = self.get_constrained_pixel_mask( + valid_pixel_mask, + self.exclude_measured_pixels_below, + y_true, + ) g = 1 - y_true / (y_pred + self.eps) # Eq. 12b - if valid_pixel_mask is not None: - g[:, torch.logical_not(valid_pixel_mask)] = 0 + g[torch.logical_not(constrained_pixel_mask)] = 0 g = g[:, None, :, :] * psi_far return g diff --git a/src/ptychi/reconstructors/ad_ptychography.py b/src/ptychi/reconstructors/ad_ptychography.py index 7b25d33d..25601f8b 100644 --- a/src/ptychi/reconstructors/ad_ptychography.py +++ b/src/ptychi/reconstructors/ad_ptychography.py @@ -137,8 +137,9 @@ def get_retain_graph(self) -> bool: def run_minibatch(self, input_data, y_true, *args, **kwargs): y_pred = self.forward_model(*input_data) + constrained_pixel_mask = self.get_constrained_pixel_mask(y_true) batch_loss = self.loss_function( - y_pred[:, self.dataset.valid_pixel_mask], y_true[:, self.dataset.valid_pixel_mask] + y_pred[constrained_pixel_mask], y_true[constrained_pixel_mask] ) batch_loss.backward(retain_graph=self.get_retain_graph()) @@ -149,5 +150,9 @@ def run_minibatch(self, input_data, y_true, *args, **kwargs): self.forward_model.zero_grad() self.run_post_update_hooks() - self.loss_tracker.update_batch_loss(y_pred=y_pred, y_true=y_true, loss=batch_loss.item()) + self.loss_tracker.update_batch_loss( + y_pred=y_pred[constrained_pixel_mask], + y_true=y_true[constrained_pixel_mask], + loss=batch_loss.item(), + ) self.loss_tracker.update_batch_regularization_loss(reg_loss.item()) diff --git a/src/ptychi/reconstructors/base.py b/src/ptychi/reconstructors/base.py index 1ebe05e6..d79f39fd 100644 --- a/src/ptychi/reconstructors/base.py +++ b/src/ptychi/reconstructors/base.py @@ -419,6 +419,14 @@ def build_dataloader(self): ) return super().build_dataloader(batch_sampler=batch_sampler) + def get_constrained_pixel_mask(self, y_true: Tensor) -> Tensor: + """Get the detector pixels that should constrain the update for a batch.""" + return fm.NoiseModel.get_constrained_pixel_mask( + valid_pixel_mask=self.dataset.valid_pixel_mask, + exclude_measured_pixels_below=self.options.exclude_measured_pixels_below, + y_true=y_true, + ) + def update_preconditioners(self): # Update preconditioner of the object only if: # - the preconditioner does not exist, or @@ -684,9 +692,11 @@ def prepare_data(self, *args, **kwargs): self.parameter_group.probe.normalize_eigenmodes() logger.info("Probe eigenmodes normalized.") - @staticmethod def replace_propagated_exit_wave_magnitude( - psi: Tensor, actual_pattern_intensity: Tensor + self, + psi: Tensor, + actual_pattern_intensity: Tensor, + constrained_pixel_mask: Optional[Tensor] = None, ) -> Tensor: """ Replace the propogated exit wave amplitude. @@ -697,6 +707,9 @@ def replace_propagated_exit_wave_magnitude( Predicted exit wave propagated to the detector plane. actual_pattern_intensity : Tensor The measured diffraction pattern at the detector. + constrained_pixel_mask : Tensor, optional + If given, only pixels where this mask is True are constrained. Other pixels + are left unchanged from `psi`. Returns ------- @@ -705,11 +718,14 @@ def replace_propagated_exit_wave_magnitude( of `actual_pattern_intensity`. """ - return ( + psi_prime = ( psi / ((psi.abs() ** 2).sum(1, keepdims=True).sqrt() + 1e-7) * torch.sqrt(actual_pattern_intensity + 1e-7)[:, None] ) + if constrained_pixel_mask is not None: + return torch.where(constrained_pixel_mask[:, None], psi_prime, psi) + return psi_prime @timer() def adjoint_shift_probe_update_direction(self, indices, delta_p, first_mode_only=False): diff --git a/src/ptychi/reconstructors/bh.py b/src/ptychi/reconstructors/bh.py index 0b0a6889..2ffc9481 100644 --- a/src/ptychi/reconstructors/bh.py +++ b/src/ptychi/reconstructors/bh.py @@ -137,6 +137,7 @@ def compute_updates( psi_far = self.forward_model.intermediate_variables["psi_far"] p = probe.get_opr_mode(0) # to do for multi modes pos = self.positions + self.current_constrained_pixel_mask = self.get_constrained_pixel_mask(y_true)[:, None] # sqrt of data d = torch.sqrt(y_true)[:, torch.newaxis] @@ -294,6 +295,7 @@ def gradientF(self, psi_far, d): td = d * (psi_far / (torch.abs(psi_far) + self.eps)) td = psi_far - td + td = td * self.current_constrained_pixel_mask # Compensate FFT normalization only for the far-field Fourier propagator. if isinstance(self.forward_model.free_space_propagator, fm.FourierPropagator): td *= psi_far.shape[-1] * psi_far.shape[-2] @@ -305,8 +307,13 @@ def hessianF(self, psi_far, psi_far1, psi_far2, data): l0 = psi_far / (torch.abs(psi_far) + self.eps) d0 = data / (torch.abs(psi_far) + self.eps) - v1 = torch.sum((1 - d0) * reprod(psi_far1, psi_far2)) - v2 = torch.sum(d0 * reprod(l0, psi_far1) * reprod(l0, psi_far2)) + v1 = torch.sum(self.current_constrained_pixel_mask * (1 - d0) * reprod(psi_far1, psi_far2)) + v2 = torch.sum( + self.current_constrained_pixel_mask + * d0 + * reprod(l0, psi_far1) + * reprod(l0, psi_far2) + ) return 2 * (v1 + v2) def gradient_o(self, p, gradF): diff --git a/src/ptychi/reconstructors/dm.py b/src/ptychi/reconstructors/dm.py index 947563fc..9948af43 100644 --- a/src/ptychi/reconstructors/dm.py +++ b/src/ptychi/reconstructors/dm.py @@ -98,11 +98,11 @@ def build_dataloader(self): @timer() def run_minibatch(self, input_data, y_true, *args, **kwargs): - dm_error_squared = self.compute_updates(y_true, self.dataset.valid_pixel_mask) + dm_error_squared = self.compute_updates(y_true) self.loss_tracker.update_batch_loss(loss=dm_error_squared.sqrt()) @timer() - def compute_updates(self, y_true: Tensor, valid_pixel_mask: Tensor) -> Tensor: + def compute_updates(self, y_true: Tensor) -> Tensor: """ Compute the updates to the object, probe, and exit wave using the procedure described here: [Probe retrieval in ptychographic coherent diffractive imaging @@ -147,7 +147,7 @@ def compute_updates(self, y_true: Tensor, valid_pixel_mask: Tensor) -> Tensor: dm_error_squared = 0 for i in range(n_chunks): obj_patches, dm_error_squared, new_psi = self.apply_dm_update_to_exit_wave_chunk( - start_pts[i], end_pts[i], y_true, valid_pixel_mask, dm_error_squared + start_pts[i], end_pts[i], y_true, dm_error_squared ) if probe.optimization_enabled(self.current_epoch): self.add_to_probe_update_terms( @@ -223,7 +223,6 @@ def apply_dm_update_to_exit_wave_chunk( start_pt: int, end_pt: int, y_true: Tensor, - valid_pixel_mask: Tensor, dm_error_squared: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """ @@ -246,10 +245,10 @@ def apply_dm_update_to_exit_wave_chunk( 2 * new_psi - self.psi[start_pt:end_pt] ) # Replace intensities - revised_psi = torch.where( - valid_pixel_mask.repeat(revised_psi.shape[0], probe.n_modes, 1, 1), - self.replace_propagated_exit_wave_magnitude(revised_psi, y_true[start_pt:end_pt]), + revised_psi = self.replace_propagated_exit_wave_magnitude( revised_psi, + y_true[start_pt:end_pt], + constrained_pixel_mask=self.get_constrained_pixel_mask(y_true[start_pt:end_pt]), ) # Propagate back to sample plane revised_psi = self.forward_model.free_space_propagator.propagate_backward(revised_psi) diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index ac032acf..11484ca3 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -66,15 +66,13 @@ def check_inputs(self, *args, **kwargs): @timer() def run_minibatch(self, input_data, y_true, *args, **kwargs): self.parameter_group.probe.initialize_grad() - (delta_o, delta_p_i, delta_pos), y_pred = self.compute_updates( - *input_data, y_true, self.dataset.valid_pixel_mask - ) + (delta_o, delta_p_i, delta_pos), y_pred = self.compute_updates(*input_data, y_true) self.apply_updates(delta_o, delta_p_i, delta_pos) self.loss_tracker.update_batch_loss_with_metric_function(y_pred, y_true) @timer() def compute_updates( - self, indices: torch.Tensor, y_true: torch.Tensor, valid_pixel_mask: torch.Tensor + self, indices: torch.Tensor, y_true: torch.Tensor ) -> tuple[torch.Tensor, ...]: """ Calculates the updates of the whole object, the probe, and other parameters. @@ -94,10 +92,10 @@ def compute_updates( psi_far = self.forward_model.intermediate_variables["psi_far"] unique_probes = self.forward_model.intermediate_variables.shifted_unique_probes - psi_prime = self.replace_propagated_exit_wave_magnitude(psi_far, y_true) - # Do not swap magnitude for bad pixels. - psi_prime = torch.where( - valid_pixel_mask.repeat(psi_prime.shape[0], probe.n_modes, 1, 1), psi_prime, psi_far + psi_prime = self.replace_propagated_exit_wave_magnitude( + psi_far, + y_true, + constrained_pixel_mask=self.get_constrained_pixel_mask(y_true), ) psi_prime = self.forward_model.free_space_propagator.propagate_backward(psi_prime) From f56941a9feff204b40fd4b97e71ddababb9d0cbc Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 13 Mar 2026 13:12:45 -0500 Subject: [PATCH 4/4] FIX: fix lint --- src/ptychi/reconstructors/dm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ptychi/reconstructors/dm.py b/src/ptychi/reconstructors/dm.py index 9948af43..b7a211b3 100644 --- a/src/ptychi/reconstructors/dm.py +++ b/src/ptychi/reconstructors/dm.py @@ -234,8 +234,6 @@ def apply_dm_update_to_exit_wave_chunk( # - revised_psi --> 2 * Pi_o(psi_n)- psi_n # - new_psi --> Pi_o(psi_n) - probe = self.parameter_group.probe - # Get the update exit wave new_psi, obj_patches = self.calculate_exit_wave_chunk( start_pt, end_pt, return_obj_patches=True