diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 1533e838..535d075e 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -251,7 +251,11 @@ There is no requirement for what the arguments to the initialization function of self.probe = t.nn.Parameter(probe_guess / self.probe_norm) self.obj = t.nn.Parameter(obj_guess) - + # We register a loss function and an appropriate normalization + self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() + + The first thing to notice about this model is that all the fixed, geometric information is stored with the :code:`module.register_buffer()` function. This is what makes it possible to move all the relevant tensors between devices using a single call to :code:`module.to()`, for example. It stores thetensor as an object attribute, but it also registers it so that pytorch is aware that this attribute helps to encode the state of the model. The supporting information we need is the wavelength of the illumination, the basis of the probe array in real space, and an offset to define the zero point of the translation. @@ -268,6 +272,8 @@ The Adam optimizer is designed so that the learning rate sets the maximum stepsi This is important to remember when adding additional error models. Rescaling all the parameters to have a typical amplitude near 1 is the best way to get well-behaved reconstructions. +The final two lines assign a loss function and its associated normalizer. The loss function is stored as an instance attribute rather than defined as a method, which allows it to be swapped out at construction time. The normalizer is a stateful object that accumulates statistics over the first epoch and uses them to convert the raw summed loss into a normalized mean value. Here we use :code:`amplitude_mse` and its paired :code:`AmplitudeMSENormalizer`, which computes the mean squared error between the square roots of the simulated and measured intensities. + Initialization from Dataset +++++++++++++++++++++++++++ @@ -347,7 +353,7 @@ Here, we take input in the form of an index and a translation. Note that this in We start by mapping the translation, given in real space, into pixel coordinates. Then, we use an "off-the-shelf" interaction model - :code:`ptycho_2d_round`, which models a standard 2D ptychography interaction, but rounds the translations to the nearest whole pixel (does not attempt subpixel translations). -The next three definitions amount to just choosing an off-the-shelf function to simulate each step in the chain. +The next two definitions amount to just choosing an off-the-shelf function to simulate each step in the chain. .. code-block:: python @@ -357,11 +363,8 @@ The next three definitions amount to just choosing an off-the-shelf function to def measurement(self, wavefields): return tools.measurements.intensity(wavefields) - def loss(self, sim_data, real_data): - return tools.losses.amplitude_mse(real_data, sim_data) - -The forward propagator maps the exit wave to the wave at the surface of the detector, here using a far-field propagator. The measurement maps that exit wave to a measured pixel value, and the loss defines a loss function to attempt to minimize. The loss function we've chosen - the amplitude mean squared error - is the most reliable one, and can also easily be overridden by an end user. +The forward propagator maps the exit wave to the wave at the surface of the detector, here using a far-field propagator. The measurement maps that wavefield to a measured pixel value. The loss function was already assigned in :code:`__init__` as described above. Plotting diff --git a/examples/near_field_ptycho.py b/examples/near_field_ptycho.py index c91076f8..e87e3c37 100644 --- a/examples/near_field_ptycho.py +++ b/examples/near_field_ptycho.py @@ -5,7 +5,6 @@ dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) dataset.inspect() -plt.show() # Setting near_field equal to True uses an angular spectrum propagator in # lieu of the default Fourier-transform propagator for far-field ptychography. @@ -26,7 +25,8 @@ near_field=True, propagation_distance=3.65e-3, # 3.65 downstream from focus units='um', # Set the units for the live plots - obj_view_crop=-35, + obj_view_crop=-35, # Expand the view for the live plots + loss="poisson_nll", # Best option for photon-counting detectors ) device = 'cuda' @@ -48,9 +48,6 @@ if model.epoch % 10 == 0: model.inspect(dataset) -# This orthogonalizes the recovered probe modes -model.tidy_probes() - model.inspect(dataset) model.compare(dataset) plt.show() diff --git a/examples/tutorial_simple_ptycho.py b/examples/tutorial_simple_ptycho.py index 7e8a3e5d..dbce73c2 100644 --- a/examples/tutorial_simple_ptycho.py +++ b/examples/tutorial_simple_ptycho.py @@ -41,6 +41,10 @@ def __init__( self.probe = t.nn.Parameter(probe_guess / self.probe_norm) self.obj = t.nn.Parameter(obj_guess) + # We register a loss function and an appropriate normalization + self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() + @classmethod def from_dataset(cls, dataset): @@ -102,9 +106,6 @@ def forward_propagator(self, wavefields): def measurement(self, wavefields): return tools.measurements.intensity(wavefields) - def loss(self, real_data, sim_data): - return tools.losses.amplitude_mse(real_data, sim_data) - # This lists all the plots to display on a call to model.inspect() plot_list = [ diff --git a/src/cdtools/models/bragg_2d_ptycho.py b/src/cdtools/models/bragg_2d_ptycho.py index 507323d0..83fc6c88 100644 --- a/src/cdtools/models/bragg_2d_ptycho.py +++ b/src/cdtools/models/bragg_2d_ptycho.py @@ -77,6 +77,7 @@ def __init__( propagate_probe=True, correct_tilt=True, lens=False, + loss='amplitude mse', units='um', dtype=t.float32, obj_view_crop=0, @@ -235,7 +236,22 @@ def __init__( # TODO: probably doesn't support non-float-32 dtypes self.register_buffer('universal_propagator', universal_propagator) - + + # Here we set the appropriate loss function + if (loss.lower().strip() == 'amplitude mse' + or loss.lower().strip() == 'amplitude_mse'): + self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() + elif (loss.lower().strip() == 'poisson nll' + or loss.lower().strip() == 'poisson_nll'): + self.loss = tools.losses.poisson_nll + self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer() + elif (loss.lower().strip() == 'intensity mse' + or loss.lower().strip() == 'intensity_mse'): + self.loss = tools.losses.intensity_mse + self.loss_normalizer = tools.losses.IntensityMSENormalizer() + else: + raise KeyError('Specified loss function not supported') @classmethod @@ -255,6 +271,7 @@ def from_dataset( propagate_probe=True, correct_tilt=True, lens=False, + loss='amplitude mse', obj_padding=200, obj_view_crop=None, units='um', @@ -446,6 +463,7 @@ def from_dataset( propagate_probe=propagate_probe, correct_tilt=correct_tilt, lens=lens, + loss=loss, obj_view_crop=obj_view_crop, units=units, ) @@ -528,10 +546,6 @@ def measurement(self, wavefields): ) - def loss(self, sim_data, real_data, mask=None): - return tools.losses.amplitude_mse(real_data, sim_data, mask=mask) - - def sim_to_dataset(self, args_list): # In the future, potentially add more control # over what metadata is saved (names, etc.) diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 0b3d4997..662f601f 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -219,9 +219,15 @@ def __init__(self, if (loss.lower().strip() == 'amplitude mse' or loss.lower().strip() == 'amplitude_mse'): self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() elif (loss.lower().strip() == 'poisson nll' or loss.lower().strip() == 'poisson_nll'): self.loss = tools.losses.poisson_nll + self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer() + elif (loss.lower().strip() == 'intensity mse' + or loss.lower().strip() == 'intensity_mse'): + self.loss = tools.losses.intensity_mse + self.loss_normalizer = tools.losses.IntensityMSENormalizer() else: raise KeyError('Specified loss function not supported') diff --git a/src/cdtools/models/multislice_2d_ptycho.py b/src/cdtools/models/multislice_2d_ptycho.py index 1d6b3cd2..122ce34f 100644 --- a/src/cdtools/models/multislice_2d_ptycho.py +++ b/src/cdtools/models/multislice_2d_ptycho.py @@ -46,6 +46,7 @@ def __init__(self, fourier_probe=False, prevent_aliasing=True, phase_only=False, + loss='amplitude mse', units='um', ): @@ -152,9 +153,25 @@ def __init__(self, self.as_prop = tools.propagators.generate_angular_spectrum_propagator(shape, spacing, self.wavelength, self.dz, self.bandlimit) + # Here we set the appropriate loss function + if (loss.lower().strip() == 'amplitude mse' + or loss.lower().strip() == 'amplitude_mse'): + self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() + elif (loss.lower().strip() == 'poisson nll' + or loss.lower().strip() == 'poisson_nll'): + self.loss = tools.losses.poisson_nll + self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer() + elif (loss.lower().strip() == 'intensity mse' + or loss.lower().strip() == 'intensity_mse'): + self.loss = tools.losses.intensity_mse + self.loss_normalizer = tools.losses.IntensityMSENormalizer() + else: + raise KeyError('Specified loss function not supported') + @classmethod - def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None): + def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None, loss='amplitude mse'): wavelength = dataset.wavelength det_basis = dataset.detector_geometry['basis'] @@ -296,7 +313,8 @@ def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n exponentiate_obj=exponentiate_obj, units=units, fourier_probe=fourier_probe, phase_only=phase_only, - prevent_aliasing=prevent_aliasing) + prevent_aliasing=prevent_aliasing, + loss=loss) def interaction(self, index, translations): @@ -410,11 +428,6 @@ def measurement(self, wavefields): oversampling=self.oversampling) - def loss(self, sim_data, real_data, mask=None): - return tools.losses.amplitude_mse(real_data, sim_data, mask=mask) - #return tools.losses.poisson_nll(real_data, sim_data, mask=mask,eps=0.5) - - def to(self, *args, **kwargs): super(Multislice2DPtycho, self).to(*args, **kwargs) self.wavelength = self.wavelength.to(*args,**kwargs) diff --git a/src/cdtools/models/multislice_ptycho.py b/src/cdtools/models/multislice_ptycho.py index afcd83e2..81ea0a41 100644 --- a/src/cdtools/models/multislice_ptycho.py +++ b/src/cdtools/models/multislice_ptycho.py @@ -164,14 +164,20 @@ def __init__(self, self.register_buffer('simulate_finite_pixels', t.as_tensor(simulate_finite_pixels, dtype=bool)) - + # Here we set the appropriate loss function if (loss.lower().strip() == 'amplitude mse' or loss.lower().strip() == 'amplitude_mse'): self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() elif (loss.lower().strip() == 'poisson nll' or loss.lower().strip() == 'poisson_nll'): self.loss = tools.losses.poisson_nll + self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer() + elif (loss.lower().strip() == 'intensity mse' + or loss.lower().strip() == 'intensity_mse'): + self.loss = tools.losses.intensity_mse + self.loss_normalizer = tools.losses.IntensityMSENormalizer() else: raise KeyError('Specified loss function not supported') diff --git a/src/cdtools/models/rpi.py b/src/cdtools/models/rpi.py index 339b6236..1ecc0f1e 100644 --- a/src/cdtools/models/rpi.py +++ b/src/cdtools/models/rpi.py @@ -56,6 +56,7 @@ def __init__( exponentiate_obj=False, phase_only=False, propagation_distance=0, + loss='amplitude mse', units='um', dtype=t.float32, ): @@ -145,6 +146,22 @@ def __init__( self.register_buffer('prop_dir', t.as_tensor([0, 0, 1], dtype=dtype)) + # Here we set the appropriate loss function + if (loss.lower().strip() == 'amplitude mse' + or loss.lower().strip() == 'amplitude_mse'): + self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() + elif (loss.lower().strip() == 'poisson nll' + or loss.lower().strip() == 'poisson_nll'): + self.loss = tools.losses.poisson_nll + self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer() + elif (loss.lower().strip() == 'intensity mse' + or loss.lower().strip() == 'intensity_mse'): + self.loss = tools.losses.intensity_mse + self.loss_normalizer = tools.losses.IntensityMSENormalizer() + else: + raise KeyError('Specified loss function not supported') + @classmethod def from_dataset( @@ -163,6 +180,7 @@ def from_dataset( exponentiate_obj=False, phase_only=False, probe_threshold=0, + loss='amplitude mse', dtype=t.float32, ): complex_dtype = (t.ones([1], dtype=dtype) + @@ -247,14 +265,15 @@ def from_dataset( obj_support = t.as_tensor(binary_dilation(obj_support)) rpi_object = cls(wavelength, det_geo, ew_basis, - probe, dummy_init_obj, + probe, dummy_init_obj, background=background, mask=mask, saturation=saturation, obj_support=obj_support, oversampling=oversampling, exponentiate_obj=exponentiate_obj, phase_only=phase_only, - weight_matrix=weight_matrix) + weight_matrix=weight_matrix, + loss=loss) # I don't love this pattern, where I do the "real" obj initialization # after creating the rpi object. But, I chose this so that I could @@ -283,6 +302,7 @@ def from_calibration( exponentiate_obj=False, phase_only=False, initialization='random', + loss='amplitude mse', dtype=t.float32 ): @@ -327,6 +347,7 @@ def from_calibration( mask=mask, exponentiate_obj=exponentiate_obj, phase_only=phase_only, + loss=loss, ) rpi_object.init_obj(initialization) diff --git a/src/cdtools/models/simple_ptycho.py b/src/cdtools/models/simple_ptycho.py index 5435c662..9bda5e65 100644 --- a/src/cdtools/models/simple_ptycho.py +++ b/src/cdtools/models/simple_ptycho.py @@ -41,6 +41,10 @@ def __init__( self.probe = t.nn.Parameter(probe_guess / self.probe_norm) self.obj = t.nn.Parameter(obj_guess) + # We register a loss function and an appropriate normalization + self.loss = tools.losses.amplitude_mse + self.loss_normalizer = tools.losses.AmplitudeMSENormalizer() + @classmethod def from_dataset(cls, dataset): @@ -99,12 +103,10 @@ def interaction(self, index, translations): def forward_propagator(self, wavefields): return tools.propagators.far_field(wavefields) + def measurement(self, wavefields): return tools.measurements.intensity(wavefields) - def loss(self, real_data, sim_data): - return tools.losses.amplitude_mse(real_data, sim_data) - # This lists all the plots to display on a call to model.inspect() plot_list = [ diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 16668149..4f94d8b8 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -161,7 +161,9 @@ def run_epoch(self, # The data loader is responsible for setting the minibatch # size, so each set is a minibatch for inputs, patterns in self.data_loader: - normalization += t.sum(patterns).cpu().numpy() + if hasattr(self.model, 'loss_normalizer') and \ + self.model.loss_normalizer is not None: + self.model.loss_normalizer.accumulate(patterns) N += 1 def closure(): @@ -217,7 +219,9 @@ def closure(): # This takes the step for this minibatch loss += self.optimizer.step(closure).detach().cpu().numpy() - loss /= normalization + if hasattr(self.model, 'loss_normalizer') and \ + self.model.loss_normalizer is not None: + loss = self.model.loss_normalizer.normalize_loss(loss) # We step the scheduler after the full epoch if self.scheduler is not None: @@ -232,7 +236,7 @@ def closure(): def optimize(self, iterations: int, batch_size: int = 1, - custom_data_loader: torch.utils.data.DataLoader = None, + custom_data_loader: t.utils.data.DataLoader = None, regularization_factor: Union[float, List[float]] = None, thread: bool = True, calculation_width: int = 10, diff --git a/src/cdtools/tools/losses/losses.py b/src/cdtools/tools/losses/losses.py index 44e5c5d3..5b304fc9 100644 --- a/src/cdtools/tools/losses/losses.py +++ b/src/cdtools/tools/losses/losses.py @@ -8,7 +8,14 @@ import torch as t -__all__ = ['amplitude_mse', 'intensity_mse', 'poisson_nll'] +__all__ = [ + 'amplitude_mse', + 'AmplitudeMSENormalizer', + 'intensity_mse', + 'IntensityMSENormalizer', + 'poisson_nll', + 'SimplePoissonNLLNormalizer', +] def amplitude_mse(intensities, sim_intensities, mask=None): @@ -17,20 +24,23 @@ def amplitude_mse(intensities, sim_intensities, mask=None): Calculates the mean squared error between a given set of measured diffraction intensities and a simulated set. + This function calculates the mean squared error between their associated amplitudes. Because this is not well defined for negative numbers, make sure that all the intensities are >0 before using this - loss. Note that this is actually a sum-squared error, because this - formulation makes it vastly simpler to compare error calculations - between reconstructions with different minibatch size. I hope to - find a better way to do this that is more honest with this - cost function, though. + loss. + + Note that this is actually, by defauly, a sum-squared error. In this + case, it is intended to be used with the loss normalization strategy + in the base CDIModel class, which works well if the minibatch size + is not fixed. It can accept intensity and simulated intensity tensors of any shape as long as their shapes match, and the provided mask array can be broadcast correctly along them. - This is empirically the most useful loss function for most cases + This is empirically the most useful loss function for most cases where + a photon counting detector cannot be used. Parameters ---------- @@ -40,6 +50,8 @@ def amplitude_mse(intensities, sim_intensities, mask=None): A tensor of simulated detector intensities mask : torch.Tensor A mask with ones for pixels to include and zeros for pixels to exclude + use_sum : bool + Default is True. If set to True, actually performs the sum squared error Returns ------- @@ -60,6 +72,48 @@ def amplitude_mse(intensities, sim_intensities, mask=None): t.sqrt(masked_intensities))**2) +class AmplitudeMSENormalizer(object): + """ Normalizer for the amplitude MSE loss, used with recon.optimize + + This is a normalizer designed for use with the recon.optimize function. The + normalization is done separately from the loss, in order to make it simple to + use different normalization strategies for different loss metrics and to make it + easier to work with different minibatch sizes. + + This normalizer accumulates the total number of pixels across all patterns + during the first epoch, then divides the summed loss by this count to + convert from sum-squared error to mean-squared error. + + The normalizer is stateful: it completes its accumulation phase on the + first epoch and then applies the same normalization factor for all + subsequent epochs. + + Methods + ------- + accumulate(patterns, mask=None) + Accumulate the normalization factor (called once per minibatch). + normalize_loss(loss) + Apply the accumulated normalization (called once per epoch). + + """ + + def __init__(self): + self.first_pass_complete = False + self.num_pix = 0 + + def accumulate(self, patterns, mask=None): + if not self.first_pass_complete: + if mask is None: + self.num_pix += patterns.numel() + else: + self.num_pix += patterns.masked_select(mask).numel() + + def normalize_loss(self, loss): + if not self.first_pass_complete: + self.first_pass_complete = True + + return loss / self.num_pix + def intensity_mse(intensities, sim_intensities, mask=None): """ Returns the mean squared error of a simulated dataset's intensities @@ -72,6 +126,9 @@ def intensity_mse(intensities, sim_intensities, mask=None): It can accept intensity and simulated intensity tensors of any shape as long as their shapes match, and the provided mask array can be broadcast correctly along them. + + This is rarely a good loss function for ptychography, but can occasionally + be useful. Parameters ---------- @@ -98,6 +155,72 @@ def intensity_mse(intensities, sim_intensities, mask=None): / masked_intensities.shape[0] +class IntensityMSENormalizer(object): + """ Normalizer for the intensity MSE loss, used with recon.optimize + + This is a normalizer designed for use with the recon.optimize function. The + normalization is done separately from the loss, in order to make it simple to + use different normalization strategies for different loss metrics and to make it + easier to work with different minibatch sizes. + + This normalizer accumulates the total number of pixels across all patterns + during the first epoch, then divides the summed loss by this count to + convert from sum-squared error to mean-squared error. + + The normalizer is stateful: it completes its accumulation phase on the + first epoch and then applies the same normalization factor for all + subsequent epochs. + + Methods + ------- + accumulate(patterns, mask=None) + Accumulate the normalization factor (called once per minibatch). + normalize_loss(loss) + Apply the accumulated normalization (called once per epoch). + + """ + + def __init__(self): + self.first_pass_complete = False + self.num_pix = 0 + + def accumulate(self, patterns, mask=None): + """Accumulate pixel counts from a batch of patterns. + + Parameters + ---------- + patterns : torch.Tensor + A tensor of measured detector patterns + mask : torch.Tensor, optional + A mask with ones for pixels to include and zeros for pixels to + exclude. If provided, only masked pixels are counted. + + """ + if not self.first_pass_complete: + if mask is None: + self.num_pix += patterns.numel() + else: + self.num_pix += patterns.masked_select(mask).numel() + + def normalize_loss(self, loss): + """Convert summed loss to mean loss by dividing by pixel count. + + Parameters + ---------- + loss : torch.Tensor + The accumulated summed loss across minibatches in an epoch + + Returns + ------- + normalized_loss : torch.Tensor + The loss divided by the total number of pixels + + """ + if not self.first_pass_complete: + self.first_pass_complete = True + + return loss / self.num_pix + def poisson_nll( intensities, @@ -124,6 +247,9 @@ def poisson_nll( The default value of eps is 1e-6 - a nonzero value here helps avoid divergence of the log function near zero. + + This is generally the best loss metric to use for ptychography when + a photon counting detector is used. Parameters ---------- @@ -135,6 +261,8 @@ def poisson_nll( A mask with ones for pixels to include and zeros for pixels to exclude eps : float Optional, a small number to add to the simulated intensities + subtract_min : bool + Default is False, whether to subtract a min to produce a nonnegative output Returns ------- @@ -144,8 +272,7 @@ def poisson_nll( """ if mask is None: nll = t.sum(sim_intensities+eps - - t.xlogy(intensities,sim_intensities+eps)) \ - / intensities.view(-1).shape[0] + t.xlogy(intensities,sim_intensities+eps)) if subtract_min: nll -= t.sum(intensities - t.xlogy(intensities,intensities)) @@ -155,17 +282,108 @@ def poisson_nll( masked_sims = sim_intensities.masked_select(mask) nll = t.sum(masked_sims + eps - \ - t.xlogy(masked_intensities, masked_sims+eps)) \ - / masked_intensities.shape[0] + t.xlogy(masked_intensities, masked_sims+eps)) if subtract_min: nll -= t.nansum(masked_intensities - \ - t.xlogy(masked_intensities, masked_intensities)) \ - / masked_intensities.shape[0] + t.xlogy(masked_intensities, masked_intensities)) return nll +class SimplePoissonNLLNormalizer(object): + """ Normalizer for the intensity MSE loss, used with recon.optimize + + This is a normalizer designed for use with the recon.optimize function. The + normalization is done separately from the loss, in order to make it simple to + use different normalization strategies for different loss metrics and to make it + easier to work with different minibatch sizes. + + This normalizer converts raw Poisson negative log likelihood values into + a statistic that is more interpretable for comparing reconstructions. It + performs two operations: + + 1. **Offset subtraction**: Subtracts the NLL calculated when comparing + measured patterns to themselves (i.e., poisson_nll(data, data)). This + represents the best-case scenario and makes the loss non-negative. + + 2. **Normalization scaling**: Divides by 0.5 times the count of non-zero + pixels in the measured patterns. This is because, roughly, each non-zero + pixel is expected to contribute to the Poisson NLL, if Poisson noise were + the only relevant source of noise in the data. + + The normalizer is stateful: it completes its accumulation phase on the + first epoch by processing all patterns in the data, then applies the + same normalization factors for all subsequent epochs. + + Methods + ------- + accumulate(patterns, mask=None) + Accumulate the normalization factor (called once per minibatch). + normalize_loss(loss) + Apply the accumulated normalization (called once per epoch). + + """ + + def __init__(self): + self.first_pass_complete = False + self.sum_nonzero = 0 + self.offset = 0 + + def accumulate(self, patterns, mask=None): + """Accumulate statistics needed for normalization from a batch. + + During the first epoch, this method counts non-zero pixels and + computes the Poisson NLL comparing patterns to themselves, which + defines the offset baseline for the loss. + + Parameters + ---------- + patterns : torch.Tensor + A tensor of measured detector patterns + mask : torch.Tensor, optional + A mask with ones for pixels to include and zeros for pixels to + exclude. If provided, only masked pixels are counted. + + """ + if not self.first_pass_complete: + if mask is None: + self.sum_nonzero += t.sum(patterns >= 1) + self.offset += poisson_nll(patterns, patterns) + else: + masked_pats = patterns.masked_select(mask) + self.sum_nonzero += t.sum(masked_pats >= 1) + self.offset += poisson_nll(masked_pats, masked_pats) + + + def normalize_loss(self, loss): + """Normalize the Poisson NLL for interpretability across datasets. + + Parameters + ---------- + loss : torch.Tensor + The accumulated Poisson NLL across minibatches in an epoch + + Returns + ------- + normalized_loss : torch.Tensor + The offset-corrected and scaled loss value + + """ + if not self.first_pass_complete: + self.normalization = 0.5 * self.sum_nonzero + self.first_pass_complete = True + + return (loss - self.offset) / self.normalization + + +# +# Note: I have two other ideas for how to normalize the Poisson NLL +# +# Idea 2: Use the mean pattern to estimate the expected error +# Idea 3: Use the simulated intensities to estimate it, but use detach +# so it doesn't hit the backward pass +# def poisson_plus_fixed_nll( intensities, diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index db05a75e..df8dc51f 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -96,7 +96,7 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): model.compare(dataset) # If this fails, the reconstruction has gotten worse - assert model.loss_history[-1] < 0.0013 + assert model.loss_history[-1] < 0.38 @pytest.mark.slow @@ -110,6 +110,7 @@ def test_near_field_ptycho(near_field_ptycho_cxi, reconstruction_device, show_pl n_modes=1, near_field=True, propagation_distance=3.65e-3, # 3.65 downstream from focus + loss='poisson_nll', ) print('Running reconstruction on provided reconstruction_device,', @@ -134,4 +135,4 @@ def test_near_field_ptycho(near_field_ptycho_cxi, reconstruction_device, show_pl model.compare(dataset) # If this fails, the reconstruction has gotten worse - assert model.loss_history[-1] < 0.005 + assert model.loss_history[-1] < 17 diff --git a/tests/models/test_simple_ptycho.py b/tests/models/test_simple_ptycho.py index b6b18680..b0061fc2 100644 --- a/tests/models/test_simple_ptycho.py +++ b/tests/models/test_simple_ptycho.py @@ -26,4 +26,4 @@ def test_simple_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): model.compare(dataset) # If this fails, the reconstruction got worse - assert model.loss_history[-1] < 0.013 + assert model.loss_history[-1] < 6.5 diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py index 132a8c27..e1038788 100644 --- a/tests/test_reconstructors.py +++ b/tests/test_reconstructors.py @@ -3,7 +3,6 @@ import torch as t import numpy as np import pickle -from matplotlib import pyplot as plt from copy import deepcopy @@ -36,7 +35,7 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): probe_support_radius=50, propagation_distance=2e-6, units='um', - probe_fourier_crop=pad + probe_fourier_crop=pad, ) model.translation_offsets.data += 0.7 * \ @@ -116,7 +115,40 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): # comes from running a reconstruction when it was working well and # choosing a rough value. If it triggers this assertion error, something # changed to make the final quality worse! - assert model_recon.loss_history[-1] < 0.0001 + assert model_recon.loss_history[-1] < 0.09 + + +@pytest.mark.slow +def test_intensity_MSE(gold_ball_cxi, reconstruction_device, show_plot): + """ + This test checks that the intensity_mse loss function works end-to-end + with the AdamReconstructor, using the Au particle dataset. + """ + + print('\nTesting performance on the standard gold balls dataset ' + + 'with intensity_mse loss') + + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=1, + propagation_distance=-3e-6, + units='nm', + loss='intensity_mse', + ) + + model.to(device=reconstruction_device) + dataset.get_as(device=reconstruction_device) + + recon = cdtools.reconstructors.AdamReconstructor(model=model, + dataset=dataset) + t.manual_seed(0) + + for loss in recon.optimize(5, lr=.05, batch_size=10): + print(model.report()) + + # Threshold to be updated after running on a GPU machine + assert model.loss_history[-1] < 91 @pytest.mark.slow @@ -207,7 +239,7 @@ def test_LBFGS_RPI(optical_data_ss_cxi, # The final loss when testing this was 2.28607e-3. Based on this, we set # a threshold of 2.3e-3 for the tested loss. If this value has been # exceeded, the reconstructions have gotten worse. - assert model_recon.loss_history[-1] < 0.0023 + assert model_recon.loss_history[-1] < 0.14 @pytest.mark.slow @@ -316,4 +348,4 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): # The final loss when testing this was 7.12188e-4. Based on this, we set # a threshold of 7.2e-4 for the tested loss. If this value has been # exceeded, the reconstructions have gotten worse. - assert model.loss_history[-1] < 0.00072 + assert model.loss_history[-1] < 0.65 diff --git a/tests/tools/test_losses.py b/tests/tools/test_losses.py index b36b85df..35e98c83 100644 --- a/tests/tools/test_losses.py +++ b/tests/tools/test_losses.py @@ -1,5 +1,6 @@ import numpy as np import torch as t +from scipy.special import xlogy from cdtools.tools import losses @@ -52,22 +53,24 @@ def test_intensity_mse(): def test_poisson_nll(): - # Make some fake data - data = np.random.rand(10, 100, 100) - # And add some noise to it + # Make some fake data spread over a realistic photon-count range, + # with ~5% of pixels set to zero + data = 10 * np.random.rand(10, 100, 100) + data[np.random.rand(10, 100, 100) < 0.05] = 0 + # Add some noise, but set ~5% of sim pixels to exactly match data sim = data + 0.1 * np.random.rand(10, 100, 100) + exact_match = np.random.rand(10, 100, 100) < 0.05 + sim[exact_match] = data[exact_match] # and define a simple mask that needs to be broadcast mask = (np.random.rand(100, 100) > 0.1).astype(bool) # First, test without a mask - np_result = np.sum(sim - data * np.log(sim)) - np_result /= data.size + np_result = np.sum(sim - xlogy(data, sim)) torch_result = losses.poisson_nll(t.from_numpy(data), t.from_numpy(sim), eps=0) assert np.isclose(np_result, np.take(torch_result.numpy(), 0)) # Then, test with a mask - np_result = np.sum(mask * (sim - data * np.log(sim))) - np_result /= np.count_nonzero(mask * np.ones_like(data)) + np_result = np.sum(mask * (sim - xlogy(data, sim))) torch_result = losses.poisson_nll(t.from_numpy(data), t.from_numpy(sim), mask=t.from_numpy(mask), eps=0) assert np.isclose(np_result, np.take(torch_result.numpy(), 0))