diff --git a/.flake8 b/.flake8 index 9214cb9e..3e08f47c 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,2 @@ [flake8] -ignore = E501, W503 \ No newline at end of file +ignore = E501, W503, E731 \ No newline at end of file diff --git a/.gitignore b/.gitignore index ea4c46c9..ec5e6feb 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ build/* dist */example_data/* *.h5 -.DS_Store \ No newline at end of file +.DS_Store +.ipynb_checkpoints \ No newline at end of file diff --git a/docs/source/examples.rst b/docs/source/examples.rst index a3ef41eb..4c101c90 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -31,9 +31,9 @@ When reading this script, note the basic workflow. After the data is loaded, a m Next, the model is moved to the GPU using the :code:`model.to` function. Any device understood by :code:`torch.Tensor.to` can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the :code:`dataset.get_as` function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using :code:`dataset.to`, but the speedup is empirically quite small. -Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. +Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. Inside the loop, :code:`model.inspect(dataset)` is called every epoch to live-update a set of plots showing the current state of the model parameters. -Finally, the results can be studied using :code:`model.inspect(dataset)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. +Finally, :code:`model.compare(dataset)` is called to show how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. Fancy Ptycho @@ -63,11 +63,13 @@ By default, FancyPtycho will also optimize over the following model parameters, These corrections can be turned off (on) by calling :code:`model..requires_grad = False #(True)`. -Note as well two other changes that are made in this script, when compared to `simple_ptycho.py`. First, a `Reconstructor` object is explicitly created, in this case an `AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to `Reconstructor.optimize()`. +Note as well two other changes that are made in this script, when compared to :code:`simple_ptycho.py`. First, a :code:`Reconstructor` object is explicitly created, in this case an :code:`AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to :code:`Reconstructor.optimize()`. -We use this pattern, instead of the simpler call to `model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object. +We use this pattern, instead of the simpler call to :code:`model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object. -In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. `model.LBFGS_optimize()`). +In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. :code:`model.LBFGS_optimize()`). + +Note also the use of :code:`min_interval=10` in the calls to :code:`model.inspect(dataset)`. Because generating plots can be expensive, passing a minimum interval (in seconds) prevents excessive replots. Finally, the call to :code:`model.inspect(dataset, replot_all=True)` at the end of the script reopens any plot windows that the user may have closed during the reconstruction, so that all results are visible at the end. Gold Ball Ptycho @@ -77,11 +79,27 @@ This script shows how the FancyPtycho model might be used in a realistic situati .. literalinclude:: ../../examples/gold_ball_ptycho.py -Note, in particular, the use of :code:`model.save_on_exception` and :code:`model.save_to_h5` to save the results of the reconstruction. If a different file format is required, :code:`model.save_results` will save to a pure-python dictionary. +Note first the explicit addition of the :code:`plot_level=2` argument in the call to :code:`FancyPtycho.from_dataset`. This value controls which plots are generated. With :code:`plot_level=1`, only the main results are shown - :code:`plot_level=2` shows some more advanced monitoring of the error correction terms (background, position error, etc.), and :code:`plot_level=3` shows all registered plots. + +Note also the use of :code:`model.save_on_exception` and :code:`model.save_to_h5` to save the results of the reconstruction. If a different file format is required, :code:`model.save_results` will save to a pure-python dictionary which can be processed further. Finally, note that there are several small adjustments made to the script to counteract particular sources of error that are present in this dataset, for example the raster grid pathology caused by the scan pattern used. Also note that not every mixin is needed every time - in this case, we turn off optimization of the :code:`weights` parameter. +Near-Field Ptycho +----------------- + +This script shows how the FancyPtycho model can be used on a typical near-field ptychography (also known as Fresnel ptychography) dataset. + +.. literalinclude:: ../../examples/near_field_ptycho.py + +The major change here is the setting of the :code:`near_field=True` argument to :code:`FancyPtycho.from_dataset`. This changes the propagator to a near-field propagator. As noted in the comments, if :code:`propagation_distance` is not set, the model will assume a standard near-field geomtry with flat illumination. + +If :code:`propagation_distance` is set, it will assume a Fresnel scaling theorem-type geometry, with :code:`propagation_distance` as the focus-to-sample distance, and the distance set in the dataset object as the sample-to-detector distance. + +Finally, note the addition of the :code:`panel_plot_mode=True` argument. This is the default mode, and returns the plots in a panel format, good for easily monitoring the progress of a reconstruction. If individual plots are needed for use in presentations, papers, or otherwise, setting :code:`panel_plot_mode=False` will plot each output in it's own window. + + Gold Ball Split --------------- diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 1533e838..8a2fa863 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -367,28 +367,42 @@ The forward propagator maps the exit wave to the wave at the surface of the dete Plotting ++++++++ -The base CDIModel class has a function, :code:`model.inspect()`, which looks for a class variable called :code:`plot_list` and plots everything contained within. The plot list should be formatted as a list of tuples, with each tuple containing: +The base CDIModel class has a function, :code:`model.inspect()`, which looks for a class variable called :code:`plot_list` and plots everything contained within. The plot list should be formatted as a list of dictionaries, with each dictionary containing: + +* :code:`'title'`: the title of the plot +* :code:`'plot_func'`: a function that takes in the model (and optionally a figure) and generates the relevant plot +* :code:`'condition'` (optional): a function that takes in the model and returns whether or not the plot should be generated -* The title of the plot -* A function that takes in the model and generates the relevant plot -* Optional, a function that takes in the model and returns whether or not the plot should be generated - .. code-block:: python # This lists all the plots to display on a call to model.inspect() plot_list = [ - ('Probe Amplitude', - lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)), - ('Probe Phase', - lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)), - ('Object Phase', - lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis)) + { + 'title': 'Probe Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis), + }, + { + 'title': 'Probe Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.probe, fig=fig, basis=self.probe_basis), + }, + { + 'title': 'Object Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis), + }, + { + 'title': 'Object Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.obj, fig=fig, basis=self.probe_basis), + }, ] In this case, we've made use of the convenience plotting functions defined in :code:`tools.plotting`. +More advanced models like :code:`FancyPtycho` also define a :code:`plot_panel_list`, which groups related plots together into multi-subplot figures. The :code:`panel_plot_mode` argument (passed at construction time) controls whether these panels are rendered as combined multi-subplot figures or as individual windows. For a simple model like :code:`SimplePtycho`, :code:`plot_list` is sufficient. + Saving ++++++ diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index ebee68f0..d6dfc748 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -1,4 +1,5 @@ import cdtools +import torch as t from matplotlib import pyplot as plt filename = 'example_data/lab_ptycho_data.cxi' @@ -15,9 +16,9 @@ obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix ) -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # For this script, we use a slightly different pattern where we explicitly # create a `Reconstructor` class to orchestrate the reconstruction. The @@ -31,9 +32,9 @@ # The batch size sets the minibatch size for loss in recon.optimize(50, lr=0.02, batch_size=10): print(model.report()) - # Plotting is expensive, so we only do it every tenth epoch - if model.epoch % 10 == 0: - model.inspect(dataset) + # Because plotting can be expensive, setting a minimum plotting interval + # (in seconds) can avoid excessive replots. + model.inspect(dataset, min_interval=10) # It's common to chain several different reconstruction loops. Here, we # started with an aggressive refinement to find the probe in the previous @@ -41,12 +42,12 @@ # and larger minibatch for loss in recon.optimize(50, lr=0.005, batch_size=50): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=10) # This orthogonalizes the recovered probe modes model.tidy_probes() -model.inspect(dataset) +# Setting replot_all will reopen any windows which were closed earlier +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/examples/fancy_ptycho_inline.ipynb b/examples/fancy_ptycho_inline.ipynb new file mode 100644 index 00000000..f5cab0fe --- /dev/null +++ b/examples/fancy_ptycho_inline.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "286054ce", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import cdtools\n", + "import torch as t\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "955bb242-e2ed-47c3-919c-1ea690681445", + "metadata": {}, + "outputs": [], + "source": [ + "# Load and inspect a dataset\n", + "\n", + "filename = 'example_data/lab_ptycho_data.cxi'\n", + "dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)\n", + "\n", + "dataset.inspect();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82f71fb4-f013-46bb-817b-1973ba23336a", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize a model from the dataset and move it to the GPU.\n", + "\n", + "model = cdtools.models.FancyPtycho.from_dataset(\n", + " dataset,\n", + " n_modes=3, # Use 3 incoherently mixing probe modes\n", + " oversampling=2, # Simulate the probe on a 2xlarger real-space array\n", + " probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix\n", + " propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm\n", + " units='mm', # Set the units for the live plots\n", + " obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix,\n", + ")\n", + "\n", + "if t.cuda.is_available():\n", + " model.to(device='cuda')\n", + " dataset.get_as(device='cuda')\n", + "\n", + "# Then, create a reconstructor object and view the initialized model\n", + "\n", + "recon = cdtools.reconstructors.AdamReconstructor(model, dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e22a65e-280b-428d-a6f5-f3206d22110b", + "metadata": {}, + "outputs": [], + "source": [ + "# Workaround reconstruction pattern for interactive plotting in jupyter:\n", + "# First, a standalone cell to plot the current model state\n", + "\n", + "model.inspect(dataset, replot_all=True);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81a768b5", + "metadata": {}, + "outputs": [], + "source": [ + "# Second, a cell for running the reconstruction. With this pattern, it is safe\n", + "# to interrupt the kernel. Then, the cell above can be re-run to refresh the plots.\n", + "while model.epoch < 50:\n", + " for loss in recon.optimize(1, lr=0.02, batch_size=10):\n", + " print(model.report())\n", + "\n", + "while model.epoch < 100:\n", + " for loss in recon.optimize(1, lr=0.005, batch_size=10):\n", + " print(model.report())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "237e9286-b6cf-41dc-aeb0-89abfe59b37c", + "metadata": {}, + "outputs": [], + "source": [ + "# Save out the results\n", + "\n", + "model.save_to_h5('lab_ptycho_reconstruction.h5', dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f1565b6", + "metadata": {}, + "outputs": [], + "source": [ + "# Finalize the plotting and create the comparison plot\n", + "\n", + "# This orthogonalizes the recovered probe modes. It is best to do so\n", + "# after saving the results, if you intend to initialize any further\n", + "# reconstructions with the probe.\n", + "model.tidy_probes()\n", + "\n", + "# Final plotting\n", + "model.inspect(dataset)\n", + "model.compare(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63ec3aa4-0d1c-4775-9fb2-3002d404faa4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/fancy_ptycho_interactive.ipynb b/examples/fancy_ptycho_interactive.ipynb new file mode 100644 index 00000000..0e9adad6 --- /dev/null +++ b/examples/fancy_ptycho_interactive.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "286054ce", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib widget\n", + "import cdtools\n", + "import torch as t\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "955bb242-e2ed-47c3-919c-1ea690681445", + "metadata": {}, + "outputs": [], + "source": [ + "# Load and inspect a dataset\n", + "\n", + "filename = 'example_data/lab_ptycho_data.cxi'\n", + "dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)\n", + "\n", + "dataset.inspect();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82f71fb4-f013-46bb-817b-1973ba23336a", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize a model from the dataset and move it to the GPU.\n", + "\n", + "model = cdtools.models.FancyPtycho.from_dataset(\n", + " dataset,\n", + " n_modes=3, # Use 3 incoherently mixing probe modes\n", + " oversampling=2, # Simulate the probe on a 2xlarger real-space array\n", + " probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix\n", + " propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm\n", + " units='mm', # Set the units for the live plots\n", + " obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix,\n", + ")\n", + "\n", + "if t.cuda.is_available():\n", + " model.to(device='cuda')\n", + " dataset.get_as(device='cuda')\n", + "\n", + "# Then, create a reconstructor object and view the initialized model\n", + "\n", + "recon = cdtools.reconstructors.AdamReconstructor(model, dataset)\n", + "model.inspect(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81a768b5", + "metadata": {}, + "outputs": [], + "source": [ + "# Workaround reconstruction pattern for interactive plotting in jupyter:\n", + "\n", + "# With this pattern, it is safe to interrupt the kernel. Doing so will\n", + "# trigger an update of the plots, at which point the current state can\n", + "# be viewed. Then this cell can be re-run to continue the reconstruction\n", + "while model.epoch < 50:\n", + " for loss in recon.optimize(1, lr=0.02, batch_size=10):\n", + " print(model.report())\n", + " model.inspect(dataset, min_interval=10)\n", + "\n", + "while model.epoch < 100:\n", + " for loss in recon.optimize(1, lr=0.005, batch_size=10):\n", + " print(model.report())\n", + " model.inspect(dataset, min_interval=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "237e9286-b6cf-41dc-aeb0-89abfe59b37c", + "metadata": {}, + "outputs": [], + "source": [ + "# Save out the results\n", + "\n", + "model.save_to_h5('lab_ptycho_reconstruction.h5', dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f1565b6", + "metadata": {}, + "outputs": [], + "source": [ + "# Finalize the plotting and create the comparison plot\n", + "\n", + "# This orthogonalizes the recovered probe modes. It is best to do so\n", + "# after saving the results, if you intend to initialize any further\n", + "# reconstructions with the probe.\n", + "model.tidy_probes()\n", + "\n", + "# Final plotting\n", + "model.inspect(dataset)\n", + "model.compare(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63ec3aa4-0d1c-4775-9fb2-3002d404faa4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/gold_ball_ptycho.py b/examples/gold_ball_ptycho.py index 49719751..aeb510c1 100644 --- a/examples/gold_ball_ptycho.py +++ b/examples/gold_ball_ptycho.py @@ -1,6 +1,6 @@ import cdtools -from matplotlib import pyplot as plt import torch as t +from matplotlib import pyplot as plt filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi' dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) @@ -26,7 +26,8 @@ probe_support_radius=50, propagation_distance=2e-6, units='um', - probe_fourier_crop=pad + probe_fourier_crop=pad, + plot_level=2, ) @@ -39,9 +40,9 @@ # Not much probe intensity instability in this dataset, no need for this model.weights.requires_grad = False -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # Create the reconstructor recon = cdtools.reconstructors.AdamReconstructor(model, dataset) @@ -53,13 +54,11 @@ for loss in recon.optimize(20, lr=0.005, batch_size=50): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) for loss in recon.optimize(50, lr=0.002, batch_size=100): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) # We can often reset our guess of the probe positions once we have a # good guess of probe and object, but in this case it causes the @@ -70,8 +69,7 @@ # the loss fails to improve after 10 epochs for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) model.tidy_probes() @@ -79,6 +77,6 @@ # This saves the final result model.save_to_h5('example_reconstructions/gold_balls.h5', dataset) -model.inspect(dataset) +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/examples/gold_ball_split.py b/examples/gold_ball_split.py index 9fc5b083..fde19c6a 100644 --- a/examples/gold_ball_split.py +++ b/examples/gold_ball_split.py @@ -32,9 +32,9 @@ model.weights.requires_grad = False - device = 'cuda' - model.to(device=device) - dataset.get_as(device=device) + if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # Create the reconstructor recon = cdtools.reconstructors.AdamReconstructor(model, dataset) diff --git a/examples/gold_ball_synthesize.py b/examples/gold_ball_synthesize.py index 3a2c68aa..41319056 100644 --- a/examples/gold_ball_synthesize.py +++ b/examples/gold_ball_synthesize.py @@ -5,11 +5,11 @@ # We load all three reconstructions half_1 = cdtools.tools.data.h5_to_nested_dict( - f'example_reconstructions/gold_balls_half_1.h5') + 'example_reconstructions/gold_balls_half_1.h5') half_2 = cdtools.tools.data.h5_to_nested_dict( - f'example_reconstructions/gold_balls_half_2.h5') + 'example_reconstructions/gold_balls_half_2.h5') full = cdtools.tools.data.h5_to_nested_dict( - f'example_reconstructions/gold_balls_full.h5') + 'example_reconstructions/gold_balls_full.h5') # This defines the region of recovered object to use for the analysis. pad = 260 diff --git a/examples/near_field_ptycho.py b/examples/near_field_ptycho.py index c91076f8..af0012d2 100644 --- a/examples/near_field_ptycho.py +++ b/examples/near_field_ptycho.py @@ -1,11 +1,11 @@ import cdtools +import torch as t from matplotlib import pyplot as plt filename = 'example_data/PETRAIII_P25_Near_Field_Ptycho.cxi' dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) dataset.inspect() -plt.show() # Setting near_field equal to True uses an angular spectrum propagator in # lieu of the default Fourier-transform propagator for far-field ptychography. @@ -27,11 +27,12 @@ propagation_distance=3.65e-3, # 3.65 downstream from focus units='um', # Set the units for the live plots obj_view_crop=-35, + panel_plot_mode=True, # Set to False to get individual figures ) -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') model.inspect(dataset) @@ -39,18 +40,15 @@ for loss in recon.optimize(100, lr=0.04, batch_size=10): print(model.report()) - # Plotting is expensive, so we only do it every tenth epoch - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) for loss in recon.optimize(50, lr=0.005, batch_size=50): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) # This orthogonalizes the recovered probe modes model.tidy_probes() -model.inspect(dataset) +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index 217d96b0..41c0c6c2 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -8,6 +8,7 @@ correct for common sources of error. """ import cdtools +import torch as t from matplotlib import pyplot as plt # We load an example dataset from a .cxi file @@ -17,10 +18,12 @@ # We create a ptychography model from the dataset model = cdtools.models.SimplePtycho.from_dataset(dataset) -# We move the model to the GPU -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +# We move the model to the GPU, if possible +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') + +model.inspect(dataset) # We run the reconstruction for loss in model.Adam_optimize(100, dataset, batch_size=10): @@ -29,7 +32,6 @@ # And liveplot the updates to the model as they happen model.inspect(dataset) -# We study the results -model.inspect(dataset) +# We open a comparison of the simulated and measured data model.compare(dataset) plt.show() diff --git a/examples/transmission_RPI.py b/examples/transmission_RPI.py index feff19e8..32c02d55 100644 --- a/examples/transmission_RPI.py +++ b/examples/transmission_RPI.py @@ -1,5 +1,6 @@ import cdtools import pickle +import torch as t from matplotlib import pyplot as plt # First, we load an example dataset from a .cxi file @@ -22,26 +23,27 @@ # Let's do this reconstruction on the GPU, shall we? -model.to(device='cuda') -dataset.get_as(device='cuda') +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # Note that the inspect step takes the vast majority of the time # The regularization is an L2 regularizer that empirically helps accelerate # convergence for loss in model.LBFGS_optimize(30, dataset, lr=0.4, regularization_factor=[0.05,0.05]): - model.inspect(dataset) + model.inspect(dataset, min_interval=5) print(model.report()) # Now we use the regularizer to damp all but the top modes for loss in model.LBFGS_optimize(50, dataset, lr=0.4, regularization_factor=[0.001,0.1]): - #model.inspect(dataset) + model.inspect(dataset, min_interval=5) print(model.report()) -# Save results to a python dictionary -results = model.save_results() +# Save results to an h5 file +model.save_to_h5('example_reconstructions/transmission_RPI.h5', dataset) # Finally, we plot the results -model.inspect(dataset) +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/examples/tutorial_finale.py b/examples/tutorial_finale.py index 9f698b76..4e95ec63 100644 --- a/examples/tutorial_finale.py +++ b/examples/tutorial_finale.py @@ -1,5 +1,6 @@ from tutorial_basic_ptycho_dataset import BasicPtychoDataset from tutorial_simple_ptycho import SimplePtycho +import torch as t from h5py import File from matplotlib import pyplot as plt @@ -11,8 +12,9 @@ model = SimplePtycho.from_dataset(dataset) -model.to(device='cuda') -dataset.get_as(device='cuda') +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') for loss in model.Adam_optimize(10, dataset): model.inspect(dataset) diff --git a/examples/tutorial_simple_ptycho.py b/examples/tutorial_simple_ptycho.py index 7e8a3e5d..67544d74 100644 --- a/examples/tutorial_simple_ptycho.py +++ b/examples/tutorial_simple_ptycho.py @@ -108,14 +108,26 @@ def loss(self, real_data, sim_data): # This lists all the plots to display on a call to model.inspect() plot_list = [ - ('Probe Amplitude', - lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)), - ('Probe Phase', - lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)), - ('Object Phase', - lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis)) + { + 'title': 'Probe Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.probe, fig, basis=self.probe_basis), + }, + { + 'title': 'Probe Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.probe, fig, basis=self.probe_basis) + }, + { + 'title': 'Object Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.obj, fig, basis=self.probe_basis) + }, + { + 'title': 'Object Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.obj, fig, basis=self.probe_basis) + }, ] def save_results(self, dataset): diff --git a/pyproject.toml b/pyproject.toml index 9ba1438c..aacaecfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,4 +50,5 @@ docs = [ ] [tool.ruff] -line-length = 79 \ No newline at end of file +line-length = 79 +ignore = ["E501", "E731"] \ No newline at end of file diff --git a/src/cdtools/datasets/ptycho_2d_dataset.py b/src/cdtools/datasets/ptycho_2d_dataset.py index adfe866e..5a84ca14 100644 --- a/src/cdtools/datasets/ptycho_2d_dataset.py +++ b/src/cdtools/datasets/ptycho_2d_dataset.py @@ -245,8 +245,6 @@ def get_images(idx): return np.log10((meas_data * mask) + log_offset) else: return meas_data * mask - - translations = self.translations.detach().cpu().numpy() # This takes about twice as long as it would to just do it all at # once, but it avoids creating another self.patterns-sized array diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index df347b63..e7516593 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -28,8 +28,10 @@ """ +from sympy import Q import torch as t from torch.utils import data as torchdata +import matplotlib from matplotlib import pyplot as plt from matplotlib.widgets import Slider from matplotlib import ticker @@ -58,12 +60,26 @@ class CDIModel(t.nn.Module): functions. """ - def __init__(self): + def __init__(self, panel_plot_mode=False, plot_level=np.inf): + """Initializes the CDIModel base class. + + Parameters + ---------- + panel_plot_mode : bool, default: False + If True, plot_panel_list entries are rendered as multi-subplot + figures. If False, each subplot is rendered as its own figure. + plot_level : float, default: np.inf + Only plots whose plot_level <= this value are shown. + """ super(CDIModel, self).__init__() self.loss_history = [] self.training_history = '' self.epoch = 0 + self.panel_plot_mode = panel_plot_mode + self.plot_level = plot_level + self.has_inspect_been_called = False + self.last_inspected_time = None def from_dataset(self, dataset): raise NotImplementedError() @@ -547,97 +563,331 @@ def report(self): return msg - # By default, the plot_list is empty + # By default, the plot lists are empty + plot_panel_list = [] plot_list = [] - def inspect(self, dataset=None, update=True): - """Plots all the plots defined in the model's plot_list attribute + def inspect(self, dataset=None, replot_all=False, min_interval=None): + """Plots all the plots defined in the model's plot_panel_list and plot_list attributes - If update is set to True, it will update any previously plotted set - of plots, if one exists, and then redraw them. Otherwise, it will - plot a new set, and any subsequent updates will update the new set + Updates any previously plotted figures that are still open. Figures + that have been closed are left closed unless replot_all=True. Optionally, a dataset can be passed, which will allow plotting of any registered plots which need to incorporate some information from the dataset (such as geometry or a comparison with measured data). - Plots can be registered in any subclass by defining the plot_list - attribute. This should be a list of tuples in the following format: - ( 'Plot Title', function_to_generate_plot(self), - function_to_determine_whether_to_plot(self)) + Plots can be registered in any subclass by defining plot_panel_list + and/or plot_list class attributes. See the CDIModel documentation for + the expected dict-based format of each. + + When panel_plot_mode=True (set in __init__), plot_panel_list entries + are rendered as multi-subplot figures. When False (the default), + each subplot in plot_panel_list is rendered as its own figure, + prepended to any standalone plot_list entries. - Where the third element in the tuple (a function that returns - True if the plot is relevant) is not required. + The plot_level attribute (set in __init__, default np.inf) controls + which plots are shown: a panel or standalone plot is only shown when + its plot_level <= self.plot_level. Parameters ---------- dataset : CDataset Optional, a dataset matched to the model type - update : bool, default: True - Whether to update existing plots or plot new ones + replot_all : bool, default: False + If True, recreate figures that were previously closed by the user. + min_interval : float, optional + If set, skip updating plots if fewer than this many seconds have + elapsed since the last call to inspect(). The time of the last + update is stored in self.last_inspected_time. """ - # We find or create all the figures - first_update = False - if update and hasattr(self, 'figs') and self.figs: - figs = self.figs - elif update: - figs = None - self.figs = [] - first_update = True + if (min_interval is not None + and self.last_inspected_time is not None + and time.time() - self.last_inspected_time < min_interval): + return + + plot_panel_list = getattr(self, 'plot_panel_list', None) or [] + plot_list = getattr(self, 'plot_list', None) or [] + + if self.panel_plot_mode: + # First we plot all the panels + panel_figs = self._inspect_panel( + plot_panel_list, dataset=dataset, replot_all=replot_all) + # And then we plot all the individual figures + individual_figs = self._inspect_individual_figures( + plot_list, dataset=dataset, replot_all=replot_all + ) + self.figs = panel_figs + individual_figs else: - figs = None - self.figs = [] + # If not in panel plot mode, we first flatten the figures + # from the panels + flat = [] + for panel in plot_panel_list: + panel_level = panel.get('plot_level', 1) + for plot in panel['plots']: + # We add the plot level from the larger panel + flat.append({**plot, 'plot_level': panel_level}) + + all_plots = flat + list(plot_list) - idx = 0 - for plots in self.plot_list: - # If a conditional is included in the plot, we check whether - # it is True - try: - if len(plots) >=3 and not plots[2](self): - continue - except TypeError as e: - if len(plots) >= 3 and not plots[2](self, dataset): - continue - - name = plots[0] - plotter = plots[1] - - if figs is None: - fig = plt.figure() - self.figs.append(fig) + # We make sure to keep a reference to the open figs around + self.figs = self._inspect_individual_figures( + all_plots, dataset=dataset, replot_all=replot_all) + + if not self.has_inspect_been_called or replot_all: + # Somehow, this is needed for new figures to appear + if self._is_backend_interactive(): + plt.pause(0.05 * len(self.figs)) + for fig in self.figs: + fig.canvas.flush_events() + self.has_inspect_been_called = True + + self.last_inspected_time = time.time() + + + def _is_backend_interactive( + self + ): + """Returns True if the current matplotlib backend is interactive.""" + backend = matplotlib.get_backend().lower() + try: + # matplotlib >= 3.9 + interactive_bk = matplotlib.backends.backend_registry.list_builtin( + matplotlib.backends.BackendFilter.INTERACTIVE + ) + except AttributeError: + # older matplotlib + interactive_bk = matplotlib.rcsetup.interactive_bk + return backend in [b.lower() for b in interactive_bk] + + + def _inspect_individual_figures( + self, + plot_list, + dataset=None, + replot_all=False + ): + """Core one-figure-per-plot rendering logic. + + fig_map is a dict {title: figure} owned by the caller and updated + in-place. It tracks which figures are open across calls. + + Behaviour: + replot_all=False — closed figures are skipped (left closed). + replot_all=True — closed figures are recreated. + + Returns the list of figures that were rendered this call. + """ + + rendered = [] + + for plot in plot_list: + # Level filter + if plot.get('plot_level', 1) > self.plot_level: + continue + + # Condition check + condition = plot.get('condition', None) + if condition is not None: + try: + if not condition(self): + continue + except TypeError: + if not condition(self, dataset): + continue + + figsize = plot.get('figure_size', None) + if self.has_inspect_been_called and \ + not replot_all and \ + not plt.fignum_exists(plot['title']): + continue + + if not self.has_inspect_been_called: + fig = plt.figure(plot['title'], + figsize=figsize) else: - fig = figs[idx] + with plt.rc_context({'figure.raise_window': False}): + fig = plt.figure(plot['title'], + figsize = figsize) - - try: # We try just plotting using the simplest allowed signature - plotter(self,fig) - plt.title(name) - except TypeError as e: - # TypeError implies it wanted another argument, i.e. a dataset + try: + plot['plot_func'](self, fig) + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) + except TypeError: if dataset is not None: try: - plotter(self, fig, dataset) - plt.title(name) - except Exception as e: # Don't raise errors: it's just plots + plot['plot_func'](self, fig, dataset) + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) + except KeyboardInterrupt: + raise + except Exception: pass - - except Exception as e: # Don't raise errors, it's just a plot + except KeyboardInterrupt: + raise + except Exception: pass - idx += 1 + rendered.append(fig) + if self._is_backend_interactive(): + plt.draw() + + return rendered + + + def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): + """Multi-subplot panel rendering. + + Creates one figure per plot_panel_list entry, placing each subplot's + plot_func output into the appropriate axes. Closed panels stay closed + on subsequent calls unless replot_all=True. Standalone plot_list + entries are then rendered via _do_inspect and appended to self.figs. + """ - if update: - # This seems to update the figure without blocking. + rendered = [] + + for panel_def in plot_panel_list[::-1]: # Flip so first ones show on top + panel_level = panel_def.get('plot_level', 1) + if panel_level > self.plot_level: + continue # skip entire panel + + panel_condition = panel_def.get('condition', None) + if panel_condition is not None: + try: + if not panel_condition(self): + continue + except TypeError: + if not panel_condition(self, dataset): + continue + + nrows, ncols = panel_def['grid'] + figsize = panel_def.get('figure_size', None) + + + if self.has_inspect_been_called and \ + not replot_all and \ + not plt.fignum_exists(panel_def['title']): + continue + + if not self.has_inspect_been_called: + fig = plt.figure(panel_def['title'], figsize=figsize) + else: + with plt.rc_context({'figure.raise_window': False}): + fig = plt.figure(panel_def['title'], figsize=figsize) + + for subfig in fig.subfigs: + if hasattr(subfig, '_sliders'): + for slider in subfig._sliders: + slider.disconnect_events() + fig.clear() + + gs = fig.add_gridspec( + nrows, ncols, + width_ratios=[1]*ncols, + height_ratios=[1]*nrows, + ) + + for plot in panel_def['plots']: + condition = plot.get('condition', None) + if condition is not None: + try: + if not condition(self): + continue + except TypeError: + if not condition(self, dataset): + continue + subfig = fig.add_subfigure(gs[plot['subplot'][0], + plot['subplot'][1]]) + + try: + plot['plot_func'](self, subfig) + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) + except TypeError: + if dataset is not None: + try: + plot['plot_func'](self, subfig, dataset) + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) + except KeyboardInterrupt: + raise + except Exception: + pass + except KeyboardInterrupt: + raise + except Exception: + raise + + rendered.append(fig) + + if self._is_backend_interactive(): plt.draw() - fig.canvas.start_event_loop(0.001) - if first_update: - # But this is needed the first time the figures update, or - # they won't get drawn at all - plt.pause(0.05 * len(self.figs)) + return rendered + + def plot_loss_history(self, fig=None, clear_fig=True): + """Plots the loss history on a semilogy axis + + Parameters + ---------- + fig : matplotlib.figure.Figure + Default is a new figure, a matplotlib figure to use to plot + clear_fig : bool + Default is True. Whether to clear the figure before plotting. + + Returns + ------- + used_fig : matplotlib.figure.Figure + The figure object that was actually plotted to. + """ + + if fig is None: + fig = plt.figure() + + if clear_fig: + fig.clear() + + if len(fig.axes) >= 1: + ax = fig.axes[0] + else: + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + bbox = fig.bbox + main_fig_bbox = main_fig.bbox + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 + + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height + + im_ax_bottom = pad_bottom + im_ax_height = 1 - pad_top - im_ax_bottom + + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) + + ax.semilogy(self.loss_history) + plt.title('Loss History') + + ax.set_xlabel('Epoch') + ax.set_ylabel('Loss Metric') + return fig def save_figures(self, prefix='', extension='.pdf'): """Saves all currently open inspection figures. @@ -661,14 +911,15 @@ def save_figures(self, prefix='', extension='.pdf'): Default is .eps, the file extension to save with. """ - if hasattr(self, 'figs') and self.figs: - figs = self.figs - else: - return # No figures to save + if not (hasattr(self, 'figs') and self.figs): + return # No figures to save for fig in self.figs: - fig.savefig(prefix + fig.axes[0].get_title() + extension, - bbox_inches = 'tight') + if hasattr(fig, '_panel_label') and fig._panel_label: + label = fig._panel_label + else: + label = fig.axes[0].get_title() if fig.axes else 'figure' + fig.savefig(prefix + label + extension, bbox_inches='tight') def compare(self, dataset, logarithmic=False): diff --git a/src/cdtools/models/bragg_2d_ptycho.py b/src/cdtools/models/bragg_2d_ptycho.py index 507323d0..bd8b5803 100644 --- a/src/cdtools/models/bragg_2d_ptycho.py +++ b/src/cdtools/models/bragg_2d_ptycho.py @@ -80,6 +80,8 @@ def __init__( units='um', dtype=t.float32, obj_view_crop=0, + panel_plot_mode=False, + plot_level=1, ): # We need the detector geometry @@ -91,7 +93,8 @@ def __init__( # translation_offsets can stay 2D for now # propagate_probe and correct_tilt are important! - super(Bragg2DPtycho, self).__init__() + super(Bragg2DPtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.register_buffer('wavelength', t.as_tensor(wavelength, dtype=dtype)) self.store_detector_geometry(detector_geometry, @@ -258,7 +261,9 @@ def from_dataset( obj_padding=200, obj_view_crop=None, units='um', - surface_normal=None + surface_normal=None, + panel_plot_mode=False, + plot_level=1, ): wavelength = dataset.wavelength det_basis = dataset.detector_geometry['basis'] @@ -436,8 +441,8 @@ def from_dataset( return cls(wavelength, det_geo, obj_basis, probe, obj, min_translation=min_translation, probe_basis=probe_basis, - median_propagation =median_propagation, - translation_offsets = translation_offsets, + median_propagation=median_propagation, + translation_offsets=translation_offsets, weights=weights, mask=mask, background=background, translation_scale=translation_scale, saturation=saturation, @@ -448,6 +453,8 @@ def from_dataset( lens=lens, obj_view_crop=obj_view_crop, units=units, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) @@ -578,90 +585,90 @@ def corrected_translations(self,dataset): plot_list = [ - ('Basis Probe Fourier Space Amplitudes', - lambda self, fig: p.plot_amplitude(tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Basis Probe Fourier Space Phases', - lambda self, fig: p.plot_phase(tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Basis Probe Real Space Amplitudes, Surface Normal View', - lambda self, fig: p.plot_amplitude( + {'title': 'Basis Probe Fourier Space Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude(tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Basis Probe Fourier Space Phases', + 'plot_func': lambda self, fig: p.plot_phase(tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Basis Probe Real Space Amplitudes, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.probe, fig=fig, basis=self.probe_basis, units=self.units, - )), - ('Basis Probe Real Space Phases, Surface Normal View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Basis Probe Real Space Phases, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_phase( self.probe, fig=fig, basis=self.probe_basis, units=self.units, - )), - ('Basis Probe Real Space Amplitudes, Beam View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Basis Probe Real Space Amplitudes, Beam View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.probe, fig=fig, basis=self.probe_basis, view_basis=beam_basis, units=self.units, - )), - ('Basis Probe Real Space Phases, Beam View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Basis Probe Real Space Phases, Beam View', + 'plot_func': lambda self, fig: p.plot_phase( self.probe, fig=fig, basis=self.probe_basis, view_basis=beam_basis, units=self.units, - )), - ('Object Amplitude, Surface Normal View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Object Amplitude, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, - )), - ('Object Phase, Surface Normal View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Object Phase, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, - )), - ('Object Amplitude, Beam View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Object Amplitude, Beam View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=beam_basis, units=self.units, - )), - ('Object Phase, Beam View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Object Phase, Beam View', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=beam_basis, units=self.units, - )), - ('Object Amplitude, Detector View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Object Amplitude, Detector View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=self.det_basis, units=self.units, - )), - ('Object Phase, Detector View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Object Phase, Detector View', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=self.det_basis, units=self.units, - )), - ('Corrected Translations', - lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)), - ('Background', - lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)) + )}, + {'title': 'Corrected Translations', + 'plot_func': lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)}, + {'title': 'Background', + 'plot_func': lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)}, ] diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 0b3d4997..3387b15a 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -7,8 +7,6 @@ from matplotlib import pyplot as plt from datetime import datetime import numpy as np -from scipy import linalg as sla -from copy import copy __all__ = ['FancyPtycho'] @@ -45,9 +43,12 @@ def __init__(self, near_field=False, angular_spectrum_propagator=None, inv_angular_spectrum_propagator=None, + panel_plot_mode=True, + plot_level=2, ): - super(FancyPtycho, self).__init__() + super(FancyPtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.register_buffer('wavelength', t.as_tensor(wavelength, dtype=dtype)) self.store_detector_geometry(detector_geometry, @@ -252,6 +253,8 @@ def from_dataset(cls, obj_view_crop=None, obj_padding=200, near_field=False, + panel_plot_mode=True, + plot_level=2, ): wavelength = dataset.wavelength @@ -517,6 +520,8 @@ def from_dataset(cls, near_field=near_field, angular_spectrum_propagator=angular_spectrum_propagator, inv_angular_spectrum_propagator=inv_angular_spectrum_propagator, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) @@ -550,7 +555,7 @@ def interaction(self, index, translations, *args): else: try: Ws = t.ones(len(index)) # I'm positive this introduced a bug - except: + except TypeError: Ws = 1 if self.weights is None or len(self.weights[0].shape) == 0: @@ -775,7 +780,6 @@ def tidy_probes(self): # First we treat the incoherent but stable case, where the weights are # just one per-shot overall weight if self.weights.dim() == 1: - probe = self.probe.detach().cpu().numpy() ortho_probes = analysis.orthogonalize_probes(self.probe.detach()) self.probe.data = ortho_probes return @@ -839,6 +843,53 @@ def tidy_probes(self): # We discard the U matrix and re-multiply S & Vh self.weights.data = S[:,:,None] * (Vh / probe_sqrt_intensities) + + def get_probe_intensities(self): + """Returns the effective probe intensity at each scan position. + + Handles both the simple (1D weights) and OPRP (2D weights) cases. + + Returns + ------- + probe_intensities : np.ndarray + Array of probe intensities, one per scan position. + """ + if not hasattr(self, 'weights'): + raise NotImplementedError( + "I don't know how to handle having no weights") + elif self.weights.ndim == 1: + probe_intensities = self.weights.detach().cpu().numpy()**2 + else: + # The big case, with OPRP + probe_matrix = np.zeros([self.probe.shape[0]]*2, + dtype=np.complex64) + np_probes = self.probe.detach().cpu().numpy() + for i in range(probe_matrix.shape[0]): + for j in range(probe_matrix.shape[0]): + probe_matrix[i,j] = np.sum(np_probes[i]*np_probes[j].conj()) + + weights = self.weights.detach().cpu().numpy() + + # The outer one is a sum, because the tensordot is what broadcasts + # the probe matrix along the shot dimension - the second one + # doesn't have to. + weighted_probe_matrices = np.sum( + np.tensordot(weights, probe_matrix, axes=1)[...,None] + * weights.conj().transpose((0,2,1))[...,None,:,:], + axis=-2 + ) + + basis_probe_intensities = np.trace( + probe_matrix, axis1=-2, axis2=-1) + probe_intensities = np.trace( + weighted_probe_matrices, axis1=-2, axis2=-1) + + # Imaginary part is already essentially zero up to rounding error + probe_intensities = np.real( + probe_intensities / basis_probe_intensities) + + return probe_intensities + def plot_wavefront_variation(self, dataset, fig=None, mode='amplitude', **kwargs): def get_probes(idx): @@ -855,22 +906,8 @@ def get_probes(idx): if mode.lower() == 'phase': return np.angle(ortho_probes.detach().cpu().numpy()) - probe_matrix = np.zeros([self.probe.shape[0]]*2, - dtype=np.complex64) - np_probes = self.probe.detach().cpu().numpy() - for i in range(probe_matrix.shape[0]): - for j in range(probe_matrix.shape[0]): - probe_matrix[i,j] = np.sum(np_probes[i]*np_probes[j].conj()) - - weights = self.weights.detach().cpu().numpy() - - probe_intensities = np.sum(np.tensordot(weights, probe_matrix, axes=1) - * weights.conj(), axis=2) - - # Imaginary part is already essentially zero up to rounding error - probe_intensities = np.real(probe_intensities) - - values = np.sum(probe_intensities, axis=1) + values = self.get_probe_intensities() + if mode.lower() == 'amplitude' or mode.lower() == 'root_sum_intensity': cmap = 'viridis' else: @@ -887,7 +924,22 @@ def get_probes(idx): cmap=cmap, **kwargs), - + + def plot_illumination_intensity(self, fig, dataset): + """Plots the probe intensity nanomap. Only used to make a plot for the plot list.""" + p.plot_nanomap( + self.corrected_translations(dataset), + self.get_probe_intensities(), + fig=fig, + cmap='viridis', + cmap_label='Intensity (a.u.)', + units=self.units, + convention='probe', + invert_xaxis=True + ) + plt.gca().set_aspect('equal') + + def plot_translations_and_originals(self, fig, dataset): """Only used to make a plot for the plot list.""" p.plot_translations( @@ -907,112 +959,212 @@ def plot_translations_and_originals(self, fig, dataset): color='k', marker='.' ) - plt.legend() + plt.gca().set_aspect('equal') + plt.legend(loc='upper right') + plot_panel_list = [ + { + 'title': 'Main Results', + 'plot_level': 1, + 'grid': (2,2), + 'figure_size': (8.4,6.8), + 'plots': [ + { + 'title': 'Object Phase', + 'subplot': (0,0), + 'plot_func': lambda self, fig: p.plot_phase( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + additional_axis_labels=['Mode #',], + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Object Amplitude', + 'subplot': (1,0), + 'plot_func': lambda self, fig: p.plot_amplitude( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + additional_axis_labels=['Mode #',], + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Real Part of T', + 'subplot': (0,0), + 'plot_func': lambda self, fig: p.plot_real( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + additional_axis_labels=['Mode #',], + units=self.units, + cmap='cividis', + ), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Imaginary Part of T', + 'subplot': (1,0), + 'plot_func': lambda self, fig: p.plot_imag( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + additional_axis_labels=['Mode #',], + units=self.units, + cmap='viridis_r', + ), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Probe Modes, Colorized', + 'subplot': (0,1), + 'plot_func': lambda self, fig: p.plot_colorized( + (self.probe if not self.fourier_probe + else tools.propagators.inverse_far_field(self.probe)), + fig=fig, + title='Probe Modes, Real Space', + basis=self.probe_basis, + additional_axis_labels=['Mode #',], + amplitude_scaling=np.sqrt, + units=self.units), + }, + { + 'title': 'Probe Modes, Amplitude', + 'subplot': (1,1), + 'plot_func': lambda self, fig: p.plot_amplitude( + (self.probe if not self.fourier_probe + else tools.propagators.inverse_far_field(self.probe)), + fig=fig, + title='Probe Modes, Real Space', + basis=self.probe_basis, + additional_axis_labels=['Mode #',], + units=self.units), + }, + ], + }, + { + 'title': 'Advanced Monitoring', + 'plot_level': 2, + 'figure_size': (12.6,6.8), + 'grid': (2,3), + 'plots': [ + { + 'title': 'Probe Modes, Fourier Colorized', + 'subplot': (0,0), + 'plot_func': lambda self, fig: p.plot_colorized( + (self.probe if self.fourier_probe + else tools.propagators.far_field(self.probe)), + fig=fig, + title='Probe Modes, Fourier Space', + additional_axis_labels=['Mode #',], + amplitude_scaling = np.sqrt, + ), + }, + { + 'title': 'Probe Modes, Fourier Amplitude', + 'subplot': (1,0), + 'plot_func': lambda self, fig: p.plot_amplitude( + (self.probe if self.fourier_probe + else tools.propagators.far_field(self.probe)), + fig=fig, + title='Probe Modes, Fourier Space', + additional_axis_labels=['Mode #',], + ), + }, + { + 'title': 'Illumination Intensity', + 'subplot': (0,1), + 'plot_func': lambda self, fig, dataset: self.plot_illumination_intensity(fig, dataset), + }, + { + 'title': 'Detector Background', + 'subplot': (1,1), + 'plot_func': lambda self, fig: p.plot_amplitude(self.background**2, fig=fig, cmap='viridis', cmap_label='Intensity (detector units)'), + }, + { + 'title': 'Corrected Translations', + 'subplot': (0,2), + 'plot_func': lambda self, fig, dataset: self.plot_translations_and_originals(fig, dataset), + }, + { + 'title': 'Loss History', + 'subplot': (1,2), + 'plot_func': lambda self, fig: self.plot_loss_history(fig), + }, + ], + }, + { + 'title': 'Unstable Probe Refinement Details', + 'plot_level': 2, + 'figure_size': (8.4,3.4), + 'grid': (1,2), + 'condition': lambda self: len(self.weights.shape) >= 2, + 'plots': [ + { + 'title': '% of Power in Top Mode', + 'subplot': (0,0), + 'plot_func': lambda self, fig, dataset: p.plot_nanomap( + self.corrected_translations(dataset), + 100 * t.stack([ + analysis.calc_mode_power_fractions( + self.probe.data, + weight_matrix=self.weights.data[i])[0] + for i in range(self.weights.shape[0]) + ], dim=0), + fig=fig, + units=self.units), + 'condition': lambda self: len(self.weights.shape) >= 2 + }, + { + 'title': 'Mean Weight Matrix Amplitudes', + 'subplot': (0,1), + 'plot_func': lambda self, fig: p.plot_amplitude( + np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), + fig=fig), + 'condition': lambda self: len(self.weights.shape) >= 2 + }, + ] + } + ] + plot_list = [ - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + {'title': 'Quantum Efficiency Mask', + 'plot_level': 2, + 'plot_func': lambda self, fig: p.plot_amplitude(self.qe_mask, fig=fig), + 'condition': lambda self: (hasattr(self, 'qe_mask') and self.qe_mask is not None)}, + {'title': 'Per-Exposure Probe Intensity', + 'plot_level': 3, + 'figure_size': (8,5.3), + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='root_sum_intensity', image_title='Root Summed Probe Intensities', image_colorbar_title='Square Root of Intensity'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Per-Exposure Probe Amplitudes', + 'plot_level': 3, + 'figure_size': (8,5.3), + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='amplitude', image_title='Probe Amplitudes (scroll to view modes)', image_colorbar_title='Probe Amplitude'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Per-Exposure Probe Phases', + 'plot_level': 3, + 'figure_size': (8,5.3), + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='phase', image_title='Probe Phases (scroll to view modes)', image_colorbar_title='Probe Phase'), - lambda self: len(self.weights.shape) >= 2), - ('Basis Probe Fourier Space Amplitudes', - lambda self, fig: p.plot_amplitude( - (self.probe if self.fourier_probe - else tools.propagators.far_field(self.probe)), - fig=fig)), - ('Basis Probe Fourier Space Colorized', - lambda self, fig: p.plot_colorized( - (self.probe if self.fourier_probe - else tools.propagators.far_field(self.probe)) - , fig=fig)), - ('Basis Probe Real Space Amplitudes', - lambda self, fig: p.plot_amplitude( - (self.probe if not self.fourier_probe - else tools.propagators.inverse_far_field(self.probe)), - fig=fig, - basis=self.probe_basis, - units=self.units)), - ('Basis Probe Real Space Colorized', - lambda self, fig: p.plot_colorized( - (self.probe if not self.fourier_probe - else tools.propagators.inverse_far_field(self.probe)), - fig=fig, - basis=self.probe_basis, - units=self.units)), - ('Average Weight Matrix Amplitudes', - lambda self, fig: p.plot_amplitude( - np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), - fig=fig), - lambda self: len(self.weights.shape) >= 2), - ('% of Power in Top Mode', - lambda self, fig, dataset: p.plot_nanomap( - self.corrected_translations(dataset), - 100 * t.stack([ - analysis.calc_mode_power_fractions( - self.probe.data, - weight_matrix=self.weights.data[i])[0] - for i in range(self.weights.shape[0]) - ], dim=0), - fig=fig, - units=self.units), - lambda self: len(self.weights.shape) >= 2), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Object Phase', - lambda self, fig: p.plot_phase( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Real Part of T', - lambda self, fig: p.plot_real( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units, - cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Imaginary Part of T', - lambda self, fig: p.plot_imag( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: self.exponentiate_obj), - - ('Corrected Translations', - lambda self, fig, dataset: self.plot_translations_and_originals(fig, dataset)), - ('Background', - lambda self, fig: p.plot_amplitude(self.background**2, fig=fig)), - ('Quantum Efficiency Mask', - lambda self, fig: p.plot_amplitude(self.qe_mask, fig=fig), - lambda self: (hasattr(self, 'qe_mask') and self.qe_mask is not None)) + 'condition': lambda self: len(self.weights.shape) >= 2}, ] diff --git a/src/cdtools/models/multislice_2d_ptycho.py b/src/cdtools/models/multislice_2d_ptycho.py index 1d6b3cd2..d6663be0 100644 --- a/src/cdtools/models/multislice_2d_ptycho.py +++ b/src/cdtools/models/multislice_2d_ptycho.py @@ -47,9 +47,12 @@ def __init__(self, prevent_aliasing=True, phase_only=False, units='um', + panel_plot_mode=False, + plot_level=1, ): - super(Multislice2DPtycho, self).__init__() + super(Multislice2DPtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.wavelength = t.tensor(wavelength) self.detector_geometry = copy(detector_geometry) self.dz = dz @@ -154,7 +157,7 @@ def __init__(self, @classmethod - def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None): + def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None, panel_plot_mode=False, plot_level=1): wavelength = dataset.wavelength det_basis = dataset.detector_geometry['basis'] @@ -286,7 +289,7 @@ def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n surface_normal=surface_normal, probe_support=probe_support, min_translation=min_translation, - translation_offsets = translation_offsets, + translation_offsets=translation_offsets, weights=Ws, mask=mask, background=background, translation_scale=translation_scale, saturation=saturation, @@ -296,7 +299,9 @@ def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n exponentiate_obj=exponentiate_obj, units=units, fourier_probe=fourier_probe, phase_only=phase_only, - prevent_aliasing=prevent_aliasing) + prevent_aliasing=prevent_aliasing, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level) def interaction(self, index, translations): @@ -581,22 +586,22 @@ def tidy_probes(self): # Needs to be updated to allow for plotting to an existing figure - plot_list = [ - ('Probe Fourier Space Amplitude', - lambda self, fig: p.plot_amplitude(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Probe Fourier Space Phase', - lambda self, fig: p.plot_phase(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Probe Real Space Amplitude', - lambda self, fig: p.plot_amplitude(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)), - ('Probe Real Space Phase', - lambda self, fig: p.plot_phase(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)), - ('Average Weight Matrix Amplitudes', - lambda self, fig: p.plot_amplitude( + plot_list = [ + {'title': 'Probe Fourier Space Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Probe Fourier Space Phase', + 'plot_func': lambda self, fig: p.plot_phase(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Probe Real Space Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)}, + {'title': 'Probe Real Space Phase', + 'plot_func': lambda self, fig: p.plot_phase(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)}, + {'title': 'Average Weight Matrix Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), fig=fig), - lambda self: len(self.weights.shape) >= 2), - ('% of Power in Top Mode', - lambda self, fig, dataset: p.plot_nanomap( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '% of Power in Top Mode', + 'plot_func': lambda self, fig, dataset: p.plot_nanomap( self.corrected_translations(dataset), 100 * t.stack([ analysis.calc_mode_power_fractions( @@ -606,35 +611,35 @@ def tidy_probes(self): ], dim=0), fig=fig, units=self.units), - lambda self: len(self.weights.shape) >= 2), - ('Slice by Slice Real Part of T', - lambda self, fig: p.plot_real(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Slice by Slice Imaginary Part of T', - lambda self, fig: p.plot_imag(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: self.exponentiate_obj), - ('Integrated Real Part of T', - lambda self, fig: p.plot_real(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), - lambda self: (self.exponentiate_obj) and self.obj.dim() >= 3), - ('Integrated Imaginary Part of T', - lambda self, fig: p.plot_imag(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: (self.exponentiate_obj) and self.obj.dim() >= 3), - ('Slice by Slice Amplitude of Object Function', - lambda self, fig: p.plot_amplitude(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Slice by Slice Phase of Object Function', - lambda self, fig: p.plot_phase(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units,cmap='cividis'), - lambda self: not self.exponentiate_obj), - ('Amplitude of Stacked Object Function', - lambda self, fig: p.plot_amplitude(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: (not self.exponentiate_obj) and self.obj.dim() >=3), - ('Phase of Stacked Object Function', - lambda self, fig: p.plot_phase(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), - lambda self: (not self.exponentiate_obj) and self.obj.dim() >= 3), - ('Corrected Translations', - lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)), - ('Background', - lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)) + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Slice by Slice Real Part of T', + 'plot_func': lambda self, fig: p.plot_real(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Slice by Slice Imaginary Part of T', + 'plot_func': lambda self, fig: p.plot_imag(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Integrated Real Part of T', + 'plot_func': lambda self, fig: p.plot_real(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: self.exponentiate_obj and self.obj.dim() >= 3}, + {'title': 'Integrated Imaginary Part of T', + 'plot_func': lambda self, fig: p.plot_imag(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: self.exponentiate_obj and self.obj.dim() >= 3}, + {'title': 'Slice by Slice Amplitude of Object Function', + 'plot_func': lambda self, fig: p.plot_amplitude(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Slice by Slice Phase of Object Function', + 'plot_func': lambda self, fig: p.plot_phase(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Amplitude of Stacked Object Function', + 'plot_func': lambda self, fig: p.plot_amplitude(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: (not self.exponentiate_obj) and self.obj.dim() >= 3}, + {'title': 'Phase of Stacked Object Function', + 'plot_func': lambda self, fig: p.plot_phase(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: (not self.exponentiate_obj) and self.obj.dim() >= 3}, + {'title': 'Corrected Translations', + 'plot_func': lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)}, + {'title': 'Background', + 'plot_func': lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)}, ] diff --git a/src/cdtools/models/multislice_ptycho.py b/src/cdtools/models/multislice_ptycho.py index afcd83e2..cec773cc 100644 --- a/src/cdtools/models/multislice_ptycho.py +++ b/src/cdtools/models/multislice_ptycho.py @@ -39,10 +39,13 @@ def __init__(self, simulate_finite_pixels=False, dtype=t.float32, exponentiate_obj=False, - obj_view_crop=0 + obj_view_crop=0, + panel_plot_mode=False, + plot_level=1, ): - super(MultislicePtycho, self).__init__() + super(MultislicePtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.register_buffer('wavelength', t.as_tensor(wavelength, dtype=dtype)) self.store_detector_geometry(detector_geometry, @@ -202,6 +205,8 @@ def from_dataset(cls, obj_view_crop=None, obj_padding=200, exponentiate_obj=False, + panel_plot_mode=False, + plot_level=1, ): wavelength = dataset.wavelength @@ -409,6 +414,8 @@ def from_dataset(cls, simulate_finite_pixels=simulate_finite_pixels, exponentiate_obj=exponentiate_obj, obj_view_crop=obj_view_crop, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) @@ -750,61 +757,61 @@ def get_probes(idx): plot_list = [ - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + {'title': '', + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='root_sum_intensity', image_title='Root Summed Probe Intensities', image_colorbar_title='Square Root of Intensity'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '', + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='amplitude', image_title='Probe Amplitudes (scroll to view modes)', image_colorbar_title='Probe Amplitude'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '', + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='phase', image_title='Probe Phases (scroll to view modes)', image_colorbar_title='Probe Phase'), - lambda self: len(self.weights.shape) >= 2), - ('Basis Probe Fourier Space Amplitudes', - lambda self, fig: p.plot_amplitude( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Basis Probe Fourier Space Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( (self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), - fig=fig)), - ('Basis Probe Fourier Space Phases', - lambda self, fig: p.plot_phase( + fig=fig)}, + {'title': 'Basis Probe Fourier Space Phases', + 'plot_func': lambda self, fig: p.plot_phase( (self.probe if self.fourier_probe - else tools.propagators.inverse_far_field(self.probe)) - , fig=fig)), - ('Basis Probe Real Space Amplitudes', - lambda self, fig: p.plot_amplitude( + else tools.propagators.inverse_far_field(self.probe)), + fig=fig)}, + {'title': 'Basis Probe Real Space Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, basis=self.probe_basis, - units=self.units)), - ('Basis Probe Real Space Phases', - lambda self, fig: p.plot_phase( + units=self.units)}, + {'title': 'Basis Probe Real Space Phases', + 'plot_func': lambda self, fig: p.plot_phase( (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, basis=self.probe_basis, - units=self.units)), - ('Average Weight Matrix Amplitudes', - lambda self, fig: p.plot_amplitude( + units=self.units)}, + {'title': 'Average Weight Matrix Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), fig=fig), - lambda self: len(self.weights.shape) >= 2), - ('% of Power in Top Mode', - lambda self, fig, dataset: p.plot_nanomap( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '% of Power in Top Mode', + 'plot_func': lambda self, fig, dataset: p.plot_nanomap( self.corrected_translations(dataset), 100 * t.stack([ analysis.calc_mode_power_fractions( @@ -814,69 +821,69 @@ def get_probes(idx): ], dim=0), fig=fig, units=self.units), - lambda self: len(self.weights.shape) >= 2), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Object Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Imaginary Part', - lambda self, fig: p.plot_imag( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Imaginary Part', + 'plot_func': lambda self, fig: p.plot_imag( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: self.exponentiate_obj), - ('Object Phase', - lambda self, fig: p.plot_phase( + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Object Phase', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Real Part', - lambda self, fig: p.plot_real( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Real Part', + 'plot_func': lambda self, fig: p.plot_real( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Object Product Amplitude', - lambda self, fig: p.plot_amplitude( + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Object Product Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude( t.prod(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Sum Imaginary Part', - lambda self, fig: p.plot_imag( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Sum Imaginary Part', + 'plot_func': lambda self, fig: p.plot_imag( t.sum(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: self.exponentiate_obj), - ('Object Product Phase', - lambda self, fig: p.plot_phase( + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Object Product Phase', + 'plot_func': lambda self, fig: p.plot_phase( t.prod(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Sum Real Part', - lambda self, fig: p.plot_real( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Sum Real Part', + 'plot_func': lambda self, fig: p.plot_real( t.sum(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Corrected Translations', - lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)), - ('Background', - lambda self, fig: p.plot_amplitude(self.background**2, fig=fig)) + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Corrected Translations', + 'plot_func': lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)}, + {'title': 'Background', + 'plot_func': lambda self, fig: p.plot_amplitude(self.background**2, fig=fig)}, ] diff --git a/src/cdtools/models/rpi.py b/src/cdtools/models/rpi.py index 339b6236..0cf0b688 100644 --- a/src/cdtools/models/rpi.py +++ b/src/cdtools/models/rpi.py @@ -58,9 +58,12 @@ def __init__( propagation_distance=0, units='um', dtype=t.float32, + panel_plot_mode=True, + plot_level=1, ): - - super(RPI, self).__init__() + + super(RPI, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) complex_dtype = (t.ones([1], dtype=dtype) + 1j * t.ones([1], dtype=dtype)).dtype @@ -164,6 +167,8 @@ def from_dataset( phase_only=False, probe_threshold=0, dtype=t.float32, + panel_plot_mode=True, + plot_level=1, ): complex_dtype = (t.ones([1], dtype=dtype) + 1j * t.ones([1], dtype=dtype)).dtype @@ -227,7 +232,7 @@ def from_dataset( # This will be superceded later by a call to init_obj, but it sets # the shape if obj_size is None: - obj_size = (np.array(self.probe.shape[-2:]) // 2).astype(int) + obj_size = (np.array(probe.shape[-2:]) // 2).astype(int) dummy_init_obj = t.ones([n_modes, obj_size[0], obj_size[1]], dtype=complex_dtype) @@ -247,14 +252,16 @@ def from_dataset( obj_support = t.as_tensor(binary_dilation(obj_support)) rpi_object = cls(wavelength, det_geo, ew_basis, - probe, dummy_init_obj, + probe, dummy_init_obj, background=background, mask=mask, saturation=saturation, obj_support=obj_support, oversampling=oversampling, exponentiate_obj=exponentiate_obj, phase_only=phase_only, - weight_matrix=weight_matrix) + weight_matrix=weight_matrix, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level) # I don't love this pattern, where I do the "real" obj initialization # after creating the rpi object. But, I chose this so that I could @@ -283,7 +290,9 @@ def from_calibration( exponentiate_obj=False, phase_only=False, initialization='random', - dtype=t.float32 + dtype=t.float32, + panel_plot_mode=True, + plot_level=1, ): complex_dtype = (t.ones([1], dtype=dtype) + @@ -308,7 +317,7 @@ def from_calibration( # This will be superceded later by a call to init_obj, but it sets # the shape if obj_size is None: - obj_size = (np.array(self.probe.shape[-2:]) // 2).astype(int) + obj_size = (np.array(probe.shape[-2:]) // 2).astype(int) dummy_init_obj = t.ones([n_modes, obj_size[0], obj_size[1]], dtype=complex_dtype) @@ -327,6 +336,8 @@ def from_calibration( mask=mask, exponentiate_obj=exponentiate_obj, phase_only=phase_only, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) rpi_object.init_obj(initialization) @@ -365,7 +376,7 @@ def init_obj( obj_shape=obj_shape, n_modes=n_modes) else: - raise KeyError('Initialization "' + str(initialization) + \ + raise KeyError('Initialization "' + str(initialization_type) + \ '" invalid - use "spectral", "uniform", or "random"') @@ -531,40 +542,72 @@ def regularizer(self, factors): def sim_to_dataset(self, args_list): raise NotImplementedError('No sim to dataset yet, sorry!') - plot_list = [ - ('Root Sum Squared Amplitude of all Probes', - lambda self, fig: p.plot_amplitude( - np.sqrt(np.sum((t.abs(t.sum(self.weights[..., None, None].detach() * self.probe, axis=-3))**2).cpu().numpy(),axis=0)), - fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Object Phase', - lambda self, fig: p.plot_phase( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Real Part of T', - lambda self, fig: p.plot_real( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units, - cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Imaginary Part of T', - lambda self, fig: p.plot_imag( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: self.exponentiate_obj), + plot_panel_list = [ + { + 'title': 'RPI Results', + 'plot_level': 1, + 'figure_size': (12, 3.5), + 'grid': (1, 3), + 'plots': [ + { + 'title': 'Object Phase', + 'subplot': (0, 0), + 'plot_func': lambda self, fig: p.plot_phase( + self.obj, + fig=fig, + title='Object Phase', + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Real Part of T', + 'subplot': (0, 0), + 'plot_func': lambda self, fig: p.plot_real( + self.obj, + fig=fig, + title='Real Part of T', + basis=self.obj_basis, + units=self.units, + cmap='cividis'), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Object Amplitude', + 'subplot': (0, 1), + 'plot_func': lambda self, fig: p.plot_amplitude( + self.obj, + fig=fig, + title='Object Amplitude', + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Imaginary Part of T', + 'subplot': (0, 1), + 'plot_func': lambda self, fig: p.plot_imag( + self.obj, + fig=fig, + title='Imaginary Part of T', + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Root Sum Squared Amplitude of all Probes', + 'subplot': (0, 2), + 'plot_func': lambda self, fig: p.plot_amplitude( + np.sqrt(np.sum( + (t.abs(t.sum(self.weights[..., None, None].detach() + * self.probe, axis=-3))**2 + ).cpu().numpy(), axis=0)), + fig=fig, + basis=self.probe_basis, + units=self.units), + }, + ], + }, ] diff --git a/src/cdtools/models/simple_ptycho.py b/src/cdtools/models/simple_ptycho.py index 5435c662..734ba20c 100644 --- a/src/cdtools/models/simple_ptycho.py +++ b/src/cdtools/models/simple_ptycho.py @@ -76,7 +76,7 @@ def from_dataset(cls, dataset): probe_basis, probe, obj, - min_translation=min_translation + min_translation=min_translation, ) @@ -108,14 +108,26 @@ def loss(self, real_data, sim_data): # This lists all the plots to display on a call to model.inspect() plot_list = [ - ('Probe Amplitude', - lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)), - ('Probe Phase', - lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)), - ('Object Phase', - lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis)) + { + 'title': 'Probe Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.probe, fig, basis=self.probe_basis), + }, + { + 'title': 'Probe Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.probe, fig, basis=self.probe_basis) + }, + { + 'title': 'Object Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.obj, fig, basis=self.probe_basis) + }, + { + 'title': 'Object Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.obj, fig, basis=self.probe_basis) + }, ] def save_results(self, dataset): diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 16668149..4ad7931c 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -16,6 +16,7 @@ import threading import queue import time +from matplotlib import pyplot as plt from typing import List, Union if TYPE_CHECKING: @@ -232,7 +233,7 @@ def closure(): def optimize(self, iterations: int, batch_size: int = 1, - custom_data_loader: torch.utils.data.DataLoader = None, + custom_data_loader: t.utils.data.DataLoader = None, regularization_factor: Union[float, List[float]] = None, thread: bool = True, calculation_width: int = 10, @@ -357,10 +358,12 @@ def target(): try: calc.start() while calc.is_alive(): - if hasattr(self.model, 'figs'): - self.model.figs[0].canvas.start_event_loop(0.01) - else: - calc.join() + open_figs = plt.get_fignums() + with plt.rc_context({'figure.raise_window': False}): + for fignum in open_figs: + plt.figure(fignum).canvas.flush_events() + + time.sleep(0.01) except KeyboardInterrupt as e: stop_event.set() diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 0f570ea1..a4847b79 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -31,7 +31,11 @@ ] -def colorize(z): +def colorize( + z, + use_cmocean=True, + amplitude_scaling=lambda x: x, +): """ Returns RGB values for a complex color plot given a complex array This function returns a set of RGB values that can be used directly in a call to imshow based on an input complex numpy array (not a @@ -41,6 +45,11 @@ def colorize(z): ---------- z : array A complex-valued array + use_cmocean : bool + If true, uses the cmocean_phase colormap instead of hue + amplitude_scaling : callable + A function applied to the normalized amplitude before colorizing. + Default is the identity (no scaling). Returns ------- rgb : list(array) @@ -48,17 +57,28 @@ def colorize(z): """ amp = np.abs(z) - rmin = 0 - rmax = np.max(amp) - amp = np.where(amp < rmin, rmin, amp) - amp = np.where(amp > rmax, rmax, amp) - ph = np.angle(z, deg=1) + 90 - # HSV are values in range [0,1] - h = (ph % 360) / 360 - s = 0.85 * np.ones_like(h) - v = (amp - rmin) / (rmax - rmin) - - return hsv_to_rgb(np.dstack((h,s,v))) + scaled_amp = amplitude_scaling(amp / np.max(amp)) + ph = np.angle(z, deg=1) + + # The offsets for both options are chosen to match, and to place red + # at the cut (at -pi and pi) + if use_cmocean: + base_rgb_values = [] + for channel in range(3): + # Flipping the cm_data matches the usual + # order of colors from hsv + base_rgb_values.append(np.interp((ph + 100)%360, + np.linspace(0, 360, cm_data.shape [0]), + cm_data[::-1,channel])) + base_rgb_values = np.dstack(base_rgb_values) + rgb_values = base_rgb_values * scaled_amp[...,None] + return rgb_values + else: + # HSV are values in range [0,1] + h = ((ph + 170) % 360) / 360 + s = 0.85 * np.ones_like(h) + v = scaled_amp + return hsv_to_rgb(np.dstack((h,s,v))) def get_units_factor(units): @@ -106,7 +126,10 @@ def plot_image( vmin=None, vmax=None, interpolation=None, - **kwargs + title=None, + additional_axis_labels=None, + updateable_colorbar=True, + **kwargs, ): """Plots an image with a colorbar and on an appropriate spatial grid @@ -121,6 +144,10 @@ def plot_image( will be called on each slice of data before it is plotted. This is used internally to enable the plot_real, plot_image, plot_phase, etc. functions. + If the image has more than 2 dimensions, a horizontal slider is created + for each extra axis with length > 1. Up/down arrow keys navigate through + all extra axes odometer-style (last axis changes fastest). + Parameters ---------- @@ -146,8 +173,13 @@ def plot_image( Default is max(plot_func(im)), the maximum value for the colormap interpolation : str What interpolation to use for imshow + additional_axis_labels : list of str, optional + Labels for each extra axis (all dimensions except the last two). + If shorter than the number of extra axes, remaining labels default to + "Axis N". If not set, all labels default to "Axis N". \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed to plt.figure(\\**kwargs), if no figure is + provided Returns ------- @@ -157,53 +189,140 @@ def plot_image( # convert to numpy if isinstance(im, t.Tensor): - # If final dimension is 2, assume it is a complex array. If not, - # assume it represents a real array - if im.shape[-1] == 2: - im = im.detach().cpu().numpy() - else: - im = im.detach().cpu().numpy() + im = im.detach().cpu().numpy() if fig is None: - fig = plt.figure() - ax = fig.add_subplot(111, **kwargs) - - # This nukes everything and updates either the appropriate image from the - # stack of images, or the only image if only a single image has been - # given - def make_plot(idx): - plt.figure(fig.number) - title = plt.gca().get_title() - fig.clear() - - - # If im only has two dimensions, this reshape will add a leading - # dimension, and update will be called on index 0. If it has 3 or more - # dimensions, then all the leading dimensions will be compressed into - # one long dimension which can be scrolled through. - s = im.shape - reshaped_im = im.reshape(-1,s[-2],s[-1]) - num_images = reshaped_im.shape[0] - fig.plot_idx = idx % num_images + fig = plt.figure(**kwargs) + + # Determine extra (non-image) dimensions and build per-axis slider map + extra_dims = im.shape[:-2] + n_extra = len(extra_dims) + + # I have it say e.g. "0th Axis" instead of "Axis 0", because the latter one + # looks kind of confusing based on the layout that a Slider widget gets + def ordinal(n): + suffix = {1: 'st', 2: 'nd', 3: 'rd'} + def get_suffix(n): + if n % 100 not in (11, 12, 13): + return suffix.get(n % 10, 'th') + else: + return 'th' + return f"{n}{get_suffix(n)}" + + if additional_axis_labels is None: + additional_axis_labels = [f'{ordinal(i)} Axis' for i in range(n_extra)] + else: + additional_axis_labels = list(additional_axis_labels) + [ + f'{ordinal(i)} Axis' + for i in range(len(additional_axis_labels), n_extra) + ] + + # Only axes with length > 1 get sliders + slider_axis_map = [ + (i, extra_dims[i], additional_axis_labels[i]) + for i in range(n_extra) if extra_dims[i] > 1 + ] + n_sliders = len(slider_axis_map) + + def make_plot(idx_list): + # Always update fig._make_plot so slider callbacks get the latest closure + fig._make_plot = make_plot + fig.plot_idx = list(idx_list) + + selected_im = im[tuple(fig.plot_idx)] if n_extra > 0 else im + to_plot = plot_func(selected_im) + + # By only updating the data, and not redrawing the fig, we + # don't "reset" the home positions of the toolbar + if hasattr(fig, '_current_im'): + fig._current_im.set_data(to_plot) + if updateable_colorbar: + fig._current_im.autoscale() + # We need to go to the "home" position before updating it + # to include the new data, because otherwise it will store + # other axes (potentially zoomed in) positions as "home", + # which is super annoying, more so than the reset. + if fig.canvas.toolbar is not None: + fig.canvas.toolbar.home() + fig.canvas.toolbar.update() + + # Sync sliders to the new index without triggering callbacks + if hasattr(fig, '_sliders'): + fig._updating = True + for j, (axis_idx, _, _) in enumerate(slider_axis_map): + fig._sliders[j].set_val(fig.plot_idx[axis_idx]) + fig._updating = False + + # Restore image axis as the "current" axis + if hasattr(fig, '_plot_ax'): + plt.sca(fig._current_im.ax) + + return fig - to_plot = plot_func(reshaped_im[fig.plot_idx]) + if title is not None: + ax_title = title + else: + try: + ax_title = fig.axes[0].get_title() + except IndexError: + ax_title = '' + + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 + + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height + + im_ax_bottom = pad_bottom + n_sliders * 0.05 + 0.025 + im_ax_height = 1 - pad_top - im_ax_bottom + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) - mpl_im = plt.imshow( + pad_left_slider = 1.2 / total_width + pad_right_slider = 0.8 / total_width + + if n_sliders > 0: + ax_sliders = [] + for n in range(n_sliders): + ax_sliders.append( + fig.add_axes([pad_left_slider, 0.025 + n * 0.05, + 1-pad_left_slider-pad_right_slider, 0.05])) + + ax_sliders = ax_sliders[::-1] + + mpl_im = ax.imshow( to_plot, - cmap = cmap, - interpolation = interpolation, + cmap=cmap, + interpolation=interpolation, vmin=vmin, vmax=vmax, ) - plt.gca().set_facecolor('k') - + fig._current_im = mpl_im + ax.set_facecolor('k') + + # Lots of logic to deal with all sorts of wild non-orthogonal + # image basis options. Leave this alone. if basis is not None: # we've closed over basis, so we can't edit it - if isinstance(basis,t.Tensor): + if isinstance(basis, t.Tensor): np_basis = basis.detach().cpu().numpy() else: np_basis = basis - + np_basis = np_basis * get_units_factor(units) if isinstance(view_basis, str) and view_basis.lower() == 'ortho': @@ -213,103 +332,166 @@ def make_plot(idx): # y-axis lies in the x-y plane of the basis, perpendicular # to the x-axis - basis_norm = np.linalg.norm(np_basis, axis = 0) - + basis_norm = np.linalg.norm(np_basis, axis=0) + normed_basis = np_basis / basis_norm - normed_z = np.cross(normed_basis[:,1], normed_basis[:,0]) + normed_z = np.cross(normed_basis[:, 1], normed_basis[:, 0]) normed_z /= np.linalg.norm(normed_z) - normed_yprime = np.cross(normed_z, normed_basis[:,1]) + normed_yprime = np.cross(normed_z, normed_basis[:, 1]) normed_yprime /= np.linalg.norm(normed_yprime) - + np_view_basis = np.stack( - [normed_yprime, normed_basis[:,1]], axis=1) + [normed_yprime, normed_basis[:, 1]], axis=1) else: # We've also closed over view_basis, so we can't update it - if isinstance(view_basis,t.Tensor): + if isinstance(view_basis, t.Tensor): np_view_basis = view_basis.detach().cpu().numpy() else: np_view_basis = view_basis - + # We always normalize the view basis - view_basis_norm = np.linalg.norm(np_view_basis, axis = 0) + view_basis_norm = np.linalg.norm(np_view_basis, axis=0) np_view_basis = np_view_basis / view_basis_norm # Holy cow, this works! transform_matrix = \ - np.linalg.lstsq(np_view_basis[:,::-1], np_basis[:,::-1], + np.linalg.lstsq(np_view_basis[:, ::-1], np_basis[:, ::-1], rcond=None)[0] - [[a,c],[b,d]] = transform_matrix + [[a, c], [b, d]] = transform_matrix - transform = mtransforms.Affine2D.from_values(a,b,c,d,0,0) + transform = mtransforms.Affine2D.from_values(a, b, c, d, 0, 0) - trans_data = transform + plt.gca().transData + trans_data = transform + ax.transData mpl_im.set_transform(trans_data) - corners = np.array([[-0.5,-0.5], - [im.shape[-1]-0.5,-0.5], - [-0.5, im.shape[-2]-0.5], - [im.shape[-1]-0.5, im.shape[-2]-0.5]]) - corners = np.matmul(transform_matrix,corners.transpose()) + corners = np.array([[-0.5, -0.5], + [im.shape[-1] - 0.5, -0.5], + [-0.5, im.shape[-2] - 0.5], + [im.shape[-1] - 0.5, im.shape[-2] - 0.5]]) + corners = np.matmul(transform_matrix, corners.transpose()) mins = np.min(corners, axis=1) maxes = np.max(corners, axis=1) - plt.gca().set_xlim([mins[0], maxes[0]]) - plt.gca().set_ylim([mins[1], maxes[1]]) - plt.gca().invert_yaxis() + ax.set_xlim([mins[0], maxes[0]]) + ax.set_ylim([mins[1], maxes[1]]) + ax.invert_yaxis() if show_cbar: - cbar = plt.colorbar() + cbar = fig.colorbar(mpl_im, ax=ax, + fraction=0.15, + pad=0.05, + location='right') + ax.set_anchor('C') + if not updateable_colorbar: + cbar.ax.set_navigate(False) if cmap_label is not None: cbar.set_label(cmap_label) - + if basis is not None: - plt.xlabel('X (' + units + ')') - plt.ylabel('Y (' + units + ')') + ax.set_xlabel('X (' + units + ')') + ax.set_ylabel('Y (' + units + ')') else: - plt.xlabel('j (pixels)') - plt.ylabel('i (pixels)') - - - plt.title(title) - - if len(im.shape) >= 3: - plt.text(0.03, 0.03, str(fig.plot_idx), fontsize=14, transform=plt.gcf().transFigure) + ax.set_xlabel('j (pixels)') + ax.set_ylabel('i (pixels)') + + if title is not None: + ax.set_title(ax_title) + + # Create sliders for axes with length > 1 + sliders = [] + for j, (axis_idx, axis_len, label) in enumerate(slider_axis_map): + s = Slider(ax_sliders[j], label, 0, axis_len - 1, + valstep=1, valfmt='%d', + valinit=fig.plot_idx[axis_idx]) + sliders.append(s) + fig._sliders = sliders + + # Slider callbacks guarded by the _updating flag to prevent re-entry. + # Uses fig._make_plot so further plot_image calls update the closure. + def make_slider_cb(axis_idx): + def cb(val): + if getattr(fig, '_updating', False): + return + new_idx = list(fig.plot_idx) + new_idx[axis_idx] = int(val) + fig._make_plot(new_idx) + plt.draw() + return cb + + for (axis_idx, _, _), slider in zip(slider_axis_map, sliders): + slider.on_changed(make_slider_cb(axis_idx)) + + if fig.canvas.toolbar is not None: + fig.canvas.toolbar.update() + + plt.sca(ax) + return fig - if hasattr(fig, 'plot_idx'): - result = make_plot(fig.plot_idx) + if hasattr(fig, 'plot_idx') and len(fig.plot_idx) == n_extra: + result_fig = make_plot(fig.plot_idx) else: - result = make_plot(0) - - update = make_plot - + result_fig = make_plot([0] * n_extra) def on_action(event): - if not hasattr(event, 'button'): - event.button = None + # Protection for multi-subfigure situation + if event.inaxes not in fig.axes: + return if not hasattr(event, 'key'): event.key = None + if not getattr(fig, '_sliders', []): + return + + direction = None + if event.key == 'up': + direction = -1 + elif event.key == 'down': + direction = 1 + if direction is None: + return - if event.key == 'up' or event.button == 'up': - update(fig.plot_idx - 1) - elif event.key == 'down' or event.button == 'down': - update(fig.plot_idx + 1) + # Odometer-style: last entry in slider_axis_map changes fastest + new_idx = list(fig.plot_idx) + carry = direction + for axis_idx, axis_len, _ in reversed(slider_axis_map): + new_val = new_idx[axis_idx] + carry + if new_val < 0: + new_idx[axis_idx] = axis_len - 1 + carry = -1 + elif new_val >= axis_len: + new_idx[axis_idx] = 0 + carry = 1 + else: + new_idx[axis_idx] = new_val + carry = 0 + break + + make_plot(new_idx) plt.draw() - if len(im.shape) >=3: - if not hasattr(fig,'my_callbacks'): + if n_sliders > 0: + if not hasattr(fig, 'my_callbacks'): fig.my_callbacks = [] - for cid in fig.my_callbacks: fig.canvas.mpl_disconnect(cid) fig.my_callbacks = [] - fig.my_callbacks.append(fig.canvas.mpl_connect('key_press_event',on_action)) - fig.my_callbacks.append(fig.canvas.mpl_connect('scroll_event',on_action)) + fig.my_callbacks.append( + fig.canvas.mpl_connect('key_press_event', on_action) + ) - return result + return result_fig -def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Real Part (a.u.)', **kwargs): +def plot_real( + im, + fig=None, + basis=None, + units='$\\mu$m', + cmap='viridis', + cmap_label='Real Part (a.u.)', + title=None, + **kwargs, +): """Plots the real part of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -332,8 +514,10 @@ def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ Default is 'viridis', the colormap to plot with cmap_label : str What to label the colorbar when plotting + title : str, optional + Title for the axes. \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -343,11 +527,19 @@ def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ plot_func = lambda x: np.real(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - **kwargs) + title=title, **kwargs) - -def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Imaginary Part (a.u.)', **kwargs): +def plot_imag( + im, + fig=None, + basis=None, + units='$\\mu$m', + cmap='viridis', + cmap_label='Imaginary Part (a.u.)', + title=None, + **kwargs, +): """Plots the imaginary part of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -370,8 +562,10 @@ def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ Default is 'viridis', the colormap to plot with cmap_label : str What to label the colorbar when plotting + title : str, optional + Title for the axes. \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -381,10 +575,19 @@ def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ plot_func = lambda x: np.imag(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - **kwargs) + title=title, **kwargs) -def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Amplitude (a.u.)', **kwargs): +def plot_amplitude( + im, + fig=None, + basis=None, + units='$\\mu$m', + cmap='viridis', + cmap_label='Amplitude (a.u.)', + title=None, + **kwargs, +): """Plots the amplitude of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -407,8 +610,10 @@ def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', Default is 'viridis', the colormap to plot with cmap_label : str What to label the colorbar when plotting + title : str, optional + Title for the axes. \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -418,7 +623,7 @@ def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', plot_func = lambda x: np.absolute(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - **kwargs) + title=title, **kwargs) def plot_phase( @@ -430,7 +635,8 @@ def plot_phase( cmap_label='Phase (rad)', vmin=None, vmax=None, - **kwargs + title=None, + **kwargs, ): """ Plots the phase of a complex array with dimensions NxM @@ -440,8 +646,8 @@ def plot_phase( If a basis is explicitly passed, the image will be plotted in real-space coordinates - If the cmap is entered as 'phase', it will plot the cmocean phase colormap, - and by default set the limits to [-pi,pi]. + If the cmap is entered as 'phase', it will plot the cmocean phase + colormap, and by default set the limits to [-pi,pi]. Parameters ---------- @@ -463,7 +669,7 @@ def plot_phase( Default is max(angle(im)), the maximum value for the colormap \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -480,14 +686,20 @@ def plot_phase( return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - vmin=vmin,vmax=vmax, + vmin=vmin, vmax=vmax, title=title, **kwargs) -def plot_amplitude_surfacenorm(): - pass - -def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', **kwargs): +def plot_colorized( + im, + fig=None, + basis=None, + units='$\\mu$m', + title=None, + use_cmocean=True, + amplitude_scaling=lambda x: x, + **kwargs, +): """ Plots the colorized version of a complex array with dimensions NxM The darkness corresponds to the intensity of the image, and the color @@ -509,6 +721,13 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', **kwargs): Optional, the 3x2 probe basis units : str The length units to mark on the plot, default is um + title : str, optional + Title for the axes. + use_cmocean : bool + If true, uses the cmocean_phase colormap instead of hue + amplitude_scaling : callable + A function applied to the normalized amplitude before colorizing. + Default is the identity (no scaling). \\**kwargs All other args are passed to fig.add_subplot(111, \\**kwargs) @@ -517,12 +736,49 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', **kwargs): used_fig : matplotlib.figure.Figure The figure object that was actually plotted to. """ - plot_func = lambda x: colorize(x) - return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, - units=units, show_cbar=False, **kwargs) + plot_func = lambda x: colorize(x, use_cmocean=use_cmocean, + amplitude_scaling=amplitude_scaling) + plot_fig = plot_image(im, plot_func=plot_func, fig=fig, basis=basis, + cmap=cmocean_phase, vmin=-np.pi, vmax=np.pi, + cmap_label='Phase (rad)', + units=units, show_cbar=True, title=title, + updateable_colorbar=False, **kwargs) + + # Find the colorbar - this is a bit hacky + cbar_ax = [ax for ax in plot_fig.get_axes() + if hasattr(ax, '_colorbar')][0] + + # --- Replace the colorbar image --- + # The internal image is a QuadMesh living on cbar.ax + #qm = cbar.ax.collections + # Build a 2D array to match the colorbar's range, here (0,1) in x + # and (pi, pi) in y + yg = np.linspace(-np.pi, np.pi, 256) # colormap values + xg = np.linspace(0, 1, 64) # second dimension + YY, XX = np.meshgrid(yg, xg, indexing='ij') + dummy_im = XX * np.exp(1j*YY) + cbar_im = plot_func(dummy_im) + for artist in list(cbar_ax.get_children()): + if 'QuadMesh' in type(artist).__name__ or \ + 'AxesImage' in type(artist).__name__: + artist.remove() + + cbar_ax.imshow(cbar_im, origin='lower', aspect='auto', + extent=[xg[0], xg[-1], -np.pi, np.pi]) -def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, invert_xaxis=True, clear_fig=True, label=None, color=None, marker='.', **kwargs): +def plot_translations( + translations, + fig=None, + units='$\\mu$m', + lines=True, + invert_xaxis=True, + clear_fig=True, + label=None, + color=None, + marker='.', + **kwargs, +): """Plots a set of probe translations in a nicely formatted way Parameters @@ -536,13 +792,15 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver lines : bool Whether to plot lines indicating the path taken invert_xaxis : bool - Default is True. This flips the x axis to match the convention from .cxi files of viewing the image from the beam's perspective + Default is True. This flips the x axis to match the convention from + .cxi files of viewing the image from the beam's perspective clear_fig : bool Default is True. Whether to clear the figure before plotting. label : str Default is None. A label to give the plotted markers for a legend. color : str - Default is None. The color to plot the markers in. By default, will follow the matplotlib color cycle. + Default is None. The color to plot the markers in. By default, will + follow the matplotlib color cycle. color : str Default is '.'. The marker style to plot with. \\**kwargs @@ -559,12 +817,39 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver if fig is None: fig = plt.figure() - ax = fig.add_subplot(111, **kwargs) + + if clear_fig: + fig.clear() + + if len(fig.axes) >= 1: + ax = fig.axes[0] else: - plt.figure(fig.number) - if clear_fig: - plt.gcf().clear() - + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 + + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height + + im_ax_bottom = pad_bottom + im_ax_height = 1 - pad_top - im_ax_bottom + + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) + if isinstance(translations, t.Tensor): translations = translations.detach().cpu().numpy() @@ -572,25 +857,33 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver linestyle = '-' if lines else 'None' linewidth = 1 if lines else 0 - plt.plot(translations[:,0], translations[:,1], - marker=marker, linestyle=linestyle, - label=label, color=color, - linewidth=linewidth) + ax.plot(translations[:,0], translations[:,1], + marker=marker, linestyle=linestyle, + label=label, color=color, + linewidth=linewidth) if invert_xaxis: - ax = plt.gca() x_min, x_max = ax.get_xlim() # Protect against flipping twice if plotting on top of existing graph if x_min <= x_max: ax.invert_xaxis() - plt.xlabel('X (' + units + ')') - plt.ylabel('Y (' + units + ')') + ax.set_xlabel('X (' + units + ')') + ax.set_ylabel('Y (' + units + ')') return fig -def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='probe', invert_xaxis=True): +def plot_nanomap( + translations, + values, + fig=None, + cmap='viridis', + cmap_label=None, + units='$\\mu$m', + convention='probe', + invert_xaxis=True, +): """Plots a set of nanomap data in a flexible way Parameters @@ -601,12 +894,18 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr A length-N object of values associated with the translations fig : matplotlib.figure.Figure Default is a new figure, a matplotlib figure to use to plot + cmap : str + Default is 'viridis', the colormap to plot with + cmap_label : str + Default is no label, what to label the colorbar when plotting. units : str Default is um, units to report in (assuming input in m) convention : str - Default is 'probe', alternative is 'obj'. Whether the translations refer to the probe or object. + Default is 'probe', alternative is 'obj'. Whether the translations + refer to the probe or object. invert_xaxis : bool - Default is True. This flips the x axis to match the convention from .cxi files of viewing the image from the beam's perspective + Default is True. This flips the x axis to match the convention from + .cxi files of viewing the image from the beam's perspective Returns ------- @@ -616,13 +915,11 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr if fig is None: fig = plt.figure() - else: - plt.figure(fig.number) - plt.gcf().clear() - + + fig.clear() factor = get_units_factor(units) - bbox = fig.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + if isinstance(translations, t.Tensor): trans = translations.detach().cpu().numpy() else: @@ -636,22 +933,74 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr if convention.lower() != 'probe': trans = trans * -1 - s = bbox.width * bbox.height / trans.shape[0] * 72**2 #72 is points per inch - s /= 4 # A rough value to make the size work out - plt.scatter(factor * trans[:,0],factor * trans[:,1],s=s,c=values) + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 + + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height + + im_ax_bottom = pad_bottom + im_ax_height = 1 - pad_top - im_ax_bottom + + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) + + s = total_width * total_height / trans.shape[0] * 72**2 + s /= 4 # A rough value to make the size work out + + scatter_plot = ax.scatter( + factor * trans[:,0],factor * trans[:,1],s=s,c=values, cmap=cmap) if invert_xaxis: - plt.gca().invert_xaxis() + ax.invert_xaxis() - plt.gca().set_facecolor('k') - plt.xlabel('Translation x (' + units + ')') - plt.ylabel('Translation y (' + units + ')') - plt.colorbar() + ax.set_facecolor('k') + ax.set_xlabel('Translation x (' + units + ')') + ax.set_ylabel('Translation y (' + units + ')') + cbar = fig.colorbar( + scatter_plot, + ax=ax, + fraction=0.15, + pad=0.05, + location='right', + ) + ax.set_anchor('C') + if cmap_label is not None: + cbar.set_label(cmap_label) return fig -def plot_nanomap_with_images(translations, get_image_func, values=None, mask=None, basis=None, fig=None, nanomap_units='$\\mu$m', image_units='$\\mu$m', convention='probe', image_title='Image', image_colorbar_title='Image Amplitude', nanomap_colorbar_title='Integrated Intensity', cmap='viridis', **kwargs): +def plot_nanomap_with_images( + translations, + get_image_func, + values=None, + mask=None, + basis=None, + fig=None, + nanomap_units='$\\mu$m', + image_units='$\\mu$m', + convention='probe', + image_title='Image', + image_colorbar_title='Image Amplitude', + nanomap_colorbar_title='Integrated Intensity', + cmap='viridis', + **kwargs, +): """Plots a nanomap, with an image or stack of images for each point In many situations, ptychography data or the output of ptychography @@ -672,30 +1021,35 @@ def plot_nanomap_with_images(translations, get_image_func, values=None, mask=Non if fig is None: fig = plt.figure(figsize=(8,5.3)) else: - plt.figure(fig.number) - plt.gcf().clear() + if plt.fignum_exists(fig.number): + fig = plt.figure(fig.number) + else: + fig = plt.figure(fig.number, + figsize=(8,5.3)) + fig.clear() if hasattr(fig, 'nanomap_cids'): for cid in fig.nanomap_cids: fig.canvas.mpl_disconnect(cid) # Does figsize work with the fig.subplots, or just for plt.subplots? - axes = fig.subplots(1,2) + gs = fig.add_gridspec(2, 2, height_ratios=[0.92,0.08], width_ratios=[1,1], + bottom=0.04) - fig.tight_layout(rect=[0.04, 0.09, 0.98, 0.96]) - plt.subplots_adjust(wspace=0.25) #avoids overlap of labels with plots - axslider = plt.axes([0.15,0.06,0.75,0.03]) + axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])] + axslider = fig.add_subplot(gs[1, :]) # full width # This gets the set of sizes for the points in the nanomap def calculate_sizes(idx): - bbox = axes[0].get_window_extent().transformed(fig.dpi_scale_trans.inverted()) - s0 = bbox.width * bbox.height / translations.shape[0] * 72**2 #72 is points per inch + bbox = axes[0].get_window_extent().transformed( + fig.dpi_scale_trans.inverted()) + # 72 is points per inch + s0 = bbox.width * bbox.height / translations.shape[0] * 72**2 s0 /= 4 # A rough value to make the size work out s = np.ones(translations.shape[0]) * s0 s[idx] *= 4 return s - def update_colorbar(im): # # This solves the problem of the colorbar being changed @@ -720,7 +1074,11 @@ def update_colorbar(im): # First we set up the left-hand plot, which shows an overview map axes[0].set_title('Relative Displacement Map') - translations = translations.detach().cpu().numpy() + if isinstance(translations, t.Tensor): + translations = translations.detach().cpu().numpy() + + if isinstance(values, t.Tensor): + values = values.detach().cpu().numpy() if convention.lower() != 'probe': translations = translations * -1 @@ -728,30 +1086,37 @@ def update_colorbar(im): s = calculate_sizes(0) nanomap_units_factor = get_units_factor(nanomap_units) + + # Suppresses a warning from ax.scatter() + if values is None: + cmap = None nanomap = axes[0].scatter(nanomap_units_factor * translations[:,0], nanomap_units_factor * translations[:,1], - s=s,c=values, picker=True) + s=s, c=values, picker=True, cmap=cmap) axes[0].invert_xaxis() axes[0].set_facecolor('k') - axes[0].set_xlabel('Translation x ('+nanomap_units+')', labelpad=1) - axes[0].set_ylabel('Translation y ('+nanomap_units+')', labelpad=1) + axes[0].set_xlabel('Translation x ('+nanomap_units+')') + axes[0].set_ylabel('Translation y ('+nanomap_units+')') + axes[0].set_aspect('equal') cb1 = plt.colorbar(nanomap, ax=axes[0], orientation='horizontal', format='%.2e', ticks=ticker.LinearLocator(numticks=5), - pad=0.17,fraction=0.1) + pad=0.19,fraction=0.1) cb1.ax.set_title(nanomap_colorbar_title, size="medium", pad=5) cb1.ax.tick_params(labelrotation=20) if values is None: # This seems to do a good job of leaving the appropriate space # where the colorbar should have been to avoid stretching the # nanomap plot, while still not showing the (now useless) colorbar. + pos = axes[0].get_position() cb1.remove() + axes[0].set_position(pos) # Now we set up the second plot, which shows the individual # diffraction patterns axes[1].set_title(image_title) - #Plot in a basis if it exists, otherwise dont + # Plot in a basis if it exists, otherwise dont if basis is not None: axes[1].set_xlabel('X (' + image_units + ')') axes[1].set_ylabel('Y (' + image_units + ')') @@ -794,7 +1159,7 @@ def update_colorbar(im): cb2 = plt.colorbar(meas, ax=axes[1], orientation='horizontal', format='%.2e', ticks=ticker.LinearLocator(numticks=5), - pad=0.17,fraction=0.1) + pad=0.19,fraction=0.1) cb2.ax.tick_params(labelrotation=20) cb2.ax.set_title(image_colorbar_title, size="medium", pad=5) cb2.ax.callbacks.connect('xlim_changed', lambda ax: update_colorbar(meas)) @@ -812,9 +1177,9 @@ def update(idx, im_idx=None): # Get the new data for this index im = get_image_func(idx) if len(im.shape) >= 3: - if im_idx == None and hasattr(axes[1],'image_idx'): + if im_idx is None and hasattr(axes[1],'image_idx'): im_idx = axes[1].image_idx - elif im_idx == None: + elif im_idx is None: im_idx=0 axes[1].image_idx = im_idx axes[1].text_box.set_text(str(im_idx)) @@ -829,8 +1194,7 @@ def update(idx, im_idx=None): ax_im.set_data(im) ax_im.norecurse=False update_colorbar(ax_im) - #plt.draw() - + # plt.draw() # # Now we define the functions to handle various kinds of events @@ -839,7 +1203,10 @@ def update(idx, im_idx=None): # We start by creating the slider here, so it can be used # by the update hooks. - slider = Slider(axslider, 'Image #', 0, translations.shape[0]-1, valstep=1, valfmt="%d") + slider = Slider( + axslider,'Image #', 0, translations.shape[0]-1, + valstep=1, valfmt="%d" + ) # This handles scroll wheel and keypress events def on_action(event): @@ -851,13 +1218,13 @@ def on_action(event): im = im.reshape(-1,im.shape[-2],im.shape[-1]) im_idx = axes[1].image_idx - if (event.key == 'up' - or (hasattr(event, 'button') and event.button == 'up') - or event.key == 'left'): + if event.key == "left" or ( + hasattr(event, "button") and event.button == "left" + ): im_idx = (im_idx - 1) % im.shape[0] - if (event.key == 'down' - or (hasattr(event, 'button') and event.button == 'down') - or event.key == 'right'): + if event.key == "right" or ( + hasattr(event, "button") and event.button == "right" + ): im_idx = (im_idx + 1) % im.shape[0] axes[1].image_idx=im_idx @@ -871,9 +1238,13 @@ def on_action(event): if not hasattr(event, 'key'): event.key = None - if event.key == 'up' or event.button == 'up' or event.key == 'left': + if event.key == "left" or ( + hasattr(event, "button") and event.button == "left" + ): idx = slider.val - 1 - elif event.key == 'down' or event.button == 'down' or event.key == 'right': + elif event.key == "right" or ( + hasattr(event, "button") and event.button == "right" + ): idx = slider.val + 1 else: # This prevents errors from being thrown on irrelevant key @@ -891,7 +1262,6 @@ def on_pick(event): if event.mouseevent.button == 1: slider.set_val(event.ind[0]) - # Here we connect the various update functions cid1 = fig.canvas.mpl_connect('pick_event',on_pick) cid2 = fig.canvas.mpl_connect('key_press_event',on_action) @@ -1197,8 +1567,8 @@ def on_pick(event): [ 0.65121289, 0.47406244, 0.05044367], [ 0.65830839, 0.46993917, 0.04941288]] -rgb = np.array(cm_data) -rgb_with_alpha = np.zeros((rgb.shape[0],4)) -rgb_with_alpha[:,:3] = rgb +cm_data = np.array(cm_data) +rgb_with_alpha = np.zeros((cm_data.shape[0],4)) +rgb_with_alpha[:,:3] = cm_data rgb_with_alpha[:,3] = 1. #set alpha channel to 1 -cmocean_phase = colors.ListedColormap(rgb_with_alpha, N=rgb.shape[0]) +cmocean_phase = colors.ListedColormap(rgb_with_alpha, N=cm_data.shape[0]) diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index db05a75e..081d85b7 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -1,5 +1,7 @@ import pytest +import time import torch as t +from matplotlib import pyplot as plt import cdtools @@ -67,6 +69,8 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): units='mm', obj_view_crop=-50, use_qe_mask=True, # test this in the case where no qe mask is defined + panel_plot_mode=True, # test with panel plot mode, + plot_level=4, # test with all plots ) print('Running reconstruction on provided reconstruction_device,', @@ -76,24 +80,26 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) for loss in model.Adam_optimize(25, dataset, lr=0.001, batch_size=50): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # If this fails, the reconstruction has gotten worse assert model.loss_history[-1] < 0.0013 @@ -110,6 +116,7 @@ def test_near_field_ptycho(near_field_ptycho_cxi, reconstruction_device, show_pl n_modes=1, near_field=True, propagation_distance=3.65e-3, # 3.65 downstream from focus + panel_plot_mode=False, # test without panel plot mode ) print('Running reconstruction on provided reconstruction_device,', @@ -119,19 +126,21 @@ def test_near_field_ptycho(near_field_ptycho_cxi, reconstruction_device, show_pl for loss in model.Adam_optimize(100, dataset, lr=0.04, batch_size=10): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # If this fails, the reconstruction has gotten worse assert model.loss_history[-1] < 0.005 diff --git a/tests/models/test_simple_ptycho.py b/tests/models/test_simple_ptycho.py index b6b18680..225e463a 100644 --- a/tests/models/test_simple_ptycho.py +++ b/tests/models/test_simple_ptycho.py @@ -1,5 +1,7 @@ import pytest +import time import torch as t +from matplotlib import pyplot as plt import cdtools @@ -18,12 +20,14 @@ def test_simple_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): for loss in model.Adam_optimize(100, dataset, batch_size=10): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # If this fails, the reconstruction got worse assert model.loss_history[-1] < 0.013 diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py index 132a8c27..3f9245b6 100644 --- a/tests/test_reconstructors.py +++ b/tests/test_reconstructors.py @@ -1,4 +1,5 @@ import pytest +import time import cdtools import torch as t import numpy as np @@ -36,7 +37,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): probe_support_radius=50, propagation_distance=2e-6, units='um', - probe_fourier_crop=pad + probe_fourier_crop=pad, + panel_plot_mode=False, # At least one check without panel plot mode ) model.translation_offsets.data += 0.7 * \ @@ -67,8 +69,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr_tup[i], batch_size=batch_size_tup[i]): print(model_recon.report()) - if show_plot and model_recon.epoch % 10 == 0: - model_recon.inspect(dataset) + if show_plot: + model_recon.inspect(dataset, min_interval=10) # Check hyperparameter update assert recon.optimizer.param_groups[0]['lr'] == lr_tup[i] @@ -86,6 +88,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): if show_plot: model_recon.inspect(dataset) model_recon.compare(dataset) + time.sleep(3) + plt.close('all') # ******* Reconstructions with CDIModel.Adam_optimize ******* print('Running reconstruction using CDIModel.Adam_optimize on provided' + @@ -99,14 +103,16 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr_tup[i], batch_size=batch_size_tup[i]): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # Ensure equivalency between the model reconstructions during the first # pass, where they should be identical @@ -170,8 +176,8 @@ def test_LBFGS_RPI(optical_data_ss_cxi, for loss in recon.optimize(iterations, lr=0.4, regularization_factor=reg_factor_tup[i]): - if show_plot and i == 0: - model_recon.inspect(dataset) + if show_plot: + model_recon.inspect(dataset, min_interval=10) print(model_recon.report()) # Check hyperparameter update (or lack thereof) @@ -180,6 +186,8 @@ def test_LBFGS_RPI(optical_data_ss_cxi, if show_plot: model_recon.inspect(dataset) model_recon.compare(dataset) + time.sleep(3) + plt.close('all') # Check model pointing assert id(model_recon) == id(recon.model) @@ -193,13 +201,15 @@ def test_LBFGS_RPI(optical_data_ss_cxi, dataset, lr=0.4, regularization_factor=reg_factor_tup[i]): # noqa - if show_plot and i == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) print(model.report()) if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # Check loss equivalency between the two reconstructions assert np.allclose(model.loss_history[:epoch_tup[0]], model_recon.loss_history[:epoch_tup[0]]) @@ -271,8 +281,8 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr, batch_size=batch_size): print(model_recon.report()) - if show_plot and model_recon.epoch % 10 == 0: - model_recon.inspect(dataset) + if show_plot: + model_recon.inspect(dataset, min_interval=10) # Check hyperparameter update assert recon.optimizer.param_groups[0]['lr'] == lr @@ -290,6 +300,8 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): if show_plot: model_recon.inspect(dataset) model_recon.compare(dataset) + time.sleep(3) + plt.close('all') # ******* Reconstructions with cdtools.CDIModel.SGD_optimize ******* print('Running reconstruction using CDIModel.SGD_optimize on provided' + @@ -301,14 +313,16 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr, batch_size=batch_size): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # Ensure equivalency between the model reconstructions assert np.allclose(model_recon.loss_history[-1], model.loss_history[-1]) diff --git a/tests/tools/test_plotting.py b/tests/tools/test_plotting.py index baca33dd..6b5b7519 100644 --- a/tests/tools/test_plotting.py +++ b/tests/tools/test_plotting.py @@ -11,28 +11,32 @@ def test_plot_amplitude(show_plot): # Test with tensor im = t.as_tensor(scipy.datasets.ascent(), dtype=t.complex128) plotting.plot_amplitude(im, basis=np.array([[0, -1], [-1, 0], [0, 0]]), title='Test Amplitude') - if show_plot: - plt.show() - - # Test with numpy array - im = scipy.datasets.ascent().astype(np.complex128) + + # Test with numpy array and an extra dimension + im = np.stack([scipy.datasets.ascent().astype(np.complex128)]*3, axis=0) plotting.plot_amplitude(im, title='Test Amplitude') + + # Test with pytorch tensor and two extra dimensions + im = t.as_tensor(np.stack([im]*5, axis=0)) + plotting.plot_amplitude(im, title='Test Amplitude', + additional_axis_labels=['Hi','There']) + if show_plot: plt.show() - + plt.close('all') def test_plot_phase(show_plot): # Test with tensor im = initializers.gaussian([512, 512], [200, 200], amplitude=100, curvature=[.1, .1]) plotting.plot_phase(im, title='Test Phase') - if show_plot: - plt.show() - + # Test with numpy array im = initializers.gaussian([512, 512], [200, 200], amplitude=100, curvature=[.1, .1]).numpy() plotting.plot_phase(im, title='Test Phase', basis=np.array([[0, -1], [-1, 0], [0, 0]])) + if show_plot: plt.show() + plt.close('all') def test_plot_colorized(show_plot): @@ -40,11 +44,74 @@ def test_plot_colorized(show_plot): gaussian = initializers.gaussian([512, 512], [200, 200], amplitude=100, curvature=[.1, .1]) im = gaussian * t.as_tensor(scipy.datasets.ascent(), dtype=t.complex64) plotting.plot_colorized(im, title='Test Colorize', basis=np.array([[0, -1], [-1, 0], [0, 0]])) + + # Test with numpy array and hsv + im = im.numpy() + plotting.plot_colorized(im, title='Test Colorize', use_cmocean=False) + if show_plot: plt.show() + plt.close('all') - # Test with numpy array - im = im.numpy() - plotting.plot_colorized(im, title='Test Colorize') + +def test_plot_translations(show_plot): + rng = np.random.default_rng(0) + trans_np = rng.uniform(-5e-6, 5e-6, (20, 2)) + trans_t = t.as_tensor(trans_np) + + # numpy, defaults + plotting.plot_translations(trans_np) + + # torch tensor and reuse figure + fig = plotting.plot_translations(trans_t) + plotting.plot_translations(trans_np, lines=False, color='red', label='scan', fig=fig, clear_fig=False) + + if show_plot: + plt.show() + plt.close('all') + +def test_plot_nanomap(show_plot): + rng = np.random.default_rng(0) + trans_np = rng.uniform(-5e-6, 5e-6, (20, 2)) + values_np = np.random.default_rng(1).uniform(0, 1, 20) + trans_t = t.as_tensor(trans_np) + values_t = t.as_tensor(values_np) + + # numpy, defaults + plotting.plot_nanomap(trans_np, values_np) + + # torch tensors + plotting.plot_nanomap(trans_t, values_t, units='nm', cmap_label='Intensity', convention='sample') + + if show_plot: + plt.show() + plt.close('all') + + +def test_plot_nanomap_with_images(show_plot): + rng = np.random.default_rng(0) + trans_np = rng.uniform(-5e-6, 5e-6, (20, 2)) + values_np = np.random.default_rng(1).uniform(0, 1, 20) + # plot_nanomap_with_images requires tensor translations + trans_t = t.as_tensor(trans_np) + values_t = t.as_tensor(values_np) + + def get_image_2d(i): + return np.random.default_rng(i).uniform(0, 1, (32, 32)) + + def get_image_3d(i): + return np.random.default_rng(i).uniform(0, 1, (4, 32, 32)) + + # basic call, no values + plotting.plot_nanomap_with_images(trans_np, get_image_2d) + + # with explicit values + plotting.plot_nanomap_with_images(trans_t, get_image_2d, values=values_np) + + # 3D image stack + fig = plt.figure(figsize=(11,7)) + plotting.plot_nanomap_with_images(trans_np, get_image_3d, values=values_t, fig=fig) + if show_plot: plt.show() + plt.close('all')