Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f3581b8
A first prototype from claude to test on Ra
allevitan Mar 20, 2026
44fe36d
Kill a bunch of bugs and make the flow more sensible by hand
allevitan Mar 20, 2026
dd744a0
Improve plotting infrastructure and fix various bugs
allevitan Mar 20, 2026
f426083
Add title option to plot_image and wrappers; misc fixes
allevitan Mar 20, 2026
e7f943e
Made a few more updates to some example scripts, to show the panel pl…
allevitan Mar 20, 2026
91db5b3
Update remaining models to new plot registration system
allevitan Mar 20, 2026
f35141d
Check that all the models work (all but Multislice2DPtycho, which was…
allevitan Mar 20, 2026
bef907c
Added a minimum plotting interval to model.inspect(dataset) to make i…
allevitan Mar 20, 2026
dafb204
Fix a bug where closed windows wouldn't reopen at the original size
allevitan Mar 20, 2026
f8ff525
Add example patterns for jupyter notebooks
allevitan Mar 21, 2026
a4a3853
Update the plot_image functions to show sliders
allevitan Mar 21, 2026
f04ac9f
Make the colorized plot look nicer, with a more perceptually uniform …
allevitan Mar 21, 2026
29656fd
Improve the colorized plotting further, and add a colorbar which will…
allevitan Mar 21, 2026
f979388
Fix an annoying warning coming from double-setting the figsize
allevitan Mar 21, 2026
4201b41
Made more adjustments to the plotting system to avoid using constaine…
allevitan Mar 24, 2026
5c5c0be
Reformat plotting function signatures to one-parameter-per-line style
allevitan Mar 24, 2026
25c7b66
Do a final review of all the examples to ensure they work and are wel…
allevitan Mar 24, 2026
29e9fdd
Add test coverage for plot_translations, plot_nanomap, and plot_nanom…
allevitan Mar 24, 2026
ad2b09f
Update the tests to work better when checking the model plotting, and…
allevitan Mar 24, 2026
603486e
Change the default colormap for exponentiated objects to match that f…
allevitan Mar 24, 2026
f55b35f
Stop stopping at each plot to show it
allevitan Mar 24, 2026
4172a9a
A few small changes to revert unimportant edits and fix linting issues
allevitan Mar 24, 2026
0bd8df4
Updated the documentation to reflect the changes to the plotting system
allevitan Mar 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
ignore = E501, W503
ignore = E501, W503, E731
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ build/*
dist
*/example_data/*
*.h5
.DS_Store
.DS_Store
.ipynb_checkpoints
30 changes: 24 additions & 6 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ When reading this script, note the basic workflow. After the data is loaded, a m

Next, the model is moved to the GPU using the :code:`model.to` function. Any device understood by :code:`torch.Tensor.to` can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the :code:`dataset.get_as` function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using :code:`dataset.to`, but the speedup is empirically quite small.

Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run.
Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. Inside the loop, :code:`model.inspect(dataset)` is called every epoch to live-update a set of plots showing the current state of the model parameters.

Finally, the results can be studied using :code:`model.inspect(dataset)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.
Finally, :code:`model.compare(dataset)` is called to show how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.


Fancy Ptycho
Expand Down Expand Up @@ -63,11 +63,13 @@ By default, FancyPtycho will also optimize over the following model parameters,

These corrections can be turned off (on) by calling :code:`model.<parameter>.requires_grad = False #(True)`.

Note as well two other changes that are made in this script, when compared to `simple_ptycho.py`. First, a `Reconstructor` object is explicitly created, in this case an `AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to `Reconstructor.optimize()`.
Note as well two other changes that are made in this script, when compared to :code:`simple_ptycho.py`. First, a :code:`Reconstructor` object is explicitly created, in this case an :code:`AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to :code:`Reconstructor.optimize()`.

We use this pattern, instead of the simpler call to `model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object.
We use this pattern, instead of the simpler call to :code:`model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object.

In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. `model.LBFGS_optimize()`).
In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. :code:`model.LBFGS_optimize()`).

Note also the use of :code:`min_interval=10` in the calls to :code:`model.inspect(dataset)`. Because generating plots can be expensive, passing a minimum interval (in seconds) prevents excessive replots. Finally, the call to :code:`model.inspect(dataset, replot_all=True)` at the end of the script reopens any plot windows that the user may have closed during the reconstruction, so that all results are visible at the end.


Gold Ball Ptycho
Expand All @@ -77,11 +79,27 @@ This script shows how the FancyPtycho model might be used in a realistic situati

.. literalinclude:: ../../examples/gold_ball_ptycho.py

Note, in particular, the use of :code:`model.save_on_exception` and :code:`model.save_to_h5` to save the results of the reconstruction. If a different file format is required, :code:`model.save_results` will save to a pure-python dictionary.
Note first the explicit addition of the :code:`plot_level=2` argument in the call to :code:`FancyPtycho.from_dataset`. This value controls which plots are generated. With :code:`plot_level=1`, only the main results are shown - :code:`plot_level=2` shows some more advanced monitoring of the error correction terms (background, position error, etc.), and :code:`plot_level=3` shows all registered plots.

Note also the use of :code:`model.save_on_exception` and :code:`model.save_to_h5` to save the results of the reconstruction. If a different file format is required, :code:`model.save_results` will save to a pure-python dictionary which can be processed further.

Finally, note that there are several small adjustments made to the script to counteract particular sources of error that are present in this dataset, for example the raster grid pathology caused by the scan pattern used. Also note that not every mixin is needed every time - in this case, we turn off optimization of the :code:`weights` parameter.


Near-Field Ptycho
-----------------

This script shows how the FancyPtycho model can be used on a typical near-field ptychography (also known as Fresnel ptychography) dataset.

.. literalinclude:: ../../examples/near_field_ptycho.py

The major change here is the setting of the :code:`near_field=True` argument to :code:`FancyPtycho.from_dataset`. This changes the propagator to a near-field propagator. As noted in the comments, if :code:`propagation_distance` is not set, the model will assume a standard near-field geomtry with flat illumination.

If :code:`propagation_distance` is set, it will assume a Fresnel scaling theorem-type geometry, with :code:`propagation_distance` as the focus-to-sample distance, and the distance set in the dataset object as the sample-to-detector distance.

Finally, note the addition of the :code:`panel_plot_mode=True` argument. This is the default mode, and returns the plots in a panel format, good for easily monitoring the progress of a reconstruction. If individual plots are needed for use in presentations, papers, or otherwise, setting :code:`panel_plot_mode=False` will plot each output in it's own window.


Gold Ball Split
---------------

Expand Down
40 changes: 27 additions & 13 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -367,28 +367,42 @@ The forward propagator maps the exit wave to the wave at the surface of the dete
Plotting
++++++++

The base CDIModel class has a function, :code:`model.inspect()`, which looks for a class variable called :code:`plot_list` and plots everything contained within. The plot list should be formatted as a list of tuples, with each tuple containing:
The base CDIModel class has a function, :code:`model.inspect()`, which looks for a class variable called :code:`plot_list` and plots everything contained within. The plot list should be formatted as a list of dictionaries, with each dictionary containing:

* :code:`'title'`: the title of the plot
* :code:`'plot_func'`: a function that takes in the model (and optionally a figure) and generates the relevant plot
* :code:`'condition'` (optional): a function that takes in the model and returns whether or not the plot should be generated

* The title of the plot
* A function that takes in the model and generates the relevant plot
* Optional, a function that takes in the model and returns whether or not the plot should be generated

.. code-block:: python

# This lists all the plots to display on a call to model.inspect()
plot_list = [
('Probe Amplitude',
lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)),
('Probe Phase',
lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)),
('Object Amplitude',
lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)),
('Object Phase',
lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis))
{
'title': 'Probe Amplitude',
'plot_func': lambda self, fig:
p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis),
},
{
'title': 'Probe Phase',
'plot_func': lambda self, fig:
p.plot_phase(self.probe, fig=fig, basis=self.probe_basis),
},
{
'title': 'Object Amplitude',
'plot_func': lambda self, fig:
p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis),
},
{
'title': 'Object Phase',
'plot_func': lambda self, fig:
p.plot_phase(self.obj, fig=fig, basis=self.probe_basis),
},
]

In this case, we've made use of the convenience plotting functions defined in :code:`tools.plotting`.

More advanced models like :code:`FancyPtycho` also define a :code:`plot_panel_list`, which groups related plots together into multi-subplot figures. The :code:`panel_plot_mode` argument (passed at construction time) controls whether these panels are rendered as combined multi-subplot figures or as individual windows. For a simple model like :code:`SimplePtycho`, :code:`plot_list` is sufficient.


Saving
++++++
Expand Down
19 changes: 10 additions & 9 deletions examples/fancy_ptycho.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import cdtools
import torch as t
from matplotlib import pyplot as plt

filename = 'example_data/lab_ptycho_data.cxi'
Expand All @@ -15,9 +16,9 @@
obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix
)

device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
if t.cuda.is_available():
model.to(device='cuda')
dataset.get_as(device='cuda')

# For this script, we use a slightly different pattern where we explicitly
# create a `Reconstructor` class to orchestrate the reconstruction. The
Expand All @@ -31,22 +32,22 @@
# The batch size sets the minibatch size
for loss in recon.optimize(50, lr=0.02, batch_size=10):
print(model.report())
# Plotting is expensive, so we only do it every tenth epoch
if model.epoch % 10 == 0:
model.inspect(dataset)
# Because plotting can be expensive, setting a minimum plotting interval
# (in seconds) can avoid excessive replots.
model.inspect(dataset, min_interval=10)

# It's common to chain several different reconstruction loops. Here, we
# started with an aggressive refinement to find the probe in the previous
# loop, and now we polish the reconstruction with a lower learning rate
# and larger minibatch
for loss in recon.optimize(50, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
model.inspect(dataset, min_interval=10)

# This orthogonalizes the recovered probe modes
model.tidy_probes()

model.inspect(dataset)
# Setting replot_all will reopen any windows which were closed earlier
model.inspect(dataset, replot_all=True)
model.compare(dataset)
plt.show()
151 changes: 151 additions & 0 deletions examples/fancy_ptycho_inline.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "286054ce",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import cdtools\n",
"import torch as t\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "955bb242-e2ed-47c3-919c-1ea690681445",
"metadata": {},
"outputs": [],
"source": [
"# Load and inspect a dataset\n",
"\n",
"filename = 'example_data/lab_ptycho_data.cxi'\n",
"dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)\n",
"\n",
"dataset.inspect();"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82f71fb4-f013-46bb-817b-1973ba23336a",
"metadata": {},
"outputs": [],
"source": [
"# Initialize a model from the dataset and move it to the GPU.\n",
"\n",
"model = cdtools.models.FancyPtycho.from_dataset(\n",
" dataset,\n",
" n_modes=3, # Use 3 incoherently mixing probe modes\n",
" oversampling=2, # Simulate the probe on a 2xlarger real-space array\n",
" probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix\n",
" propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm\n",
" units='mm', # Set the units for the live plots\n",
" obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix,\n",
")\n",
"\n",
"if t.cuda.is_available():\n",
" model.to(device='cuda')\n",
" dataset.get_as(device='cuda')\n",
"\n",
"# Then, create a reconstructor object and view the initialized model\n",
"\n",
"recon = cdtools.reconstructors.AdamReconstructor(model, dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e22a65e-280b-428d-a6f5-f3206d22110b",
"metadata": {},
"outputs": [],
"source": [
"# Workaround reconstruction pattern for interactive plotting in jupyter:\n",
"# First, a standalone cell to plot the current model state\n",
"\n",
"model.inspect(dataset, replot_all=True);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81a768b5",
"metadata": {},
"outputs": [],
"source": [
"# Second, a cell for running the reconstruction. With this pattern, it is safe\n",
"# to interrupt the kernel. Then, the cell above can be re-run to refresh the plots.\n",
"while model.epoch < 50:\n",
" for loss in recon.optimize(1, lr=0.02, batch_size=10):\n",
" print(model.report())\n",
"\n",
"while model.epoch < 100:\n",
" for loss in recon.optimize(1, lr=0.005, batch_size=10):\n",
" print(model.report())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "237e9286-b6cf-41dc-aeb0-89abfe59b37c",
"metadata": {},
"outputs": [],
"source": [
"# Save out the results\n",
"\n",
"model.save_to_h5('lab_ptycho_reconstruction.h5', dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f1565b6",
"metadata": {},
"outputs": [],
"source": [
"# Finalize the plotting and create the comparison plot\n",
"\n",
"# This orthogonalizes the recovered probe modes. It is best to do so\n",
"# after saving the results, if you intend to initialize any further\n",
"# reconstructions with the probe.\n",
"model.tidy_probes()\n",
"\n",
"# Final plotting\n",
"model.inspect(dataset)\n",
"model.compare(dataset);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63ec3aa4-0d1c-4775-9fb2-3002d404faa4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading