Skip to content
Open
15 changes: 9 additions & 6 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ There is no requirement for what the arguments to the initialization function of
self.probe = t.nn.Parameter(probe_guess / self.probe_norm)
self.obj = t.nn.Parameter(obj_guess)


# We register a loss function and an appropriate normalization
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()


The first thing to notice about this model is that all the fixed, geometric information is stored with the :code:`module.register_buffer()` function. This is what makes it possible to move all the relevant tensors between devices using a single call to :code:`module.to()`, for example. It stores thetensor as an object attribute, but it also registers it so that pytorch is aware that this attribute helps to encode the state of the model.

The supporting information we need is the wavelength of the illumination, the basis of the probe array in real space, and an offset to define the zero point of the translation.
Expand All @@ -268,6 +272,8 @@ The Adam optimizer is designed so that the learning rate sets the maximum stepsi

This is important to remember when adding additional error models. Rescaling all the parameters to have a typical amplitude near 1 is the best way to get well-behaved reconstructions.

The final two lines assign a loss function and its associated normalizer. The loss function is stored as an instance attribute rather than defined as a method, which allows it to be swapped out at construction time. The normalizer is a stateful object that accumulates statistics over the first epoch and uses them to convert the raw summed loss into a normalized mean value. Here we use :code:`amplitude_mse` and its paired :code:`AmplitudeMSENormalizer`, which computes the mean squared error between the square roots of the simulated and measured intensities.


Initialization from Dataset
+++++++++++++++++++++++++++
Expand Down Expand Up @@ -347,7 +353,7 @@ Here, we take input in the form of an index and a translation. Note that this in

We start by mapping the translation, given in real space, into pixel coordinates. Then, we use an "off-the-shelf" interaction model - :code:`ptycho_2d_round`, which models a standard 2D ptychography interaction, but rounds the translations to the nearest whole pixel (does not attempt subpixel translations).

The next three definitions amount to just choosing an off-the-shelf function to simulate each step in the chain.
The next two definitions amount to just choosing an off-the-shelf function to simulate each step in the chain.

.. code-block:: python

Expand All @@ -357,11 +363,8 @@ The next three definitions amount to just choosing an off-the-shelf function to
def measurement(self, wavefields):
return tools.measurements.intensity(wavefields)

def loss(self, sim_data, real_data):
return tools.losses.amplitude_mse(real_data, sim_data)


The forward propagator maps the exit wave to the wave at the surface of the detector, here using a far-field propagator. The measurement maps that exit wave to a measured pixel value, and the loss defines a loss function to attempt to minimize. The loss function we've chosen - the amplitude mean squared error - is the most reliable one, and can also easily be overridden by an end user.
The forward propagator maps the exit wave to the wave at the surface of the detector, here using a far-field propagator. The measurement maps that wavefield to a measured pixel value. The loss function was already assigned in :code:`__init__` as described above.


Plotting
Expand Down
7 changes: 2 additions & 5 deletions examples/near_field_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

dataset.inspect()
plt.show()

# Setting near_field equal to True uses an angular spectrum propagator in
# lieu of the default Fourier-transform propagator for far-field ptychography.
Expand All @@ -26,7 +25,8 @@
near_field=True,
propagation_distance=3.65e-3, # 3.65 downstream from focus
units='um', # Set the units for the live plots
obj_view_crop=-35,
obj_view_crop=-35, # Expand the view for the live plots
loss="poisson_nll", # Best option for photon-counting detectors
)

device = 'cuda'
Expand All @@ -48,9 +48,6 @@
if model.epoch % 10 == 0:
model.inspect(dataset)

# This orthogonalizes the recovered probe modes
model.tidy_probes()

model.inspect(dataset)
model.compare(dataset)
plt.show()
7 changes: 4 additions & 3 deletions examples/tutorial_simple_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
self.probe = t.nn.Parameter(probe_guess / self.probe_norm)
self.obj = t.nn.Parameter(obj_guess)

# We register a loss function and an appropriate normalization
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()


@classmethod
def from_dataset(cls, dataset):
Expand Down Expand Up @@ -102,9 +106,6 @@ def forward_propagator(self, wavefields):
def measurement(self, wavefields):
return tools.measurements.intensity(wavefields)

def loss(self, real_data, sim_data):
return tools.losses.amplitude_mse(real_data, sim_data)


# This lists all the plots to display on a call to model.inspect()
plot_list = [
Expand Down
24 changes: 19 additions & 5 deletions src/cdtools/models/bragg_2d_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
propagate_probe=True,
correct_tilt=True,
lens=False,
loss='amplitude mse',
units='um',
dtype=t.float32,
obj_view_crop=0,
Expand Down Expand Up @@ -235,7 +236,22 @@ def __init__(
# TODO: probably doesn't support non-float-32 dtypes
self.register_buffer('universal_propagator',
universal_propagator)


# Here we set the appropriate loss function
if (loss.lower().strip() == 'amplitude mse'
or loss.lower().strip() == 'amplitude_mse'):
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()
elif (loss.lower().strip() == 'poisson nll'
or loss.lower().strip() == 'poisson_nll'):
self.loss = tools.losses.poisson_nll
self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer()
elif (loss.lower().strip() == 'intensity mse'
or loss.lower().strip() == 'intensity_mse'):
self.loss = tools.losses.intensity_mse
self.loss_normalizer = tools.losses.IntensityMSENormalizer()
else:
raise KeyError('Specified loss function not supported')


@classmethod
Expand All @@ -255,6 +271,7 @@ def from_dataset(
propagate_probe=True,
correct_tilt=True,
lens=False,
loss='amplitude mse',
obj_padding=200,
obj_view_crop=None,
units='um',
Expand Down Expand Up @@ -446,6 +463,7 @@ def from_dataset(
propagate_probe=propagate_probe,
correct_tilt=correct_tilt,
lens=lens,
loss=loss,
obj_view_crop=obj_view_crop,
units=units,
)
Expand Down Expand Up @@ -528,10 +546,6 @@ def measurement(self, wavefields):
)


def loss(self, sim_data, real_data, mask=None):
return tools.losses.amplitude_mse(real_data, sim_data, mask=mask)


def sim_to_dataset(self, args_list):
# In the future, potentially add more control
# over what metadata is saved (names, etc.)
Expand Down
6 changes: 6 additions & 0 deletions src/cdtools/models/fancy_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,15 @@ def __init__(self,
if (loss.lower().strip() == 'amplitude mse'
or loss.lower().strip() == 'amplitude_mse'):
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()
elif (loss.lower().strip() == 'poisson nll'
or loss.lower().strip() == 'poisson_nll'):
self.loss = tools.losses.poisson_nll
self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer()
elif (loss.lower().strip() == 'intensity mse'
or loss.lower().strip() == 'intensity_mse'):
self.loss = tools.losses.intensity_mse
self.loss_normalizer = tools.losses.IntensityMSENormalizer()
else:
raise KeyError('Specified loss function not supported')

Expand Down
27 changes: 20 additions & 7 deletions src/cdtools/models/multislice_2d_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self,
fourier_probe=False,
prevent_aliasing=True,
phase_only=False,
loss='amplitude mse',
units='um',
):

Expand Down Expand Up @@ -152,9 +153,25 @@ def __init__(self,

self.as_prop = tools.propagators.generate_angular_spectrum_propagator(shape, spacing, self.wavelength, self.dz, self.bandlimit)

# Here we set the appropriate loss function
if (loss.lower().strip() == 'amplitude mse'
or loss.lower().strip() == 'amplitude_mse'):
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()
elif (loss.lower().strip() == 'poisson nll'
or loss.lower().strip() == 'poisson_nll'):
self.loss = tools.losses.poisson_nll
self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer()
elif (loss.lower().strip() == 'intensity mse'
or loss.lower().strip() == 'intensity_mse'):
self.loss = tools.losses.intensity_mse
self.loss_normalizer = tools.losses.IntensityMSENormalizer()
else:
raise KeyError('Specified loss function not supported')


@classmethod
def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None):
def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None, loss='amplitude mse'):

wavelength = dataset.wavelength
det_basis = dataset.detector_geometry['basis']
Expand Down Expand Up @@ -296,7 +313,8 @@ def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n
exponentiate_obj=exponentiate_obj,
units=units, fourier_probe=fourier_probe,
phase_only=phase_only,
prevent_aliasing=prevent_aliasing)
prevent_aliasing=prevent_aliasing,
loss=loss)


def interaction(self, index, translations):
Expand Down Expand Up @@ -410,11 +428,6 @@ def measurement(self, wavefields):
oversampling=self.oversampling)


def loss(self, sim_data, real_data, mask=None):
return tools.losses.amplitude_mse(real_data, sim_data, mask=mask)
#return tools.losses.poisson_nll(real_data, sim_data, mask=mask,eps=0.5)


def to(self, *args, **kwargs):
super(Multislice2DPtycho, self).to(*args, **kwargs)
self.wavelength = self.wavelength.to(*args,**kwargs)
Expand Down
8 changes: 7 additions & 1 deletion src/cdtools/models/multislice_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,20 @@ def __init__(self,

self.register_buffer('simulate_finite_pixels',
t.as_tensor(simulate_finite_pixels, dtype=bool))

# Here we set the appropriate loss function
if (loss.lower().strip() == 'amplitude mse'
or loss.lower().strip() == 'amplitude_mse'):
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()
elif (loss.lower().strip() == 'poisson nll'
or loss.lower().strip() == 'poisson_nll'):
self.loss = tools.losses.poisson_nll
self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer()
elif (loss.lower().strip() == 'intensity mse'
or loss.lower().strip() == 'intensity_mse'):
self.loss = tools.losses.intensity_mse
self.loss_normalizer = tools.losses.IntensityMSENormalizer()
else:
raise KeyError('Specified loss function not supported')

Expand Down
25 changes: 23 additions & 2 deletions src/cdtools/models/rpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
exponentiate_obj=False,
phase_only=False,
propagation_distance=0,
loss='amplitude mse',
units='um',
dtype=t.float32,
):
Expand Down Expand Up @@ -145,6 +146,22 @@ def __init__(
self.register_buffer('prop_dir',
t.as_tensor([0, 0, 1], dtype=dtype))

# Here we set the appropriate loss function
if (loss.lower().strip() == 'amplitude mse'
or loss.lower().strip() == 'amplitude_mse'):
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()
elif (loss.lower().strip() == 'poisson nll'
or loss.lower().strip() == 'poisson_nll'):
self.loss = tools.losses.poisson_nll
self.loss_normalizer = tools.losses.SimplePoissonNLLNormalizer()
elif (loss.lower().strip() == 'intensity mse'
or loss.lower().strip() == 'intensity_mse'):
self.loss = tools.losses.intensity_mse
self.loss_normalizer = tools.losses.IntensityMSENormalizer()
else:
raise KeyError('Specified loss function not supported')


@classmethod
def from_dataset(
Expand All @@ -163,6 +180,7 @@ def from_dataset(
exponentiate_obj=False,
phase_only=False,
probe_threshold=0,
loss='amplitude mse',
dtype=t.float32,
):
complex_dtype = (t.ones([1], dtype=dtype) +
Expand Down Expand Up @@ -247,14 +265,15 @@ def from_dataset(
obj_support = t.as_tensor(binary_dilation(obj_support))

rpi_object = cls(wavelength, det_geo, ew_basis,
probe, dummy_init_obj,
probe, dummy_init_obj,
background=background, mask=mask,
saturation=saturation,
obj_support=obj_support,
oversampling=oversampling,
exponentiate_obj=exponentiate_obj,
phase_only=phase_only,
weight_matrix=weight_matrix)
weight_matrix=weight_matrix,
loss=loss)

# I don't love this pattern, where I do the "real" obj initialization
# after creating the rpi object. But, I chose this so that I could
Expand Down Expand Up @@ -283,6 +302,7 @@ def from_calibration(
exponentiate_obj=False,
phase_only=False,
initialization='random',
loss='amplitude mse',
dtype=t.float32
):

Expand Down Expand Up @@ -327,6 +347,7 @@ def from_calibration(
mask=mask,
exponentiate_obj=exponentiate_obj,
phase_only=phase_only,
loss=loss,
)

rpi_object.init_obj(initialization)
Expand Down
8 changes: 5 additions & 3 deletions src/cdtools/models/simple_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
self.probe = t.nn.Parameter(probe_guess / self.probe_norm)
self.obj = t.nn.Parameter(obj_guess)

# We register a loss function and an appropriate normalization
self.loss = tools.losses.amplitude_mse
self.loss_normalizer = tools.losses.AmplitudeMSENormalizer()


@classmethod
def from_dataset(cls, dataset):
Expand Down Expand Up @@ -99,12 +103,10 @@ def interaction(self, index, translations):
def forward_propagator(self, wavefields):
return tools.propagators.far_field(wavefields)


def measurement(self, wavefields):
return tools.measurements.intensity(wavefields)

def loss(self, real_data, sim_data):
return tools.losses.amplitude_mse(real_data, sim_data)


# This lists all the plots to display on a call to model.inspect()
plot_list = [
Expand Down
10 changes: 7 additions & 3 deletions src/cdtools/reconstructors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def run_epoch(self,
# The data loader is responsible for setting the minibatch
# size, so each set is a minibatch
for inputs, patterns in self.data_loader:
normalization += t.sum(patterns).cpu().numpy()
if hasattr(self.model, 'loss_normalizer') and \
self.model.loss_normalizer is not None:
self.model.loss_normalizer.accumulate(patterns)
N += 1

def closure():
Expand Down Expand Up @@ -217,7 +219,9 @@ def closure():
# This takes the step for this minibatch
loss += self.optimizer.step(closure).detach().cpu().numpy()

loss /= normalization
if hasattr(self.model, 'loss_normalizer') and \
self.model.loss_normalizer is not None:
loss = self.model.loss_normalizer.normalize_loss(loss)

# We step the scheduler after the full epoch
if self.scheduler is not None:
Expand All @@ -232,7 +236,7 @@ def closure():
def optimize(self,
iterations: int,
batch_size: int = 1,
custom_data_loader: torch.utils.data.DataLoader = None,
custom_data_loader: t.utils.data.DataLoader = None,
regularization_factor: Union[float, List[float]] = None,
thread: bool = True,
calculation_width: int = 10,
Expand Down
Loading
Loading