From f3581b88dbf9f1284bb33f88882750f0bfa3f332 Mon Sep 17 00:00:00 2001 From: allevitan Date: Fri, 20 Mar 2026 09:47:38 +0100 Subject: [PATCH 01/23] A first prototype from claude to test on Ra --- src/cdtools/models/base.py | 283 +++++++++++++++++++------ src/cdtools/models/simple_ptycho.py | 76 +++++-- src/cdtools/reconstructors/base.py | 15 +- src/cdtools/tools/plotting/plotting.py | 51 +++-- 4 files changed, 328 insertions(+), 97 deletions(-) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index df347b63..f2823b7c 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -58,12 +58,14 @@ class CDIModel(t.nn.Module): functions. """ - def __init__(self): + def __init__(self, panel_plot_mode=False, plot_level=np.inf): super(CDIModel, self).__init__() self.loss_history = [] self.training_history = '' self.epoch = 0 + self.panel_plot_mode = panel_plot_mode + self.plot_level = plot_level def from_dataset(self, dataset): raise NotImplementedError() @@ -547,96 +549,238 @@ def report(self): return msg - # By default, the plot_list is empty + # By default, the plot lists are empty + plot_panel_list = [] plot_list = [] - def inspect(self, dataset=None, update=True): - """Plots all the plots defined in the model's plot_list attribute + def inspect(self, dataset=None, replot_all=False): + """Plots all the plots defined in the model's plot_panel_list and plot_list attributes - If update is set to True, it will update any previously plotted set - of plots, if one exists, and then redraw them. Otherwise, it will - plot a new set, and any subsequent updates will update the new set + Updates any previously plotted figures that are still open. Figures + that have been closed are left closed unless replot_all=True. Optionally, a dataset can be passed, which will allow plotting of any registered plots which need to incorporate some information from the dataset (such as geometry or a comparison with measured data). - Plots can be registered in any subclass by defining the plot_list - attribute. This should be a list of tuples in the following format: - ( 'Plot Title', function_to_generate_plot(self), - function_to_determine_whether_to_plot(self)) + Plots can be registered in any subclass by defining plot_panel_list + and/or plot_list class attributes. See the CDIModel documentation for + the expected dict-based format of each. - Where the third element in the tuple (a function that returns - True if the plot is relevant) is not required. + When panel_plot_mode=True (set in __init__), plot_panel_list entries + are rendered as multi-subplot figures. When False (the default), + each subplot in plot_panel_list is rendered as its own figure, + prepended to any standalone plot_list entries. + + The plot_level attribute (set in __init__, default np.inf) controls + which plots are shown: a panel or standalone plot is only shown when + its plot_level <= self.plot_level. Parameters ---------- dataset : CDataset Optional, a dataset matched to the model type - update : bool, default: True - Whether to update existing plots or plot new ones + replot_all : bool, default: False + If True, recreate figures that were previously closed by the user. """ - # We find or create all the figures - first_update = False - if update and hasattr(self, 'figs') and self.figs: - figs = self.figs - elif update: - figs = None - self.figs = [] - first_update = True + plot_panel_list = getattr(self, 'plot_panel_list', None) or [] + plot_list = getattr(self, 'plot_list', None) or [] + + if self.panel_plot_mode and plot_panel_list: + self._inspect_panel(dataset=dataset, replot_all=replot_all) else: - figs = None - self.figs = [] + # Flatten plot_panel_list, assigning each subplot the panel's plot_level, + # then prepend to plot_list + flat = [] + for panel in plot_panel_list: + panel_level = panel.get('plot_level', 0) + for plot in panel['plots']: + flat.append({**plot, 'plot_level': panel_level}) + all_plots = flat + list(plot_list) + + if not hasattr(self, '_flat_fig_map'): + self._flat_fig_map = {} + + self.figs = self._do_inspect(all_plots, self._flat_fig_map, + dataset=dataset, + replot_all=replot_all) + + plt.pause(0.05) + + + def _do_inspect(self, plot_list, fig_map, dataset=None, replot_all=False): + """Core one-figure-per-plot rendering logic. + + fig_map is a dict {title: figure} owned by the caller and updated + in-place. It tracks which figures are open across calls. + + Behaviour: + replot_all=False — closed figures are skipped (left closed). + replot_all=True — closed figures are recreated. + + Returns the list of figures that were rendered this call. + """ + if not plot_list: + return [] + + rendered = [] + + for plot in plot_list: + # Level filter + if plot.get('plot_level', 0) > self.plot_level: + continue + + # Condition check + condition = plot.get('condition', None) + if condition is not None: + try: + if not condition(self): + continue + except TypeError: + if not condition(self, dataset): + continue + + title = plot['title'] + fig = fig_map.get(title) + + if fig is not None and not plt.fignum_exists(fig.number): + # Figure was closed by the user + if replot_all: + fig = None + del fig_map[title] + else: + continue # leave it closed + + if fig is None: + fig = plt.figure(num=title) + fig._panel_label = title + fig_map[title] = fig - idx = 0 - for plots in self.plot_list: - # If a conditional is included in the plot, we check whether - # it is True try: - if len(plots) >=3 and not plots[2](self): - continue - except TypeError as e: - if len(plots) >= 3 and not plots[2](self, dataset): - continue - - name = plots[0] - plotter = plots[1] - - if figs is None: - fig = plt.figure() - self.figs.append(fig) - else: - fig = figs[idx] - - - try: # We try just plotting using the simplest allowed signature - plotter(self,fig) - plt.title(name) - except TypeError as e: - # TypeError implies it wanted another argument, i.e. a dataset + plot['plot_func'](self, fig) + plt.title(title) + except TypeError: if dataset is not None: try: - plotter(self, fig, dataset) - plt.title(name) - except Exception as e: # Don't raise errors: it's just plots + plot['plot_func'](self, fig, dataset) + plt.title(title) + except Exception: pass + except Exception: + pass - except Exception as e: # Don't raise errors, it's just a plot + rendered.append(fig) + try: + fig.canvas.draw_idle() + except Exception: pass - idx += 1 + return rendered - if update: - # This seems to update the figure without blocking. - plt.draw() - fig.canvas.start_event_loop(0.001) - if first_update: - # But this is needed the first time the figures update, or - # they won't get drawn at all - plt.pause(0.05 * len(self.figs)) + def _inspect_panel(self, dataset=None, replot_all=False): + """Multi-subplot panel rendering. + + Creates one figure per plot_panel_list entry, placing each subplot's + plot_func output into the appropriate axes. Closed panels stay closed + on subsequent calls unless replot_all=True. Standalone plot_list + entries are then rendered via _do_inspect and appended to self.figs. + """ + plot_panel_list = getattr(self, 'plot_panel_list', None) or [] + plot_list = getattr(self, 'plot_list', None) or [] + n_panels = len(plot_panel_list) + + # _panel_figs: list of figures (or None if never created / closed). + # _panel_axes: dict keyed by (panel_idx, row, col) → Axes. + # _standalone_fig_map: dict {title: figure} for standalone plot_list. + first_call = not hasattr(self, '_panel_figs') + if first_call: + self._panel_figs = [None] * n_panels + self._panel_axes = {} + self._standalone_fig_map = {} + + if not hasattr(self, '_standalone_fig_map'): + self._standalone_fig_map = {} + + for panel_idx, panel_def in enumerate(plot_panel_list): + panel_level = panel_def.get('plot_level', 0) + if panel_level > self.plot_level: + continue # skip entire panel + + nrows, ncols = panel_def['grid'] + figsize = panel_def.get('figure_size', None) + title = panel_def.get('title', '') + + fig = self._panel_figs[panel_idx] + + # Detect if a previously open figure was closed by the user. + if fig is not None and not plt.fignum_exists(fig.number): + self._panel_figs[panel_idx] = None + for k in [k for k in self._panel_axes if k[0] == panel_idx]: + del self._panel_axes[k] + fig = None + + if fig is None: + if not first_call and not replot_all: + continue # was closed; leave it closed + fig = plt.figure(num=title, figsize=figsize) + fig._panel_label = title + self._panel_figs[panel_idx] = fig + else: + # Remove all axes and recreate them fresh each update. + # plt.colorbar() shrinks the parent axes to make room for + # itself, so clearing and recreating is simpler than trying + # to undo that resizing. + for ax in list(fig.axes): + ax.remove() + for k in [k for k in self._panel_axes if k[0] == panel_idx]: + del self._panel_axes[k] + + for plot in panel_def['plots']: + condition = plot.get('condition', None) + if condition is not None: + try: + if not condition(self): + continue + except TypeError: + if not condition(self, dataset): + continue + + row, col = plot['subplot'] + position = row * ncols + col + 1 # 1-indexed for matplotlib + + ax_key = (panel_idx, row, col) + ax = fig.add_subplot(nrows, ncols, position) + self._panel_axes[ax_key] = ax + + try: + plot['plot_func'](self, ax) + ax.set_title(plot['title']) + except TypeError: + if dataset is not None: + try: + plot['plot_func'](self, ax, dataset) + ax.set_title(plot['title']) + except Exception: + pass + except Exception: + pass + + try: + fig.canvas.draw_idle() + except Exception: + pass + + # Rebuild self.figs from open panel figures + rendered standalone figures. + panel_figs = [f for f in self._panel_figs if f is not None] + standalone_rendered = self._do_inspect( + list(plot_list), self._standalone_fig_map, + dataset=dataset, replot_all=replot_all, + ) + self.figs = panel_figs + standalone_rendered + def save_figures(self, prefix='', extension='.pdf'): @@ -661,14 +805,15 @@ def save_figures(self, prefix='', extension='.pdf'): Default is .eps, the file extension to save with. """ - if hasattr(self, 'figs') and self.figs: - figs = self.figs - else: - return # No figures to save + if not (hasattr(self, 'figs') and self.figs): + return # No figures to save for fig in self.figs: - fig.savefig(prefix + fig.axes[0].get_title() + extension, - bbox_inches = 'tight') + if hasattr(fig, '_panel_label') and fig._panel_label: + label = fig._panel_label + else: + label = fig.axes[0].get_title() if fig.axes else 'figure' + fig.savefig(prefix + label + extension, bbox_inches='tight') def compare(self, dataset, logarithmic=False): diff --git a/src/cdtools/models/simple_ptycho.py b/src/cdtools/models/simple_ptycho.py index 5435c662..bcc02d28 100644 --- a/src/cdtools/models/simple_ptycho.py +++ b/src/cdtools/models/simple_ptycho.py @@ -15,10 +15,15 @@ def __init__( probe_guess, obj_guess, min_translation = [0,0], + panel_plot_mode=False, + plot_level=0, ): # We initialize the superclass - super().__init__() + super().__init__( + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, + ) # We register all the constants, like wavelength, as buffers. This # lets the model hook into some nice pytorch features, like using @@ -43,7 +48,8 @@ def __init__( @classmethod - def from_dataset(cls, dataset): + def from_dataset(cls, dataset,panel_plot_mode=False, + plot_level=0, ): # We get the key geometry information from the dataset wavelength = dataset.wavelength @@ -76,7 +82,9 @@ def from_dataset(cls, dataset): probe_basis, probe, obj, - min_translation=min_translation + min_translation=min_translation, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) @@ -107,15 +115,59 @@ def loss(self, real_data, sim_data): # This lists all the plots to display on a call to model.inspect() - plot_list = [ - ('Probe Amplitude', - lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)), - ('Probe Phase', - lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)), - ('Object Phase', - lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis)) + plot_panel_list = [ + { + # Title for window + 'title' : 'Probe Results', + # (width, height) in inches + 'figure_size': (7, 7), + # (nrows, ncols) for subplot grid + 'grid': (2, 2), + # A setting to control how many plots are produced + 'plot_level': 1, + # The list of plots to include + 'plots' : [ + { + 'title': 'Probe Amplitude', + 'subplot' : (0, 0), + 'plot_func': lambda self, fig: + p.plot_amplitude(self.probe, fig, + basis=self.probe_basis) + },{ + 'title': 'Probe Phase', + 'subplot' : (0, 1), + 'plot_func': lambda self, fig: + p.plot_phase(self.probe, fig, + basis=self.probe_basis) + } + ] + }, + { + # Title for window + 'title' : 'Object Results', + # (width, height) in inches + 'figure_size': (7, 7), + # (nrows, ncols) for subplot grid + 'grid': (2, 2), + # A setting to control how many plots are produced + 'plot_level': 1, + # The list of plots to include + 'plots' : [ + { + 'title': 'Object Amplitude', + 'subplot' : (1, 0), + 'plot_func': lambda self, fig: + p.plot_amplitude(self.obj, fig, + basis=self.probe_basis) + }, { + 'title': 'Object Phase', + 'subplot' : (1, 1), + 'plot_func': lambda self, fig: + p.plot_phase(self.obj, fig, + basis=self.probe_basis) + }, + ] + }, ] def save_results(self, dataset): diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 16668149..4beab9f3 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -16,6 +16,7 @@ import threading import queue import time +from matplotlib import pyplot as plt from typing import List, Union if TYPE_CHECKING: @@ -357,10 +358,18 @@ def target(): try: calc.start() while calc.is_alive(): - if hasattr(self.model, 'figs'): - self.model.figs[0].canvas.start_event_loop(0.01) + figs = getattr(self.model, 'figs', []) + open_fig = next( + (f for f in figs if plt.fignum_exists(f.number)), + None, + ) + if open_fig is not None: + try: + open_fig.canvas.start_event_loop(0.01) + except Exception: + time.sleep(0.01) else: - calc.join() + time.sleep(0.01) except KeyboardInterrupt as e: stop_event.set() diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 0f570ea1..e8207640 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -164,7 +164,12 @@ def plot_image( else: im = im.detach().cpu().numpy() - if fig is None: + # Support passing an Axes object instead of a Figure + ax_mode = isinstance(fig, plt.Axes) + if ax_mode: + ax = fig + fig = ax.get_figure() + elif fig is None: fig = plt.figure() ax = fig.add_subplot(111, **kwargs) @@ -173,8 +178,13 @@ def plot_image( # given def make_plot(idx): plt.figure(fig.number) - title = plt.gca().get_title() - fig.clear() + if ax_mode: + title = ax.get_title() + ax.cla() + plt.sca(ax) + else: + title = plt.gca().get_title() + fig.clear() # If im only has two dimensions, this reshape will add a leading @@ -184,9 +194,10 @@ def make_plot(idx): s = im.shape reshaped_im = im.reshape(-1,s[-2],s[-1]) num_images = reshaped_im.shape[0] - fig.plot_idx = idx % num_images + plot_holder = ax if ax_mode else fig + plot_holder.plot_idx = idx % num_images - to_plot = plot_func(reshaped_im[fig.plot_idx]) + to_plot = plot_func(reshaped_im[plot_holder.plot_idx]) mpl_im = plt.imshow( to_plot, @@ -273,11 +284,13 @@ def make_plot(idx): plt.title(title) if len(im.shape) >= 3: - plt.text(0.03, 0.03, str(fig.plot_idx), fontsize=14, transform=plt.gcf().transFigure) + text_transform = ax.transAxes if ax_mode else plt.gcf().transFigure + plt.text(0.03, 0.03, str(plot_holder.plot_idx), fontsize=14, transform=text_transform) return fig - if hasattr(fig, 'plot_idx'): - result = make_plot(fig.plot_idx) + plot_holder = ax if ax_mode else fig + if hasattr(plot_holder, 'plot_idx'): + result = make_plot(plot_holder.plot_idx) else: result = make_plot(0) @@ -285,15 +298,16 @@ def make_plot(idx): def on_action(event): + plot_holder = ax if ax_mode else fig if not hasattr(event, 'button'): event.button = None if not hasattr(event, 'key'): event.key = None if event.key == 'up' or event.button == 'up': - update(fig.plot_idx - 1) + update(plot_holder.plot_idx - 1) elif event.key == 'down' or event.button == 'down': - update(fig.plot_idx + 1) + update(plot_holder.plot_idx + 1) plt.draw() if len(im.shape) >=3: @@ -557,7 +571,11 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver factor = get_units_factor(units) - if fig is None: + if isinstance(fig, plt.Axes): + ax = fig + fig = ax.get_figure() + plt.sca(ax) + elif fig is None: fig = plt.figure() ax = fig.add_subplot(111, **kwargs) else: @@ -614,7 +632,13 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr The figure object that was actually plotted to. """ - if fig is None: + ax_mode = isinstance(fig, plt.Axes) + if ax_mode: + ax = fig + fig = ax.get_figure() + ax.cla() + plt.sca(ax) + elif fig is None: fig = plt.figure() else: plt.figure(fig.number) @@ -622,7 +646,8 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr factor = get_units_factor(units) - bbox = fig.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + plot_area = ax if ax_mode else fig + bbox = plot_area.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) if isinstance(translations, t.Tensor): trans = translations.detach().cpu().numpy() else: From 44fe36d24fe6547f0be70de66d0e78d658f7d051 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Fri, 20 Mar 2026 11:09:13 +0100 Subject: [PATCH 02/23] Kill a bunch of bugs and make the flow more sensible by hand --- examples/simple_ptycho.py | 12 ++- src/cdtools/models/base.py | 154 +++++++++++++--------------- src/cdtools/models/simple_ptycho.py | 6 +- src/cdtools/reconstructors/base.py | 9 +- 4 files changed, 87 insertions(+), 94 deletions(-) diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index 217d96b0..571f24f2 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -15,21 +15,25 @@ dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) # We create a ptychography model from the dataset -model = cdtools.models.SimplePtycho.from_dataset(dataset) +model = cdtools.models.SimplePtycho.from_dataset(dataset, panel_plot_mode=True) # We move the model to the GPU device = 'cuda' model.to(device=device) dataset.get_as(device=device) +model.inspect(dataset) +print('hi') # We run the reconstruction -for loss in model.Adam_optimize(100, dataset, batch_size=10): +for loss in model.Adam_optimize(30, dataset, batch_size=10): # We print a quick report of the optimization status print(model.report()) # And liveplot the updates to the model as they happen - model.inspect(dataset) + if model.epoch % 10 == 0: + model.inspect(dataset) + # We study the results -model.inspect(dataset) +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index f2823b7c..2dba377e 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -30,6 +30,7 @@ import torch as t from torch.utils import data as torchdata +import matplotlib from matplotlib import pyplot as plt from matplotlib.widgets import Slider from matplotlib import ticker @@ -66,6 +67,7 @@ def __init__(self, panel_plot_mode=False, plot_level=np.inf): self.epoch = 0 self.panel_plot_mode = panel_plot_mode self.plot_level = plot_level + self.has_inspect_been_called = False def from_dataset(self, dataset): raise NotImplementedError() @@ -588,29 +590,56 @@ def inspect(self, dataset=None, replot_all=False): plot_panel_list = getattr(self, 'plot_panel_list', None) or [] plot_list = getattr(self, 'plot_list', None) or [] - if self.panel_plot_mode and plot_panel_list: - self._inspect_panel(dataset=dataset, replot_all=replot_all) + if self.panel_plot_mode: + # First we plot all the panels + panel_figs = self._inspect_panel( + plot_panel_list, dataset=dataset, replot_all=replot_all) + # And then we plot all the individual figures + individual_figs = self._inspect_individual_figures( + plot_list, dataset=dataset, replot_all=replot_all + ) + self.figs = panel_figs + individual_figs else: - # Flatten plot_panel_list, assigning each subplot the panel's plot_level, - # then prepend to plot_list + # If not in panel plot mode, we first flatten the figures + # from the panels flat = [] for panel in plot_panel_list: panel_level = panel.get('plot_level', 0) for plot in panel['plots']: + # We add the plot level from the larger panel flat.append({**plot, 'plot_level': panel_level}) + all_plots = flat + list(plot_list) - if not hasattr(self, '_flat_fig_map'): - self._flat_fig_map = {} - - self.figs = self._do_inspect(all_plots, self._flat_fig_map, - dataset=dataset, - replot_all=replot_all) + # We make sure to keep a reference to the open figs around + self.figs = self._inspect_individual_figures( + all_plots, dataset=dataset, replot_all=replot_all) - plt.pause(0.05) + if not self.has_inspect_been_called or replot_all: + # Somehow, this is needed for new figures to appear + if self._is_backend_interactive(): + plt.pause(0.05 * len(self.figs)) + for fig in self.figs: + fig.canvas.flush_events() + self.has_inspect_been_called = True + + def _is_backend_interactive( + self + ): + backend = matplotlib.get_backend().lower() + interactive_bk = matplotlib.backends.backend_registry.list_builtin( + matplotlib.backends.BackendFilter.INTERACTIVE + ) + return backend in [b.lower() for b in interactive_bk] + - def _do_inspect(self, plot_list, fig_map, dataset=None, replot_all=False): + def _inspect_individual_figures( + self, + plot_list, + dataset=None, + replot_all=False + ): """Core one-figure-per-plot rendering logic. fig_map is a dict {title: figure} owned by the caller and updated @@ -622,8 +651,6 @@ def _do_inspect(self, plot_list, fig_map, dataset=None, replot_all=False): Returns the list of figures that were rendered this call. """ - if not plot_list: - return [] rendered = [] @@ -642,45 +669,38 @@ def _do_inspect(self, plot_list, fig_map, dataset=None, replot_all=False): if not condition(self, dataset): continue - title = plot['title'] - fig = fig_map.get(title) - - if fig is not None and not plt.fignum_exists(fig.number): - # Figure was closed by the user - if replot_all: - fig = None - del fig_map[title] - else: - continue # leave it closed + if self.has_inspect_been_called and \ + replot_all == False and \ + not plt.fignum_exists(plot['title']): + continue - if fig is None: - fig = plt.figure(num=title) - fig._panel_label = title - fig_map[title] = fig + if not self.has_inspect_been_called: + fig = plt.figure(plot['title']) + else: + with plt.rc_context({'figure.raise_window': False}): + fig = plt.figure(plot['title']) try: plot['plot_func'](self, fig) - plt.title(title) + plt.title(plot['title']) except TypeError: if dataset is not None: try: plot['plot_func'](self, fig, dataset) - plt.title(title) + plt.title(plot['title']) except Exception: pass except Exception: pass rendered.append(fig) - try: - fig.canvas.draw_idle() - except Exception: - pass + if self._is_backend_interactive(): + plt.draw() return rendered - def _inspect_panel(self, dataset=None, replot_all=False): + def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): """Multi-subplot panel rendering. Creates one figure per plot_panel_list entry, placing each subplot's @@ -688,22 +708,9 @@ def _inspect_panel(self, dataset=None, replot_all=False): on subsequent calls unless replot_all=True. Standalone plot_list entries are then rendered via _do_inspect and appended to self.figs. """ - plot_panel_list = getattr(self, 'plot_panel_list', None) or [] - plot_list = getattr(self, 'plot_list', None) or [] - n_panels = len(plot_panel_list) - - # _panel_figs: list of figures (or None if never created / closed). - # _panel_axes: dict keyed by (panel_idx, row, col) → Axes. - # _standalone_fig_map: dict {title: figure} for standalone plot_list. - first_call = not hasattr(self, '_panel_figs') - if first_call: - self._panel_figs = [None] * n_panels - self._panel_axes = {} - self._standalone_fig_map = {} - - if not hasattr(self, '_standalone_fig_map'): - self._standalone_fig_map = {} + rendered = [] + for panel_idx, panel_def in enumerate(plot_panel_list): panel_level = panel_def.get('plot_level', 0) if panel_level > self.plot_level: @@ -713,30 +720,24 @@ def _inspect_panel(self, dataset=None, replot_all=False): figsize = panel_def.get('figure_size', None) title = panel_def.get('title', '') - fig = self._panel_figs[panel_idx] - - # Detect if a previously open figure was closed by the user. - if fig is not None and not plt.fignum_exists(fig.number): - self._panel_figs[panel_idx] = None - for k in [k for k in self._panel_axes if k[0] == panel_idx]: - del self._panel_axes[k] - fig = None - - if fig is None: - if not first_call and not replot_all: - continue # was closed; leave it closed - fig = plt.figure(num=title, figsize=figsize) - fig._panel_label = title - self._panel_figs[panel_idx] = fig + + if self.has_inspect_been_called and \ + replot_all == False and \ + not plt.fignum_exists(panel_def['title']): + continue + + if not self.has_inspect_been_called: + fig = plt.figure(panel_def['title']) else: + with plt.rc_context({'figure.raise_window': False}): + fig = plt.figure(panel_def['title']) + # Remove all axes and recreate them fresh each update. # plt.colorbar() shrinks the parent axes to make room for # itself, so clearing and recreating is simpler than trying # to undo that resizing. for ax in list(fig.axes): ax.remove() - for k in [k for k in self._panel_axes if k[0] == panel_idx]: - del self._panel_axes[k] for plot in panel_def['plots']: condition = plot.get('condition', None) @@ -753,7 +754,6 @@ def _inspect_panel(self, dataset=None, replot_all=False): ax_key = (panel_idx, row, col) ax = fig.add_subplot(nrows, ncols, position) - self._panel_axes[ax_key] = ax try: plot['plot_func'](self, ax) @@ -767,20 +767,12 @@ def _inspect_panel(self, dataset=None, replot_all=False): pass except Exception: pass + rendered.append(fig) + + if self._is_backend_interactive(): + plt.draw() - try: - fig.canvas.draw_idle() - except Exception: - pass - - # Rebuild self.figs from open panel figures + rendered standalone figures. - panel_figs = [f for f in self._panel_figs if f is not None] - standalone_rendered = self._do_inspect( - list(plot_list), self._standalone_fig_map, - dataset=dataset, replot_all=replot_all, - ) - self.figs = panel_figs + standalone_rendered - + return rendered def save_figures(self, prefix='', extension='.pdf'): diff --git a/src/cdtools/models/simple_ptycho.py b/src/cdtools/models/simple_ptycho.py index bcc02d28..0460ecb5 100644 --- a/src/cdtools/models/simple_ptycho.py +++ b/src/cdtools/models/simple_ptycho.py @@ -16,7 +16,7 @@ def __init__( obj_guess, min_translation = [0,0], panel_plot_mode=False, - plot_level=0, + plot_level=1, ): # We initialize the superclass @@ -49,7 +49,7 @@ def __init__( @classmethod def from_dataset(cls, dataset,panel_plot_mode=False, - plot_level=0, ): + plot_level=1, ): # We get the key geometry information from the dataset wavelength = dataset.wavelength @@ -142,7 +142,7 @@ def loss(self, real_data, sim_data): } ] }, - { + { # Title for window 'title' : 'Object Results', # (width, height) in inches diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 4beab9f3..a9698ee2 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -364,12 +364,9 @@ def target(): None, ) if open_fig is not None: - try: - open_fig.canvas.start_event_loop(0.01) - except Exception: - time.sleep(0.01) - else: - time.sleep(0.01) + open_fig.canvas.flush_events() + # We need a low value for smooth figure responses + time.sleep(0.001) except KeyboardInterrupt as e: stop_event.set() From dd744a00a4e585700b75232d43460285c35f770b Mon Sep 17 00:00:00 2001 From: allevitan Date: Fri, 20 Mar 2026 15:33:49 +0100 Subject: [PATCH 03/23] Improve plotting infrastructure and fix various bugs - Switch panel plots to use subfigures with constrained_layout for better layout management - Add plot_loss_history method to CDIModel base class - Fix matplotlib compatibility for older versions (interactive backend detection) - Make CUDA usage conditional in examples - Add panel_plot_mode and plot_level params to FancyPtycho constructor - Default plot_level filtering to 1 instead of 0 Co-Authored-By: Claude Sonnet 4.6 --- examples/fancy_ptycho.py | 15 +- examples/simple_ptycho.py | 12 +- src/cdtools/models/base.py | 114 ++++++++---- src/cdtools/models/fancy_ptycho.py | 241 +++++++++++++++++-------- src/cdtools/models/simple_ptycho.py | 89 +++------ src/cdtools/tools/plotting/plotting.py | 182 ++++++++++--------- 6 files changed, 387 insertions(+), 266 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index ebee68f0..b7e6fbfc 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -1,4 +1,5 @@ import cdtools +import torch as t from matplotlib import pyplot as plt filename = 'example_data/lab_ptycho_data.cxi' @@ -12,12 +13,15 @@ probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm units='mm', # Set the units for the live plots - obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix + obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix, + exponentiate_obj=False, + panel_plot_mode=True, + plot_level=2, ) -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # For this script, we use a slightly different pattern where we explicitly # create a `Reconstructor` class to orchestrate the reconstruction. The @@ -26,13 +30,14 @@ # e.g. estimates of the moments of individual parameters recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + # The learning rate parameter sets the alpha for Adam. # The beta parameters are (0.9, 0.999) by default # The batch size sets the minibatch size for loss in recon.optimize(50, lr=0.02, batch_size=10): print(model.report()) # Plotting is expensive, so we only do it every tenth epoch - if model.epoch % 10 == 0: + if model.epoch % 2 == 0: model.inspect(dataset) # It's common to chain several different reconstruction loops. Here, we diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index 571f24f2..467b9afd 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -8,6 +8,7 @@ correct for common sources of error. """ import cdtools +import torch as t from matplotlib import pyplot as plt # We load an example dataset from a .cxi file @@ -15,12 +16,12 @@ dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) # We create a ptychography model from the dataset -model = cdtools.models.SimplePtycho.from_dataset(dataset, panel_plot_mode=True) +model = cdtools.models.SimplePtycho.from_dataset(dataset) -# We move the model to the GPU -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +# We move the model to the GPU, if possible +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') model.inspect(dataset) print('hi') @@ -32,7 +33,6 @@ if model.epoch % 10 == 0: model.inspect(dataset) - # We study the results model.inspect(dataset, replot_all=True) model.compare(dataset) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index 2dba377e..8decc4d0 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -28,6 +28,7 @@ """ +from sympy import Q import torch as t from torch.utils import data as torchdata import matplotlib @@ -604,7 +605,7 @@ def inspect(self, dataset=None, replot_all=False): # from the panels flat = [] for panel in plot_panel_list: - panel_level = panel.get('plot_level', 0) + panel_level = panel.get('plot_level', 1) for plot in panel['plots']: # We add the plot level from the larger panel flat.append({**plot, 'plot_level': panel_level}) @@ -628,9 +629,14 @@ def _is_backend_interactive( self ): backend = matplotlib.get_backend().lower() - interactive_bk = matplotlib.backends.backend_registry.list_builtin( - matplotlib.backends.BackendFilter.INTERACTIVE - ) + try: + # matplotlib >= 3.9 + interactive_bk = matplotlib.backends.backend_registry.list_builtin( + matplotlib.backends.BackendFilter.INTERACTIVE + ) + except AttributeError: + # older matplotlib + interactive_bk = matplotlib.rcsetup.interactive_bk return backend in [b.lower() for b in interactive_bk] @@ -656,7 +662,7 @@ def _inspect_individual_figures( for plot in plot_list: # Level filter - if plot.get('plot_level', 0) > self.plot_level: + if plot.get('plot_level', 1) > self.plot_level: continue # Condition check @@ -670,15 +676,17 @@ def _inspect_individual_figures( continue if self.has_inspect_been_called and \ - replot_all == False and \ + not replot_all and \ not plt.fignum_exists(plot['title']): continue if not self.has_inspect_been_called: - fig = plt.figure(plot['title']) + fig = plt.figure(plot['title'], + constrained_layout=True) else: with plt.rc_context({'figure.raise_window': False}): - fig = plt.figure(plot['title']) + fig = plt.figure(plot['title'], + constrained_layout=True) try: plot['plot_func'](self, fig) @@ -711,34 +719,40 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): rendered = [] - for panel_idx, panel_def in enumerate(plot_panel_list): - panel_level = panel_def.get('plot_level', 0) + for panel_def in plot_panel_list[::-1]: # Flip so first ones show on top + panel_level = panel_def.get('plot_level', 1) if panel_level > self.plot_level: continue # skip entire panel nrows, ncols = panel_def['grid'] figsize = panel_def.get('figure_size', None) - title = panel_def.get('title', '') - + if self.has_inspect_been_called and \ - replot_all == False and \ + not replot_all and \ not plt.fignum_exists(panel_def['title']): continue if not self.has_inspect_been_called: - fig = plt.figure(panel_def['title']) + fig = plt.figure(panel_def['title'], figsize=figsize, + constrained_layout=True) else: with plt.rc_context({'figure.raise_window': False}): - fig = plt.figure(panel_def['title']) + fig = plt.figure(panel_def['title'], + constrained_layout=True) - # Remove all axes and recreate them fresh each update. - # plt.colorbar() shrinks the parent axes to make room for - # itself, so clearing and recreating is simpler than trying - # to undo that resizing. - for ax in list(fig.axes): - ax.remove() + fig.clear() + + fig.get_layout_engine().set( + rect=(0.02, 0.02, 0.96, 0.96), + ) + gs = fig.add_gridspec( + nrows, ncols, + width_ratios=[1]*ncols, + height_ratios=[1]*nrows, + ) + for plot in panel_def['plots']: condition = plot.get('condition', None) if condition is not None: @@ -748,25 +762,22 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): except TypeError: if not condition(self, dataset): continue - - row, col = plot['subplot'] - position = row * ncols + col + 1 # 1-indexed for matplotlib - - ax_key = (panel_idx, row, col) - ax = fig.add_subplot(nrows, ncols, position) + subfig = fig.add_subfigure(gs[plot['subplot'][0], + plot['subplot'][1]]) try: - plot['plot_func'](self, ax) - ax.set_title(plot['title']) + plot['plot_func'](self, subfig) + plt.gca().set_title(plot['title']) except TypeError: if dataset is not None: try: - plot['plot_func'](self, ax, dataset) - ax.set_title(plot['title']) - except Exception: + plot['plot_func'](self, subfig, dataset) + plt.gca().set_title(plot['title']) + except TypeError:#Exception: pass - except Exception: - pass + #except Exception: + # pass + rendered.append(fig) if self._is_backend_interactive(): @@ -775,6 +786,41 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): return rendered + def plot_loss_history(self, fig=None, clear_fig=True): + """Plots the loss history on a semilogy axis + + Parameters + ---------- + fig : matplotlib.figure.Figure + Default is a new figure, a matplotlib figure to use to plot + clear_fig : bool + Default is True. Whether to clear the figure before plotting. + + Returns + ------- + used_fig : matplotlib.figure.Figure + The figure object that was actually plotted to. + """ + + if fig is None: + fig = plt.figure() + + if clear_fig: + fig.clear() + + if len(fig.axes) >= 1: + ax = fig.axes[0] + else: + ax = fig.add_subplot(111) + + ax.semilogy(self.loss_history) + plt.title('Loss History') + + ax.set_xlabel('Epoch') + ax.set_ylabel('Loss Metric') + + return fig + def save_figures(self, prefix='', extension='.pdf'): """Saves all currently open inspection figures. diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 0b3d4997..fb9d4daa 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -45,9 +45,12 @@ def __init__(self, near_field=False, angular_spectrum_propagator=None, inv_angular_spectrum_propagator=None, + panel_plot_mode=False, + plot_level=2, ): - super(FancyPtycho, self).__init__() + super(FancyPtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.register_buffer('wavelength', t.as_tensor(wavelength, dtype=dtype)) self.store_detector_geometry(detector_geometry, @@ -252,6 +255,8 @@ def from_dataset(cls, obj_view_crop=None, obj_padding=200, near_field=False, + panel_plot_mode=False, + plot_level=2, ): wavelength = dataset.wavelength @@ -517,6 +522,8 @@ def from_dataset(cls, near_field=near_field, angular_spectrum_propagator=angular_spectrum_propagator, inv_angular_spectrum_propagator=inv_angular_spectrum_propagator, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) @@ -887,7 +894,22 @@ def get_probes(idx): cmap=cmap, **kwargs), + + def plot_illumination_intensity(self, fig, dataset): + if not hasattr(self, 'weights') or self.weights.ndim != 1: + raise NotImplementedError('Not yet implemented for OPRP') + p.plot_nanomap( + self.corrected_translations(dataset), + self.weights**2, + fig=fig, + cmap='magma', + cmap_label='Intensity (a.u.)', + units=self.units, + convention='probe', + invert_xaxis=True + ) + def plot_translations_and_originals(self, fig, dataset): """Only used to make a plot for the plot list.""" p.plot_translations( @@ -910,62 +932,160 @@ def plot_translations_and_originals(self, fig, dataset): plt.legend() + plot_panel_list = [ + { + 'title': 'Main Results', + 'plot_level': 1, + 'grid': (2,2), + 'figure_size': (9,7), + 'plots': [ + { + 'title': 'Object Phase', + 'subplot': (0,0), + 'plot_func': lambda self, fig: p.plot_phase( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Object Amplitude', + 'subplot': (1,0), + 'plot_func': lambda self, fig: p.plot_amplitude( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Real Part of T', + 'subplot': (0,0), + 'plot_func': lambda self, fig: p.plot_real( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + units=self.units, + cmap='cividis'), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Imaginary Part of T', + 'subplot': (1,0), + 'plot_func': lambda self, fig: p.plot_imag( + self.obj[self.obj_view_slice], + fig=fig, + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Basis Probes, Colorized', + 'subplot': (0,1), + 'plot_func': lambda self, fig: p.plot_colorized( + (self.probe if not self.fourier_probe + else tools.propagators.inverse_far_field(self.probe)), + fig=fig, + basis=self.probe_basis, + units=self.units), + }, + { + 'title': 'Basis Probes, Amplitude', + 'subplot': (1,1), + 'plot_func': lambda self, fig: p.plot_amplitude( + (self.probe if not self.fourier_probe + else tools.propagators.inverse_far_field(self.probe)), + fig=fig, + basis=self.probe_basis, + units=self.units), + }, + ], + }, + { + 'title': 'Advanced Monitoring', + 'plot_level': 2, + 'figure_size': (12,7), + 'grid': (2,3), + 'plots': [ + { + 'title': 'Basis Probes, Fourier Colorized', + 'subplot': (0,0), + 'plot_func': lambda self, fig: p.plot_colorized( + (self.probe if self.fourier_probe + else tools.propagators.far_field(self.probe)), + fig=fig), + }, + { + 'title': 'Basis Probes, Fourier Amplitude', + 'subplot': (1,0), + 'plot_func': lambda self, fig: p.plot_amplitude( + (self.probe if self.fourier_probe + else tools.propagators.far_field(self.probe)), + fig=fig), + }, + { + 'title': 'Illumination Intensity', + 'subplot': (0,1), + 'plot_func': lambda self, fig, dataset: self.plot_illumination_intensity(fig, dataset), + 'condition': lambda self: hasattr(self, 'weights') and self.weights.ndim == 1 + }, + { + 'title': 'Detector Background', + 'subplot': (1,1), + 'plot_func': lambda self, fig: p.plot_amplitude(self.background**2, fig=fig, cmap='magma', cmap_label='Intensity (detector units)'), + }, + { + 'title': 'Corrected Translations', + 'subplot': (0,2), + 'plot_func': lambda self, fig, dataset: self.plot_translations_and_originals(fig, dataset), + }, + { + 'title': 'Loss History', + 'subplot': (1,2), + 'plot_func': lambda self, fig: self.plot_loss_history(fig), + }, + ], + }, + ] + plot_list = [ - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + {'title': 'Per-Exposure Probe Intensity', + 'plot_level': 3, + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='root_sum_intensity', image_title='Root Summed Probe Intensities', image_colorbar_title='Square Root of Intensity'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Per-Exposure Probe Amplitudes', + 'plot_level': 3, + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='amplitude', image_title='Probe Amplitudes (scroll to view modes)', image_colorbar_title='Probe Amplitude'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Per-Exposure Probe Phases', + 'plot_level': 3, + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='phase', image_title='Probe Phases (scroll to view modes)', image_colorbar_title='Probe Phase'), - lambda self: len(self.weights.shape) >= 2), - ('Basis Probe Fourier Space Amplitudes', - lambda self, fig: p.plot_amplitude( - (self.probe if self.fourier_probe - else tools.propagators.far_field(self.probe)), - fig=fig)), - ('Basis Probe Fourier Space Colorized', - lambda self, fig: p.plot_colorized( - (self.probe if self.fourier_probe - else tools.propagators.far_field(self.probe)) - , fig=fig)), - ('Basis Probe Real Space Amplitudes', - lambda self, fig: p.plot_amplitude( - (self.probe if not self.fourier_probe - else tools.propagators.inverse_far_field(self.probe)), - fig=fig, - basis=self.probe_basis, - units=self.units)), - ('Basis Probe Real Space Colorized', - lambda self, fig: p.plot_colorized( - (self.probe if not self.fourier_probe - else tools.propagators.inverse_far_field(self.probe)), - fig=fig, - basis=self.probe_basis, - units=self.units)), - ('Average Weight Matrix Amplitudes', - lambda self, fig: p.plot_amplitude( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Average Weight Matrix Amplitudes', + 'plot_level': 1, + 'plot_func': lambda self, fig: p.plot_amplitude( np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), fig=fig), - lambda self: len(self.weights.shape) >= 2), - ('% of Power in Top Mode', - lambda self, fig, dataset: p.plot_nanomap( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '% of Power in Top Mode', + 'plot_level': 3, + 'plot_func': lambda self, fig, dataset: p.plot_nanomap( self.corrected_translations(dataset), 100 * t.stack([ analysis.calc_mode_power_fractions( @@ -975,44 +1095,11 @@ def plot_translations_and_originals(self, fig, dataset): ], dim=0), fig=fig, units=self.units), - lambda self: len(self.weights.shape) >= 2), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Object Phase', - lambda self, fig: p.plot_phase( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Real Part of T', - lambda self, fig: p.plot_real( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units, - cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Imaginary Part of T', - lambda self, fig: p.plot_imag( - self.obj[self.obj_view_slice], - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: self.exponentiate_obj), - - ('Corrected Translations', - lambda self, fig, dataset: self.plot_translations_and_originals(fig, dataset)), - ('Background', - lambda self, fig: p.plot_amplitude(self.background**2, fig=fig)), - ('Quantum Efficiency Mask', - lambda self, fig: p.plot_amplitude(self.qe_mask, fig=fig), - lambda self: (hasattr(self, 'qe_mask') and self.qe_mask is not None)) + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Quantum Efficiency Mask', + 'plot_level': 3, + 'plot_func': lambda self, fig: p.plot_amplitude(self.qe_mask, fig=fig), + 'condition': lambda self: (hasattr(self, 'qe_mask') and self.qe_mask is not None)}, ] diff --git a/src/cdtools/models/simple_ptycho.py b/src/cdtools/models/simple_ptycho.py index 0460ecb5..234ed5d7 100644 --- a/src/cdtools/models/simple_ptycho.py +++ b/src/cdtools/models/simple_ptycho.py @@ -15,15 +15,10 @@ def __init__( probe_guess, obj_guess, min_translation = [0,0], - panel_plot_mode=False, - plot_level=1, ): # We initialize the superclass - super().__init__( - panel_plot_mode=panel_plot_mode, - plot_level=plot_level, - ) + super().__init__() # We register all the constants, like wavelength, as buffers. This # lets the model hook into some nice pytorch features, like using @@ -48,8 +43,7 @@ def __init__( @classmethod - def from_dataset(cls, dataset,panel_plot_mode=False, - plot_level=1, ): + def from_dataset(cls, dataset): # We get the key geometry information from the dataset wavelength = dataset.wavelength @@ -83,8 +77,6 @@ def from_dataset(cls, dataset,panel_plot_mode=False, probe, obj, min_translation=min_translation, - panel_plot_mode=panel_plot_mode, - plot_level=plot_level, ) @@ -115,58 +107,31 @@ def loss(self, real_data, sim_data): # This lists all the plots to display on a call to model.inspect() - plot_panel_list = [ - { - # Title for window - 'title' : 'Probe Results', - # (width, height) in inches - 'figure_size': (7, 7), - # (nrows, ncols) for subplot grid - 'grid': (2, 2), - # A setting to control how many plots are produced - 'plot_level': 1, - # The list of plots to include - 'plots' : [ - { - 'title': 'Probe Amplitude', - 'subplot' : (0, 0), - 'plot_func': lambda self, fig: - p.plot_amplitude(self.probe, fig, - basis=self.probe_basis) - },{ - 'title': 'Probe Phase', - 'subplot' : (0, 1), - 'plot_func': lambda self, fig: - p.plot_phase(self.probe, fig, - basis=self.probe_basis) - } - ] - }, - { - # Title for window - 'title' : 'Object Results', - # (width, height) in inches - 'figure_size': (7, 7), - # (nrows, ncols) for subplot grid - 'grid': (2, 2), - # A setting to control how many plots are produced - 'plot_level': 1, - # The list of plots to include - 'plots' : [ - { - 'title': 'Object Amplitude', - 'subplot' : (1, 0), - 'plot_func': lambda self, fig: - p.plot_amplitude(self.obj, fig, - basis=self.probe_basis) - }, { - 'title': 'Object Phase', - 'subplot' : (1, 1), - 'plot_func': lambda self, fig: - p.plot_phase(self.obj, fig, - basis=self.probe_basis) - }, - ] + plot_list = [ + { + 'title': 'Probe Amplitude', + 'subplot' : (0, 0), + 'plot_func': lambda self, fig: + p.plot_amplitude(self.probe, fig, + basis=self.probe_basis), + }, { + 'title': 'Probe Phase', + 'subplot' : (0, 1), + 'plot_func': lambda self, fig: + p.plot_phase(self.probe, fig, + basis=self.probe_basis) + }, { + 'title': 'Object Amplitude', + 'subplot' : (1, 0), + 'plot_func': lambda self, fig: + p.plot_amplitude(self.obj, fig, + basis=self.probe_basis) + }, { + 'title': 'Object Phase', + 'subplot' : (1, 1), + 'plot_func': lambda self, fig: + p.plot_phase(self.obj, fig, + basis=self.probe_basis) }, ] diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index e8207640..367c9062 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -92,7 +92,6 @@ def get_units_factor(units): factor=1e12 return factor - def plot_image( im, plot_func=lambda x: x, @@ -164,28 +163,18 @@ def plot_image( else: im = im.detach().cpu().numpy() - # Support passing an Axes object instead of a Figure - ax_mode = isinstance(fig, plt.Axes) - if ax_mode: - ax = fig - fig = ax.get_figure() - elif fig is None: + if fig is None: fig = plt.figure() - ax = fig.add_subplot(111, **kwargs) - # This nukes everything and updates either the appropriate image from the # stack of images, or the only image if only a single image has been # given def make_plot(idx): - plt.figure(fig.number) - if ax_mode: - title = ax.get_title() - ax.cla() - plt.sca(ax) - else: - title = plt.gca().get_title() - fig.clear() - + #plt.figure(fig.number) + #title = plt.gca().get_title() + try: + title = fig.axes[0].get_title() + except IndexError: + title = '' # If im only has two dimensions, this reshape will add a leading # dimension, and update will be called on index 0. If it has 3 or more @@ -194,19 +183,41 @@ def make_plot(idx): s = im.shape reshaped_im = im.reshape(-1,s[-2],s[-1]) num_images = reshaped_im.shape[0] - plot_holder = ax if ax_mode else fig - plot_holder.plot_idx = idx % num_images - - to_plot = plot_func(reshaped_im[plot_holder.plot_idx]) + fig.plot_idx = idx % num_images + + to_plot = plot_func(reshaped_im[fig.plot_idx]) + + # By only updating the data, and not redrawing the fig, we + # don't "reset" the home positions of the other + if hasattr(fig, '_current_im'): + print('Just changing data') + fig._current_im.set_data(to_plot) + fig._current_im.autoscale() + # We need to go to the "home" position before updating it + # to include the new data, because otherwise it will store + # other axes (potentially zoomed in) positions as "home", + # which is super annoying, more so than the reset. + if fig.canvas.toolbar is not None: + fig.canvas.toolbar.home() + fig.canvas.toolbar.update() + # Replace existing mode number + for artist in fig.texts: + artist.set_text(f'Mode {fig.plot_idx}') + + return fig + + fig.clear() + ax = fig.add_subplot(111, **kwargs) - mpl_im = plt.imshow( + mpl_im = ax.imshow( to_plot, cmap = cmap, interpolation = interpolation, vmin=vmin, vmax=vmax, ) - plt.gca().set_facecolor('k') + fig._current_im = mpl_im + ax.set_facecolor('k') if basis is not None: # we've closed over basis, so we can't edit it @@ -264,50 +275,51 @@ def make_plot(idx): corners = np.matmul(transform_matrix,corners.transpose()) mins = np.min(corners, axis=1) maxes = np.max(corners, axis=1) - plt.gca().set_xlim([mins[0], maxes[0]]) - plt.gca().set_ylim([mins[1], maxes[1]]) - plt.gca().invert_yaxis() + ax.set_xlim([mins[0], maxes[0]]) + ax.set_ylim([mins[1], maxes[1]]) + ax.invert_yaxis() if show_cbar: - cbar = plt.colorbar() + cbar = fig.colorbar(mpl_im, ax=ax, fraction=0.05, pad=0.05) if cmap_label is not None: cbar.set_label(cmap_label) if basis is not None: - plt.xlabel('X (' + units + ')') - plt.ylabel('Y (' + units + ')') + ax.set_xlabel('X (' + units + ')') + ax.set_ylabel('Y (' + units + ')') else: - plt.xlabel('j (pixels)') - plt.ylabel('i (pixels)') - + ax.set_xlabel('j (pixels)') + ax.set_ylabel('i (pixels)') - plt.title(title) + ax.set_title(title) if len(im.shape) >= 3: - text_transform = ax.transAxes if ax_mode else plt.gcf().transFigure - plt.text(0.03, 0.03, str(plot_holder.plot_idx), fontsize=14, transform=text_transform) + fig.text(0.03, 0.03, f'Mode {fig.plot_idx}', fontsize=14) + + if fig.canvas.toolbar is not None: + fig.canvas.toolbar.update() return fig - plot_holder = ax if ax_mode else fig - if hasattr(plot_holder, 'plot_idx'): - result = make_plot(plot_holder.plot_idx) + if hasattr(fig, 'plot_idx'): + result_fig = make_plot(fig.plot_idx) else: - result = make_plot(0) - + result_fig = make_plot(0) + update = make_plot - def on_action(event): - plot_holder = ax if ax_mode else fig + # Protection for multi-subfigure situation + if event.inaxes not in fig.axes: + return if not hasattr(event, 'button'): event.button = None if not hasattr(event, 'key'): event.key = None if event.key == 'up' or event.button == 'up': - update(plot_holder.plot_idx - 1) + update(fig.plot_idx - 1) elif event.key == 'down' or event.button == 'down': - update(plot_holder.plot_idx + 1) + update(fig.plot_idx + 1) plt.draw() if len(im.shape) >=3: @@ -320,7 +332,7 @@ def on_action(event): fig.my_callbacks.append(fig.canvas.mpl_connect('key_press_event',on_action)) fig.my_callbacks.append(fig.canvas.mpl_connect('scroll_event',on_action)) - return result + return result_fig def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Real Part (a.u.)', **kwargs): @@ -571,17 +583,16 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver factor = get_units_factor(units) - if isinstance(fig, plt.Axes): - ax = fig - fig = ax.get_figure() - plt.sca(ax) - elif fig is None: + if fig is None: fig = plt.figure() - ax = fig.add_subplot(111, **kwargs) + + if clear_fig: + fig.clear() + + if len(fig.axes) >= 1: + ax = fig.axes[0] else: - plt.figure(fig.number) - if clear_fig: - plt.gcf().clear() + ax = fig.add_subplot(111, **kwargs) if isinstance(translations, t.Tensor): translations = translations.detach().cpu().numpy() @@ -590,25 +601,33 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver linestyle = '-' if lines else 'None' linewidth = 1 if lines else 0 - plt.plot(translations[:,0], translations[:,1], - marker=marker, linestyle=linestyle, - label=label, color=color, - linewidth=linewidth) + ax.plot(translations[:,0], translations[:,1], + marker=marker, linestyle=linestyle, + label=label, color=color, + linewidth=linewidth) if invert_xaxis: - ax = plt.gca() x_min, x_max = ax.get_xlim() # Protect against flipping twice if plotting on top of existing graph if x_min <= x_max: ax.invert_xaxis() - plt.xlabel('X (' + units + ')') - plt.ylabel('Y (' + units + ')') + ax.set_xlabel('X (' + units + ')') + ax.set_ylabel('Y (' + units + ')') return fig -def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='probe', invert_xaxis=True): +def plot_nanomap( + translations, + values, + fig=None, + cmap='viridis', + cmap_label=None, + units='$\\mu$m', + convention='probe', + invert_xaxis=True +): """Plots a set of nanomap data in a flexible way Parameters @@ -619,6 +638,10 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr A length-N object of values associated with the translations fig : matplotlib.figure.Figure Default is a new figure, a matplotlib figure to use to plot + cmap : str + Default is 'viridis', the colormap to plot with + cmap_label : str + Default is no label, what to label the colorbar when plotting. units : str Default is um, units to report in (assuming input in m) convention : str @@ -632,22 +655,14 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr The figure object that was actually plotted to. """ - ax_mode = isinstance(fig, plt.Axes) - if ax_mode: - ax = fig - fig = ax.get_figure() - ax.cla() - plt.sca(ax) - elif fig is None: + if fig is None: fig = plt.figure() - else: - plt.figure(fig.number) - plt.gcf().clear() - + + fig.clear() + ax = fig.add_subplot(111) factor = get_units_factor(units) - plot_area = ax if ax_mode else fig - bbox = plot_area.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + bbox = fig.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) if isinstance(translations, t.Tensor): trans = translations.detach().cpu().numpy() else: @@ -664,14 +679,17 @@ def plot_nanomap(translations, values, fig=None, units='$\\mu$m', convention='pr s = bbox.width * bbox.height / trans.shape[0] * 72**2 #72 is points per inch s /= 4 # A rough value to make the size work out - plt.scatter(factor * trans[:,0],factor * trans[:,1],s=s,c=values) + scatter_plot = ax.scatter( + factor * trans[:,0],factor * trans[:,1],s=s,c=values, cmap=cmap) if invert_xaxis: - plt.gca().invert_xaxis() + ax.invert_xaxis() - plt.gca().set_facecolor('k') - plt.xlabel('Translation x (' + units + ')') - plt.ylabel('Translation y (' + units + ')') - plt.colorbar() + ax.set_facecolor('k') + ax.set_xlabel('Translation x (' + units + ')') + ax.set_ylabel('Translation y (' + units + ')') + cbar = fig.colorbar(scatter_plot, ax=ax, fraction=0.05, pad=0.05) + if cmap_label is not None: + cbar.set_label(cmap_label) return fig From f4260837bc459d601a48e563e4f0731899b4bd2a Mon Sep 17 00:00:00 2001 From: allevitan Date: Fri, 20 Mar 2026 17:20:52 +0100 Subject: [PATCH 04/23] Add title option to plot_image and wrappers; misc fixes - Add title parameter to plot_image, plot_real, plot_imag, plot_amplitude, plot_phase, and plot_colorized - Various fixes to base.py and fancy_ptycho.py Co-Authored-By: Claude Sonnet 4.6 --- examples/fancy_ptycho.py | 5 +- src/cdtools/models/base.py | 38 +++++++-- src/cdtools/models/fancy_ptycho.py | 105 ++++++++++++++++++------- src/cdtools/tools/plotting/plotting.py | 73 ++++++++--------- 4 files changed, 145 insertions(+), 76 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index b7e6fbfc..8e3a9d14 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -14,9 +14,7 @@ propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm units='mm', # Set the units for the live plots obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix, - exponentiate_obj=False, - panel_plot_mode=True, - plot_level=2, + panel_plot_mode=True, # Organizes the live plots into panels ) if t.cuda.is_available(): @@ -30,7 +28,6 @@ # e.g. estimates of the moments of individual parameters recon = cdtools.reconstructors.AdamReconstructor(model, dataset) - # The learning rate parameter sets the alpha for Adam. # The beta parameters are (0.9, 0.999) by default # The batch size sets the minibatch size diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index 8decc4d0..99a87f8c 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -675,6 +675,7 @@ def _inspect_individual_figures( if not condition(self, dataset): continue + figsize = plot.get('figure_size', None) if self.has_inspect_been_called and \ not replot_all and \ not plt.fignum_exists(plot['title']): @@ -682,22 +683,30 @@ def _inspect_individual_figures( if not self.has_inspect_been_called: fig = plt.figure(plot['title'], + figsize=figsize, constrained_layout=True) else: with plt.rc_context({'figure.raise_window': False}): fig = plt.figure(plot['title'], + figsize = panel_def.get('figure_size', None) constrained_layout=True) try: plot['plot_func'](self, fig) - plt.title(plot['title']) + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) except TypeError: if dataset is not None: try: plot['plot_func'](self, fig, dataset) - plt.title(plot['title']) + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) + except KeyboardInterrupt: + raise except Exception: pass + except KeyboardInterrupt: + raise except Exception: pass @@ -723,6 +732,15 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): panel_level = panel_def.get('plot_level', 1) if panel_level > self.plot_level: continue # skip entire panel + + panel_condition = panel_def.get('condition', None) + if panel_condition is not None: + try: + if not panel_condition(self): + continue + except TypeError: + if not panel_condition(self, dataset): + continue nrows, ncols = panel_def['grid'] figsize = panel_def.get('figure_size', None) @@ -767,16 +785,22 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): try: plot['plot_func'](self, subfig) - plt.gca().set_title(plot['title']) + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) except TypeError: if dataset is not None: try: plot['plot_func'](self, subfig, dataset) - plt.gca().set_title(plot['title']) - except TypeError:#Exception: + if plt.gca().get_title().strip() == '': + plt.title(plot['title']) + except KeyboardInterrupt: + raise + except Exception: pass - #except Exception: - # pass + except KeyboardInterrupt: + raise + except Exception: + pass rendered.append(fig) diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index fb9d4daa..9ec668bc 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -896,11 +896,35 @@ def get_probes(idx): def plot_illumination_intensity(self, fig, dataset): - if not hasattr(self, 'weights') or self.weights.ndim != 1: - raise NotImplementedError('Not yet implemented for OPRP') + if not hasattr(self, 'weights'): + raise NotImplementedError("I don't know how to handle having no weights") + elif self.weights.ndim == 1: + probe_intensities = self.weights.detach().cpu().numpy()**2 + else: + # The big case, with OPRP + probe_matrix = np.zeros([self.probe.shape[0]]*2, + dtype=np.complex64) + np_probes = self.probe.detach().cpu().numpy() + for i in range(probe_matrix.shape[0]): + for j in range(probe_matrix.shape[0]): + probe_matrix[i,j] = np.sum(np_probes[i]*np_probes[j].conj()) + + weights = self.weights.detach().cpu().numpy() + + # The outer one is a sum, because the tensordot is what broadcasts the + # probe matrix along the shot dimension - the second one doesn't have to. + weighted_probe_matrices = np.sum(np.tensordot(weights, probe_matrix, axes=1)[...,None] + * weights.conj().transpose((0,2,1))[...,None,:,:], axis=-2) + + basis_probe_intensities = np.trace(probe_matrix, axis1=-2, axis2=-1) + probe_intensities = np.trace(weighted_probe_matrices, axis1=-2, axis2=-1) + + # Imaginary part is already essentially zero up to rounding error + probe_intensities = np.real(probe_intensities / basis_probe_intensities) + p.plot_nanomap( self.corrected_translations(dataset), - self.weights**2, + probe_intensities, fig=fig, cmap='magma', cmap_label='Intensity (a.u.)', @@ -908,7 +932,7 @@ def plot_illumination_intensity(self, fig, dataset): convention='probe', invert_xaxis=True ) - + def plot_translations_and_originals(self, fig, dataset): """Only used to make a plot for the plot list.""" @@ -987,6 +1011,7 @@ def plot_translations_and_originals(self, fig, dataset): (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, + title='Basis Probe', basis=self.probe_basis, units=self.units), }, @@ -997,6 +1022,7 @@ def plot_translations_and_originals(self, fig, dataset): (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, + title='Basis Probe', basis=self.probe_basis, units=self.units), }, @@ -1014,7 +1040,9 @@ def plot_translations_and_originals(self, fig, dataset): 'plot_func': lambda self, fig: p.plot_colorized( (self.probe if self.fourier_probe else tools.propagators.far_field(self.probe)), - fig=fig), + fig=fig, + title='Basis Probe, Fourier', + ), }, { 'title': 'Basis Probes, Fourier Amplitude', @@ -1022,13 +1050,14 @@ def plot_translations_and_originals(self, fig, dataset): 'plot_func': lambda self, fig: p.plot_amplitude( (self.probe if self.fourier_probe else tools.propagators.far_field(self.probe)), - fig=fig), + fig=fig, + title='Basis Probe, Fourier', + ), }, { 'title': 'Illumination Intensity', 'subplot': (0,1), 'plot_func': lambda self, fig, dataset: self.plot_illumination_intensity(fig, dataset), - 'condition': lambda self: hasattr(self, 'weights') and self.weights.ndim == 1 }, { 'title': 'Detector Background', @@ -1047,11 +1076,48 @@ def plot_translations_and_originals(self, fig, dataset): }, ], }, + { + 'title': 'Unstable Probe Refinement Details', + 'plot_level': 2, + 'figure_size': (9,3.5), + 'grid': (1,2), + 'condition': lambda self: len(self.weights.shape) >= 2, + 'plots': [ + { + 'title': '% of Power in Top Mode', + 'subplot': (0,0), + 'plot_func': lambda self, fig, dataset: p.plot_nanomap( + self.corrected_translations(dataset), + 100 * t.stack([ + analysis.calc_mode_power_fractions( + self.probe.data, + weight_matrix=self.weights.data[i])[0] + for i in range(self.weights.shape[0]) + ], dim=0), + fig=fig, + units=self.units), + 'condition': lambda self: len(self.weights.shape) >= 2 + }, + { + 'title': 'Average Weight Matrix Amplitudes', + 'subplot': (0,1), + 'plot_func': lambda self, fig: p.plot_amplitude( + np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), + fig=fig), + 'condition': lambda self: len(self.weights.shape) >= 2 + }, + ] + } ] plot_list = [ + {'title': 'Quantum Efficiency Mask', + 'plot_level': 2, + 'plot_func': lambda self, fig: p.plot_amplitude(self.qe_mask, fig=fig), + 'condition': lambda self: (hasattr(self, 'qe_mask') and self.qe_mask is not None)}, {'title': 'Per-Exposure Probe Intensity', 'plot_level': 3, + 'figure_size': (8,5.3), 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, @@ -1061,6 +1127,7 @@ def plot_translations_and_originals(self, fig, dataset): 'condition': lambda self: len(self.weights.shape) >= 2}, {'title': 'Per-Exposure Probe Amplitudes', 'plot_level': 3, + 'figure_size': (8,5.3), 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, @@ -1070,6 +1137,7 @@ def plot_translations_and_originals(self, fig, dataset): 'condition': lambda self: len(self.weights.shape) >= 2}, {'title': 'Per-Exposure Probe Phases', 'plot_level': 3, + 'figure_size': (8,5.3), 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, @@ -1077,29 +1145,6 @@ def plot_translations_and_originals(self, fig, dataset): image_title='Probe Phases (scroll to view modes)', image_colorbar_title='Probe Phase'), 'condition': lambda self: len(self.weights.shape) >= 2}, - {'title': 'Average Weight Matrix Amplitudes', - 'plot_level': 1, - 'plot_func': lambda self, fig: p.plot_amplitude( - np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), - fig=fig), - 'condition': lambda self: len(self.weights.shape) >= 2}, - {'title': '% of Power in Top Mode', - 'plot_level': 3, - 'plot_func': lambda self, fig, dataset: p.plot_nanomap( - self.corrected_translations(dataset), - 100 * t.stack([ - analysis.calc_mode_power_fractions( - self.probe.data, - weight_matrix=self.weights.data[i])[0] - for i in range(self.weights.shape[0]) - ], dim=0), - fig=fig, - units=self.units), - 'condition': lambda self: len(self.weights.shape) >= 2}, - {'title': 'Quantum Efficiency Mask', - 'plot_level': 3, - 'plot_func': lambda self, fig: p.plot_amplitude(self.qe_mask, fig=fig), - 'condition': lambda self: (hasattr(self, 'qe_mask') and self.qe_mask is not None)}, ] diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 367c9062..7069130b 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -105,6 +105,7 @@ def plot_image( vmin=None, vmax=None, interpolation=None, + title=None, **kwargs ): """Plots an image with a colorbar and on an appropriate spatial grid @@ -169,12 +170,13 @@ def plot_image( # stack of images, or the only image if only a single image has been # given def make_plot(idx): - #plt.figure(fig.number) - #title = plt.gca().get_title() - try: - title = fig.axes[0].get_title() - except IndexError: - title = '' + if title is not None: + ax_title = title + else: + try: + ax_title = fig.axes[0].get_title() + except IndexError: + ax_title = '' # If im only has two dimensions, this reshape will add a leading # dimension, and update will be called on index 0. If it has 3 or more @@ -190,7 +192,6 @@ def make_plot(idx): # By only updating the data, and not redrawing the fig, we # don't "reset" the home positions of the other if hasattr(fig, '_current_im'): - print('Just changing data') fig._current_im.set_data(to_plot) fig._current_im.autoscale() # We need to go to the "home" position before updating it @@ -200,10 +201,11 @@ def make_plot(idx): if fig.canvas.toolbar is not None: fig.canvas.toolbar.home() fig.canvas.toolbar.update() - # Replace existing mode number - for artist in fig.texts: - artist.set_text(f'Mode {fig.plot_idx}') - + + if len(im.shape) >= 3: + base = title if title is not None else '('.join(ax_title.split('(')[:-1])[:-1] + fig.axes[0].set_title(base + f' ({fig.plot_idx+1} of {num_images})') + return fig fig.clear() @@ -291,10 +293,10 @@ def make_plot(idx): ax.set_xlabel('j (pixels)') ax.set_ylabel('i (pixels)') - ax.set_title(title) - + if title is not None: + ax.set_title(ax_title) if len(im.shape) >= 3: - fig.text(0.03, 0.03, f'Mode {fig.plot_idx}', fontsize=14) + ax.set_title(ax_title + f' ({fig.plot_idx+1} of {num_images})') if fig.canvas.toolbar is not None: fig.canvas.toolbar.update() @@ -335,7 +337,7 @@ def on_action(event): return result_fig -def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Real Part (a.u.)', **kwargs): +def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Real Part (a.u.)', title=None, **kwargs): """Plots the real part of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -369,11 +371,11 @@ def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ plot_func = lambda x: np.real(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - **kwargs) + title=title, **kwargs) -def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Imaginary Part (a.u.)', **kwargs): +def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Imaginary Part (a.u.)', title=None, **kwargs): """Plots the imaginary part of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -407,10 +409,10 @@ def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ plot_func = lambda x: np.imag(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - **kwargs) + title=title, **kwargs) -def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Amplitude (a.u.)', **kwargs): +def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Amplitude (a.u.)', title=None, **kwargs): """Plots the amplitude of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -444,7 +446,7 @@ def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', plot_func = lambda x: np.absolute(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - **kwargs) + title=title, **kwargs) def plot_phase( @@ -456,6 +458,7 @@ def plot_phase( cmap_label='Phase (rad)', vmin=None, vmax=None, + title=None, **kwargs ): """ Plots the phase of a complex array with dimensions NxM @@ -506,14 +509,14 @@ def plot_phase( return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, units=units, cmap=cmap, cmap_label=cmap_label, - vmin=vmin,vmax=vmax, + vmin=vmin, vmax=vmax, title=title, **kwargs) def plot_amplitude_surfacenorm(): pass -def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', **kwargs): +def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, **kwargs): """ Plots the colorized version of a complex array with dimensions NxM The darkness corresponds to the intensity of the image, and the color @@ -545,7 +548,7 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', **kwargs): """ plot_func = lambda x: colorize(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, - units=units, show_cbar=False, **kwargs) + units=units, show_cbar=False, title=title, **kwargs) def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, invert_xaxis=True, clear_fig=True, label=None, color=None, marker='.', **kwargs): @@ -713,20 +716,19 @@ def plot_nanomap_with_images(translations, get_image_func, values=None, mask=Non # mode, i.e. on a figure that already has this thing showing. if fig is None: - fig = plt.figure(figsize=(8,5.3)) + fig = plt.figure(figsize=(20,4.5), constrained_layout=True) else: - plt.figure(fig.number) - plt.gcf().clear() + fig = plt.figure(fig.number, figsize=(20,4.5), constrained_layout=True) + fig.clear() if hasattr(fig, 'nanomap_cids'): for cid in fig.nanomap_cids: fig.canvas.mpl_disconnect(cid) # Does figsize work with the fig.subplots, or just for plt.subplots? - axes = fig.subplots(1,2) + gs = fig.add_gridspec(2, 2, height_ratios=[0.9,0.1], width_ratios=[1,1]) - fig.tight_layout(rect=[0.04, 0.09, 0.98, 0.96]) - plt.subplots_adjust(wspace=0.25) #avoids overlap of labels with plots - axslider = plt.axes([0.15,0.06,0.75,0.03]) + axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])] + axslider = fig.add_subplot(gs[1, :]) # full width # This gets the set of sizes for the points in the nanomap def calculate_sizes(idx): @@ -779,11 +781,12 @@ def update_colorbar(im): axes[0].set_facecolor('k') axes[0].set_xlabel('Translation x ('+nanomap_units+')', labelpad=1) axes[0].set_ylabel('Translation y ('+nanomap_units+')', labelpad=1) + axes[0].set_aspect('equal') cb1 = plt.colorbar(nanomap, ax=axes[0], orientation='horizontal', format='%.2e', - ticks=ticker.LinearLocator(numticks=5), - pad=0.17,fraction=0.1) - cb1.ax.set_title(nanomap_colorbar_title, size="medium", pad=5) + ticks=ticker.LinearLocator(numticks=5))#, + #pad=0.17,fraction=0.1) + cb1.ax.set_title(nanomap_colorbar_title, size="medium")#, pad=5) cb1.ax.tick_params(labelrotation=20) if values is None: # This seems to do a good job of leaving the appropriate space @@ -836,8 +839,8 @@ def update_colorbar(im): cb2 = plt.colorbar(meas, ax=axes[1], orientation='horizontal', format='%.2e', - ticks=ticker.LinearLocator(numticks=5), - pad=0.17,fraction=0.1) + ticks=ticker.LinearLocator(numticks=5))#, + #pad=-0.17)#,fraction=0.1) cb2.ax.tick_params(labelrotation=20) cb2.ax.set_title(image_colorbar_title, size="medium", pad=5) cb2.ax.callbacks.connect('xlim_changed', lambda ax: update_colorbar(meas)) From e7f943e254a5d8c5f1284df142dbcdcc474dac15 Mon Sep 17 00:00:00 2001 From: allevitan Date: Fri, 20 Mar 2026 17:35:25 +0100 Subject: [PATCH 05/23] Made a few more updates to some example scripts, to show the panel plot mode --- examples/gold_ball_ptycho.py | 11 ++++++----- examples/near_field_ptycho.py | 9 +++++---- src/cdtools/models/base.py | 2 +- src/cdtools/tools/plotting/plotting.py | 4 ++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/gold_ball_ptycho.py b/examples/gold_ball_ptycho.py index 49719751..fd7ba32a 100644 --- a/examples/gold_ball_ptycho.py +++ b/examples/gold_ball_ptycho.py @@ -1,6 +1,6 @@ import cdtools -from matplotlib import pyplot as plt import torch as t +from matplotlib import pyplot as plt filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi' dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) @@ -26,7 +26,8 @@ probe_support_radius=50, propagation_distance=2e-6, units='um', - probe_fourier_crop=pad + probe_fourier_crop=pad, + panel_plot_mode=True, ) @@ -39,9 +40,9 @@ # Not much probe intensity instability in this dataset, no need for this model.weights.requires_grad = False -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # Create the reconstructor recon = cdtools.reconstructors.AdamReconstructor(model, dataset) diff --git a/examples/near_field_ptycho.py b/examples/near_field_ptycho.py index c91076f8..7b86bd3e 100644 --- a/examples/near_field_ptycho.py +++ b/examples/near_field_ptycho.py @@ -1,11 +1,11 @@ import cdtools +import torch as t from matplotlib import pyplot as plt filename = 'example_data/PETRAIII_P25_Near_Field_Ptycho.cxi' dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) dataset.inspect() -plt.show() # Setting near_field equal to True uses an angular spectrum propagator in # lieu of the default Fourier-transform propagator for far-field ptychography. @@ -27,11 +27,12 @@ propagation_distance=3.65e-3, # 3.65 downstream from focus units='um', # Set the units for the live plots obj_view_crop=-35, + panel_plot_mode=True, ) -device = 'cuda' -model.to(device=device) -dataset.get_as(device=device) +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') model.inspect(dataset) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index 99a87f8c..f6473dc2 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -688,7 +688,7 @@ def _inspect_individual_figures( else: with plt.rc_context({'figure.raise_window': False}): fig = plt.figure(plot['title'], - figsize = panel_def.get('figure_size', None) + figsize = figsize, constrained_layout=True) try: diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 7069130b..53808717 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -716,9 +716,9 @@ def plot_nanomap_with_images(translations, get_image_func, values=None, mask=Non # mode, i.e. on a figure that already has this thing showing. if fig is None: - fig = plt.figure(figsize=(20,4.5), constrained_layout=True) + fig = plt.figure(figsize=(8,5.3), constrained_layout=True) else: - fig = plt.figure(fig.number, figsize=(20,4.5), constrained_layout=True) + fig = plt.figure(fig.number, figsize=(8,5.3), constrained_layout=True) fig.clear() if hasattr(fig, 'nanomap_cids'): for cid in fig.nanomap_cids: From 91db5b3c6497cddb403900f780cd60fcb90d7d3d Mon Sep 17 00:00:00 2001 From: allevitan Date: Fri, 20 Mar 2026 19:01:22 +0100 Subject: [PATCH 06/23] Update remaining models to new plot registration system --- examples/transmission_RPI.py | 10 +- src/cdtools/models/bragg_2d_ptycho.py | 91 ++++++++------- src/cdtools/models/multislice_2d_ptycho.py | 99 ++++++++-------- src/cdtools/models/multislice_ptycho.py | 123 ++++++++++---------- src/cdtools/models/rpi.py | 127 ++++++++++++++------- 5 files changed, 257 insertions(+), 193 deletions(-) diff --git a/examples/transmission_RPI.py b/examples/transmission_RPI.py index feff19e8..7b8d65a8 100644 --- a/examples/transmission_RPI.py +++ b/examples/transmission_RPI.py @@ -1,5 +1,6 @@ import cdtools import pickle +import torch as t from matplotlib import pyplot as plt # First, we load an example dataset from a .cxi file @@ -18,17 +19,18 @@ # Note that we explicitly as for two incoherent probe modes model = cdtools.models.RPI.from_dataset(dataset, probe, [500,500], background=background, n_modes=2, - initialization='random') + initialization='random',panel_plot_mode=True) # Let's do this reconstruction on the GPU, shall we? -model.to(device='cuda') -dataset.get_as(device='cuda') +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # Note that the inspect step takes the vast majority of the time # The regularization is an L2 regularizer that empirically helps accelerate # convergence -for loss in model.LBFGS_optimize(30, dataset, lr=0.4, regularization_factor=[0.05,0.05]): +for loss in model.Adam_optimize(30, dataset, lr=0.4, regularization_factor=[0.05,0.05]): model.inspect(dataset) print(model.report()) diff --git a/src/cdtools/models/bragg_2d_ptycho.py b/src/cdtools/models/bragg_2d_ptycho.py index 507323d0..bd8b5803 100644 --- a/src/cdtools/models/bragg_2d_ptycho.py +++ b/src/cdtools/models/bragg_2d_ptycho.py @@ -80,6 +80,8 @@ def __init__( units='um', dtype=t.float32, obj_view_crop=0, + panel_plot_mode=False, + plot_level=1, ): # We need the detector geometry @@ -91,7 +93,8 @@ def __init__( # translation_offsets can stay 2D for now # propagate_probe and correct_tilt are important! - super(Bragg2DPtycho, self).__init__() + super(Bragg2DPtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.register_buffer('wavelength', t.as_tensor(wavelength, dtype=dtype)) self.store_detector_geometry(detector_geometry, @@ -258,7 +261,9 @@ def from_dataset( obj_padding=200, obj_view_crop=None, units='um', - surface_normal=None + surface_normal=None, + panel_plot_mode=False, + plot_level=1, ): wavelength = dataset.wavelength det_basis = dataset.detector_geometry['basis'] @@ -436,8 +441,8 @@ def from_dataset( return cls(wavelength, det_geo, obj_basis, probe, obj, min_translation=min_translation, probe_basis=probe_basis, - median_propagation =median_propagation, - translation_offsets = translation_offsets, + median_propagation=median_propagation, + translation_offsets=translation_offsets, weights=weights, mask=mask, background=background, translation_scale=translation_scale, saturation=saturation, @@ -448,6 +453,8 @@ def from_dataset( lens=lens, obj_view_crop=obj_view_crop, units=units, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) @@ -578,90 +585,90 @@ def corrected_translations(self,dataset): plot_list = [ - ('Basis Probe Fourier Space Amplitudes', - lambda self, fig: p.plot_amplitude(tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Basis Probe Fourier Space Phases', - lambda self, fig: p.plot_phase(tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Basis Probe Real Space Amplitudes, Surface Normal View', - lambda self, fig: p.plot_amplitude( + {'title': 'Basis Probe Fourier Space Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude(tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Basis Probe Fourier Space Phases', + 'plot_func': lambda self, fig: p.plot_phase(tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Basis Probe Real Space Amplitudes, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.probe, fig=fig, basis=self.probe_basis, units=self.units, - )), - ('Basis Probe Real Space Phases, Surface Normal View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Basis Probe Real Space Phases, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_phase( self.probe, fig=fig, basis=self.probe_basis, units=self.units, - )), - ('Basis Probe Real Space Amplitudes, Beam View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Basis Probe Real Space Amplitudes, Beam View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.probe, fig=fig, basis=self.probe_basis, view_basis=beam_basis, units=self.units, - )), - ('Basis Probe Real Space Phases, Beam View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Basis Probe Real Space Phases, Beam View', + 'plot_func': lambda self, fig: p.plot_phase( self.probe, fig=fig, basis=self.probe_basis, view_basis=beam_basis, units=self.units, - )), - ('Object Amplitude, Surface Normal View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Object Amplitude, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, - )), - ('Object Phase, Surface Normal View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Object Phase, Surface Normal View', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, - )), - ('Object Amplitude, Beam View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Object Amplitude, Beam View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=beam_basis, units=self.units, - )), - ('Object Phase, Beam View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Object Phase, Beam View', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=beam_basis, units=self.units, - )), - ('Object Amplitude, Detector View', - lambda self, fig: p.plot_amplitude( + )}, + {'title': 'Object Amplitude, Detector View', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=self.det_basis, units=self.units, - )), - ('Object Phase, Detector View', - lambda self, fig: p.plot_phase( + )}, + {'title': 'Object Phase, Detector View', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, view_basis=self.det_basis, units=self.units, - )), - ('Corrected Translations', - lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)), - ('Background', - lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)) + )}, + {'title': 'Corrected Translations', + 'plot_func': lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)}, + {'title': 'Background', + 'plot_func': lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)}, ] diff --git a/src/cdtools/models/multislice_2d_ptycho.py b/src/cdtools/models/multislice_2d_ptycho.py index 1d6b3cd2..d6663be0 100644 --- a/src/cdtools/models/multislice_2d_ptycho.py +++ b/src/cdtools/models/multislice_2d_ptycho.py @@ -47,9 +47,12 @@ def __init__(self, prevent_aliasing=True, phase_only=False, units='um', + panel_plot_mode=False, + plot_level=1, ): - super(Multislice2DPtycho, self).__init__() + super(Multislice2DPtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.wavelength = t.tensor(wavelength) self.detector_geometry = copy(detector_geometry) self.dz = dz @@ -154,7 +157,7 @@ def __init__(self, @classmethod - def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None): + def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n_modes=1, dm_rank=None, translation_scale=1, saturation=None, propagation_distance=None, scattering_mode=None, oversampling=1, auto_center=True, bandlimit=None, replicate_slice=False, subpixel=True, exponentiate_obj=True, units='um', fourier_probe=False, phase_only=False, prevent_aliasing=True, probe_support_radius=None, panel_plot_mode=False, plot_level=1): wavelength = dataset.wavelength det_basis = dataset.detector_geometry['basis'] @@ -286,7 +289,7 @@ def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n surface_normal=surface_normal, probe_support=probe_support, min_translation=min_translation, - translation_offsets = translation_offsets, + translation_offsets=translation_offsets, weights=Ws, mask=mask, background=background, translation_scale=translation_scale, saturation=saturation, @@ -296,7 +299,9 @@ def from_dataset(cls, dataset, dz, nz, probe_convergence_semiangle, padding=0, n exponentiate_obj=exponentiate_obj, units=units, fourier_probe=fourier_probe, phase_only=phase_only, - prevent_aliasing=prevent_aliasing) + prevent_aliasing=prevent_aliasing, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level) def interaction(self, index, translations): @@ -581,22 +586,22 @@ def tidy_probes(self): # Needs to be updated to allow for plotting to an existing figure - plot_list = [ - ('Probe Fourier Space Amplitude', - lambda self, fig: p.plot_amplitude(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Probe Fourier Space Phase', - lambda self, fig: p.plot_phase(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)), - ('Probe Real Space Amplitude', - lambda self, fig: p.plot_amplitude(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)), - ('Probe Real Space Phase', - lambda self, fig: p.plot_phase(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)), - ('Average Weight Matrix Amplitudes', - lambda self, fig: p.plot_amplitude( + plot_list = [ + {'title': 'Probe Fourier Space Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Probe Fourier Space Phase', + 'plot_func': lambda self, fig: p.plot_phase(self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig)}, + {'title': 'Probe Real Space Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)}, + {'title': 'Probe Real Space Phase', + 'plot_func': lambda self, fig: p.plot_phase(self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe), fig=fig, basis=self.probe_basis, units=self.units)}, + {'title': 'Average Weight Matrix Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), fig=fig), - lambda self: len(self.weights.shape) >= 2), - ('% of Power in Top Mode', - lambda self, fig, dataset: p.plot_nanomap( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '% of Power in Top Mode', + 'plot_func': lambda self, fig, dataset: p.plot_nanomap( self.corrected_translations(dataset), 100 * t.stack([ analysis.calc_mode_power_fractions( @@ -606,35 +611,35 @@ def tidy_probes(self): ], dim=0), fig=fig, units=self.units), - lambda self: len(self.weights.shape) >= 2), - ('Slice by Slice Real Part of T', - lambda self, fig: p.plot_real(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Slice by Slice Imaginary Part of T', - lambda self, fig: p.plot_imag(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: self.exponentiate_obj), - ('Integrated Real Part of T', - lambda self, fig: p.plot_real(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), - lambda self: (self.exponentiate_obj) and self.obj.dim() >= 3), - ('Integrated Imaginary Part of T', - lambda self, fig: p.plot_imag(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: (self.exponentiate_obj) and self.obj.dim() >= 3), - ('Slice by Slice Amplitude of Object Function', - lambda self, fig: p.plot_amplitude(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Slice by Slice Phase of Object Function', - lambda self, fig: p.plot_phase(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units,cmap='cividis'), - lambda self: not self.exponentiate_obj), - ('Amplitude of Stacked Object Function', - lambda self, fig: p.plot_amplitude(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units), - lambda self: (not self.exponentiate_obj) and self.obj.dim() >=3), - ('Phase of Stacked Object Function', - lambda self, fig: p.plot_phase(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), - lambda self: (not self.exponentiate_obj) and self.obj.dim() >= 3), - ('Corrected Translations', - lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)), - ('Background', - lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)) + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Slice by Slice Real Part of T', + 'plot_func': lambda self, fig: p.plot_real(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Slice by Slice Imaginary Part of T', + 'plot_func': lambda self, fig: p.plot_imag(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Integrated Real Part of T', + 'plot_func': lambda self, fig: p.plot_real(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: self.exponentiate_obj and self.obj.dim() >= 3}, + {'title': 'Integrated Imaginary Part of T', + 'plot_func': lambda self, fig: p.plot_imag(t.sum(self.obj.detach().cpu(),dim=0), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: self.exponentiate_obj and self.obj.dim() >= 3}, + {'title': 'Slice by Slice Amplitude of Object Function', + 'plot_func': lambda self, fig: p.plot_amplitude(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Slice by Slice Phase of Object Function', + 'plot_func': lambda self, fig: p.plot_phase(self.obj.detach().cpu(), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Amplitude of Stacked Object Function', + 'plot_func': lambda self, fig: p.plot_amplitude(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units), + 'condition': lambda self: (not self.exponentiate_obj) and self.obj.dim() >= 3}, + {'title': 'Phase of Stacked Object Function', + 'plot_func': lambda self, fig: p.plot_phase(reduce(t.mul, self.obj.detach().cpu()), fig=fig, basis=self.probe_basis, units=self.units, cmap='cividis'), + 'condition': lambda self: (not self.exponentiate_obj) and self.obj.dim() >= 3}, + {'title': 'Corrected Translations', + 'plot_func': lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)}, + {'title': 'Background', + 'plot_func': lambda self, fig: plt.figure(fig.number) and plt.imshow(self.background.detach().cpu().numpy()**2)}, ] diff --git a/src/cdtools/models/multislice_ptycho.py b/src/cdtools/models/multislice_ptycho.py index afcd83e2..cec773cc 100644 --- a/src/cdtools/models/multislice_ptycho.py +++ b/src/cdtools/models/multislice_ptycho.py @@ -39,10 +39,13 @@ def __init__(self, simulate_finite_pixels=False, dtype=t.float32, exponentiate_obj=False, - obj_view_crop=0 + obj_view_crop=0, + panel_plot_mode=False, + plot_level=1, ): - super(MultislicePtycho, self).__init__() + super(MultislicePtycho, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) self.register_buffer('wavelength', t.as_tensor(wavelength, dtype=dtype)) self.store_detector_geometry(detector_geometry, @@ -202,6 +205,8 @@ def from_dataset(cls, obj_view_crop=None, obj_padding=200, exponentiate_obj=False, + panel_plot_mode=False, + plot_level=1, ): wavelength = dataset.wavelength @@ -409,6 +414,8 @@ def from_dataset(cls, simulate_finite_pixels=simulate_finite_pixels, exponentiate_obj=exponentiate_obj, obj_view_crop=obj_view_crop, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) @@ -750,61 +757,61 @@ def get_probes(idx): plot_list = [ - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + {'title': '', + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='root_sum_intensity', image_title='Root Summed Probe Intensities', image_colorbar_title='Square Root of Intensity'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '', + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='amplitude', image_title='Probe Amplitudes (scroll to view modes)', image_colorbar_title='Probe Amplitude'), - lambda self: len(self.weights.shape) >= 2), - ('', - lambda self, fig, dataset: self.plot_wavefront_variation( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '', + 'plot_func': lambda self, fig, dataset: self.plot_wavefront_variation( dataset, fig=fig, mode='phase', image_title='Probe Phases (scroll to view modes)', image_colorbar_title='Probe Phase'), - lambda self: len(self.weights.shape) >= 2), - ('Basis Probe Fourier Space Amplitudes', - lambda self, fig: p.plot_amplitude( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Basis Probe Fourier Space Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( (self.probe if self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), - fig=fig)), - ('Basis Probe Fourier Space Phases', - lambda self, fig: p.plot_phase( + fig=fig)}, + {'title': 'Basis Probe Fourier Space Phases', + 'plot_func': lambda self, fig: p.plot_phase( (self.probe if self.fourier_probe - else tools.propagators.inverse_far_field(self.probe)) - , fig=fig)), - ('Basis Probe Real Space Amplitudes', - lambda self, fig: p.plot_amplitude( + else tools.propagators.inverse_far_field(self.probe)), + fig=fig)}, + {'title': 'Basis Probe Real Space Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, basis=self.probe_basis, - units=self.units)), - ('Basis Probe Real Space Phases', - lambda self, fig: p.plot_phase( + units=self.units)}, + {'title': 'Basis Probe Real Space Phases', + 'plot_func': lambda self, fig: p.plot_phase( (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, basis=self.probe_basis, - units=self.units)), - ('Average Weight Matrix Amplitudes', - lambda self, fig: p.plot_amplitude( + units=self.units)}, + {'title': 'Average Weight Matrix Amplitudes', + 'plot_func': lambda self, fig: p.plot_amplitude( np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), fig=fig), - lambda self: len(self.weights.shape) >= 2), - ('% of Power in Top Mode', - lambda self, fig, dataset: p.plot_nanomap( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': '% of Power in Top Mode', + 'plot_func': lambda self, fig, dataset: p.plot_nanomap( self.corrected_translations(dataset), 100 * t.stack([ analysis.calc_mode_power_fractions( @@ -814,69 +821,69 @@ def get_probes(idx): ], dim=0), fig=fig, units=self.units), - lambda self: len(self.weights.shape) >= 2), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude( + 'condition': lambda self: len(self.weights.shape) >= 2}, + {'title': 'Object Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Imaginary Part', - lambda self, fig: p.plot_imag( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Imaginary Part', + 'plot_func': lambda self, fig: p.plot_imag( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: self.exponentiate_obj), - ('Object Phase', - lambda self, fig: p.plot_phase( + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Object Phase', + 'plot_func': lambda self, fig: p.plot_phase( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Real Part', - lambda self, fig: p.plot_real( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Real Part', + 'plot_func': lambda self, fig: p.plot_real( self.obj[(np.s_[:],) + self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Object Product Amplitude', - lambda self, fig: p.plot_amplitude( + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Object Product Amplitude', + 'plot_func': lambda self, fig: p.plot_amplitude( t.prod(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Sum Imaginary Part', - lambda self, fig: p.plot_imag( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Sum Imaginary Part', + 'plot_func': lambda self, fig: p.plot_imag( t.sum(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: self.exponentiate_obj), - ('Object Product Phase', - lambda self, fig: p.plot_phase( + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Object Product Phase', + 'plot_func': lambda self, fig: p.plot_phase( t.prod(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units), - lambda self: not self.exponentiate_obj), - ('Object (T) Sum Real Part', - lambda self, fig: p.plot_real( + 'condition': lambda self: not self.exponentiate_obj}, + {'title': 'Object (T) Sum Real Part', + 'plot_func': lambda self, fig: p.plot_real( t.sum(self.obj, dim=0)[self.obj_view_slice], fig=fig, basis=self.obj_basis, units=self.units, cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Corrected Translations', - lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)), - ('Background', - lambda self, fig: p.plot_amplitude(self.background**2, fig=fig)) + 'condition': lambda self: self.exponentiate_obj}, + {'title': 'Corrected Translations', + 'plot_func': lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)}, + {'title': 'Background', + 'plot_func': lambda self, fig: p.plot_amplitude(self.background**2, fig=fig)}, ] diff --git a/src/cdtools/models/rpi.py b/src/cdtools/models/rpi.py index 339b6236..4dadb166 100644 --- a/src/cdtools/models/rpi.py +++ b/src/cdtools/models/rpi.py @@ -58,9 +58,12 @@ def __init__( propagation_distance=0, units='um', dtype=t.float32, + panel_plot_mode=False, + plot_level=1, ): - - super(RPI, self).__init__() + + super(RPI, self).__init__(panel_plot_mode=panel_plot_mode, + plot_level=plot_level) complex_dtype = (t.ones([1], dtype=dtype) + 1j * t.ones([1], dtype=dtype)).dtype @@ -164,6 +167,8 @@ def from_dataset( phase_only=False, probe_threshold=0, dtype=t.float32, + panel_plot_mode=False, + plot_level=1, ): complex_dtype = (t.ones([1], dtype=dtype) + 1j * t.ones([1], dtype=dtype)).dtype @@ -227,7 +232,7 @@ def from_dataset( # This will be superceded later by a call to init_obj, but it sets # the shape if obj_size is None: - obj_size = (np.array(self.probe.shape[-2:]) // 2).astype(int) + obj_size = (np.array(probe.shape[-2:]) // 2).astype(int) dummy_init_obj = t.ones([n_modes, obj_size[0], obj_size[1]], dtype=complex_dtype) @@ -247,14 +252,16 @@ def from_dataset( obj_support = t.as_tensor(binary_dilation(obj_support)) rpi_object = cls(wavelength, det_geo, ew_basis, - probe, dummy_init_obj, + probe, dummy_init_obj, background=background, mask=mask, saturation=saturation, obj_support=obj_support, oversampling=oversampling, exponentiate_obj=exponentiate_obj, phase_only=phase_only, - weight_matrix=weight_matrix) + weight_matrix=weight_matrix, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level) # I don't love this pattern, where I do the "real" obj initialization # after creating the rpi object. But, I chose this so that I could @@ -283,7 +290,9 @@ def from_calibration( exponentiate_obj=False, phase_only=False, initialization='random', - dtype=t.float32 + dtype=t.float32, + panel_plot_mode=False, + plot_level=1, ): complex_dtype = (t.ones([1], dtype=dtype) + @@ -308,7 +317,7 @@ def from_calibration( # This will be superceded later by a call to init_obj, but it sets # the shape if obj_size is None: - obj_size = (np.array(self.probe.shape[-2:]) // 2).astype(int) + obj_size = (np.array(probe.shape[-2:]) // 2).astype(int) dummy_init_obj = t.ones([n_modes, obj_size[0], obj_size[1]], dtype=complex_dtype) @@ -327,6 +336,8 @@ def from_calibration( mask=mask, exponentiate_obj=exponentiate_obj, phase_only=phase_only, + panel_plot_mode=panel_plot_mode, + plot_level=plot_level, ) rpi_object.init_obj(initialization) @@ -365,7 +376,7 @@ def init_obj( obj_shape=obj_shape, n_modes=n_modes) else: - raise KeyError('Initialization "' + str(initialization) + \ + raise KeyError('Initialization "' + str(initialization_type) + \ '" invalid - use "spectral", "uniform", or "random"') @@ -531,40 +542,72 @@ def regularizer(self, factors): def sim_to_dataset(self, args_list): raise NotImplementedError('No sim to dataset yet, sorry!') - plot_list = [ - ('Root Sum Squared Amplitude of all Probes', - lambda self, fig: p.plot_amplitude( - np.sqrt(np.sum((t.abs(t.sum(self.weights[..., None, None].detach() * self.probe, axis=-3))**2).cpu().numpy(),axis=0)), - fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Object Phase', - lambda self, fig: p.plot_phase( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: not self.exponentiate_obj), - ('Real Part of T', - lambda self, fig: p.plot_real( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units, - cmap='cividis'), - lambda self: self.exponentiate_obj), - ('Imaginary Part of T', - lambda self, fig: p.plot_imag( - self.obj, - fig=fig, - basis=self.obj_basis, - units=self.units), - lambda self: self.exponentiate_obj), + plot_panel_list = [ + { + 'title': 'RPI Results', + 'plot_level': 1, + 'figure_size': (12, 3.5), + 'grid': (1, 3), + 'plots': [ + { + 'title': 'Object Phase', + 'subplot': (0, 0), + 'plot_func': lambda self, fig: p.plot_phase( + self.obj, + fig=fig, + title='Object Phase', + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Real Part of T', + 'subplot': (0, 0), + 'plot_func': lambda self, fig: p.plot_real( + self.obj, + fig=fig, + title='Real Part of T', + basis=self.obj_basis, + units=self.units, + cmap='cividis'), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Object Amplitude', + 'subplot': (0, 1), + 'plot_func': lambda self, fig: p.plot_amplitude( + self.obj, + fig=fig, + title='Object Amplitude', + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: not self.exponentiate_obj, + }, + { + 'title': 'Imaginary Part of T', + 'subplot': (0, 1), + 'plot_func': lambda self, fig: p.plot_imag( + self.obj, + fig=fig, + title='Imaginary Part of T', + basis=self.obj_basis, + units=self.units), + 'condition': lambda self: self.exponentiate_obj, + }, + { + 'title': 'Root Sum Squared Amplitude of all Probes', + 'subplot': (0, 2), + 'plot_func': lambda self, fig: p.plot_amplitude( + np.sqrt(np.sum( + (t.abs(t.sum(self.weights[..., None, None].detach() + * self.probe, axis=-3))**2 + ).cpu().numpy(), axis=0)), + fig=fig, + basis=self.probe_basis, + units=self.units), + }, + ], + }, ] From f35141dcc839e7f0e24003ab45ab88f2f8abee67 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Fri, 20 Mar 2026 19:39:47 +0100 Subject: [PATCH 07/23] Check that all the models work (all but Multislice2DPtycho, which was already broken, and make necessary tweaks --- examples/fancy_ptycho.py | 3 +- examples/transmission_RPI.py | 2 +- src/cdtools/models/fancy_ptycho.py | 96 +++++++++++++------------- src/cdtools/models/rpi.py | 6 +- src/cdtools/tools/plotting/plotting.py | 4 +- 5 files changed, 54 insertions(+), 57 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index 8e3a9d14..3ac89748 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -14,7 +14,6 @@ propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm units='mm', # Set the units for the live plots obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix, - panel_plot_mode=True, # Organizes the live plots into panels ) if t.cuda.is_available(): @@ -34,7 +33,7 @@ for loss in recon.optimize(50, lr=0.02, batch_size=10): print(model.report()) # Plotting is expensive, so we only do it every tenth epoch - if model.epoch % 2 == 0: + if model.epoch % 10 == 0: model.inspect(dataset) # It's common to chain several different reconstruction loops. Here, we diff --git a/examples/transmission_RPI.py b/examples/transmission_RPI.py index 7b8d65a8..c8fd424e 100644 --- a/examples/transmission_RPI.py +++ b/examples/transmission_RPI.py @@ -19,7 +19,7 @@ # Note that we explicitly as for two incoherent probe modes model = cdtools.models.RPI.from_dataset(dataset, probe, [500,500], background=background, n_modes=2, - initialization='random',panel_plot_mode=True) + initialization='random') # Let's do this reconstruction on the GPU, shall we? diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 9ec668bc..44448b36 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -45,7 +45,7 @@ def __init__(self, near_field=False, angular_spectrum_propagator=None, inv_angular_spectrum_propagator=None, - panel_plot_mode=False, + panel_plot_mode=True, plot_level=2, ): @@ -255,7 +255,7 @@ def from_dataset(cls, obj_view_crop=None, obj_padding=200, near_field=False, - panel_plot_mode=False, + panel_plot_mode=True, plot_level=2, ): @@ -846,6 +846,44 @@ def tidy_probes(self): # We discard the U matrix and re-multiply S & Vh self.weights.data = S[:,:,None] * (Vh / probe_sqrt_intensities) + + def get_probe_intensities(self): + if not hasattr(self, 'weights'): + raise NotImplementedError( + "I don't know how to handle having no weights") + elif self.weights.ndim == 1: + probe_intensities = self.weights.detach().cpu().numpy()**2 + else: + # The big case, with OPRP + probe_matrix = np.zeros([self.probe.shape[0]]*2, + dtype=np.complex64) + np_probes = self.probe.detach().cpu().numpy() + for i in range(probe_matrix.shape[0]): + for j in range(probe_matrix.shape[0]): + probe_matrix[i,j] = np.sum(np_probes[i]*np_probes[j].conj()) + + weights = self.weights.detach().cpu().numpy() + + # The outer one is a sum, because the tensordot is what broadcasts + # the probe matrix along the shot dimension - the second one + # doesn't have to. + weighted_probe_matrices = np.sum( + np.tensordot(weights, probe_matrix, axes=1)[...,None] + * weights.conj().transpose((0,2,1))[...,None,:,:], + axis=-2 + ) + + basis_probe_intensities = np.trace( + probe_matrix, axis1=-2, axis2=-1) + probe_intensities = np.trace( + weighted_probe_matrices, axis1=-2, axis2=-1) + + # Imaginary part is already essentially zero up to rounding error + probe_intensities = np.real( + probe_intensities / basis_probe_intensities) + + return probe_intensities + def plot_wavefront_variation(self, dataset, fig=None, mode='amplitude', **kwargs): def get_probes(idx): @@ -862,22 +900,8 @@ def get_probes(idx): if mode.lower() == 'phase': return np.angle(ortho_probes.detach().cpu().numpy()) - probe_matrix = np.zeros([self.probe.shape[0]]*2, - dtype=np.complex64) - np_probes = self.probe.detach().cpu().numpy() - for i in range(probe_matrix.shape[0]): - for j in range(probe_matrix.shape[0]): - probe_matrix[i,j] = np.sum(np_probes[i]*np_probes[j].conj()) - - weights = self.weights.detach().cpu().numpy() - - probe_intensities = np.sum(np.tensordot(weights, probe_matrix, axes=1) - * weights.conj(), axis=2) - - # Imaginary part is already essentially zero up to rounding error - probe_intensities = np.real(probe_intensities) - - values = np.sum(probe_intensities, axis=1) + values = self.get_probe_intensities() + if mode.lower() == 'amplitude' or mode.lower() == 'root_sum_intensity': cmap = 'viridis' else: @@ -896,35 +920,9 @@ def get_probes(idx): def plot_illumination_intensity(self, fig, dataset): - if not hasattr(self, 'weights'): - raise NotImplementedError("I don't know how to handle having no weights") - elif self.weights.ndim == 1: - probe_intensities = self.weights.detach().cpu().numpy()**2 - else: - # The big case, with OPRP - probe_matrix = np.zeros([self.probe.shape[0]]*2, - dtype=np.complex64) - np_probes = self.probe.detach().cpu().numpy() - for i in range(probe_matrix.shape[0]): - for j in range(probe_matrix.shape[0]): - probe_matrix[i,j] = np.sum(np_probes[i]*np_probes[j].conj()) - - weights = self.weights.detach().cpu().numpy() - - # The outer one is a sum, because the tensordot is what broadcasts the - # probe matrix along the shot dimension - the second one doesn't have to. - weighted_probe_matrices = np.sum(np.tensordot(weights, probe_matrix, axes=1)[...,None] - * weights.conj().transpose((0,2,1))[...,None,:,:], axis=-2) - - basis_probe_intensities = np.trace(probe_matrix, axis1=-2, axis2=-1) - probe_intensities = np.trace(weighted_probe_matrices, axis1=-2, axis2=-1) - - # Imaginary part is already essentially zero up to rounding error - probe_intensities = np.real(probe_intensities / basis_probe_intensities) - p.plot_nanomap( self.corrected_translations(dataset), - probe_intensities, + self.get_probe_intensities(), fig=fig, cmap='magma', cmap_label='Intensity (a.u.)', @@ -1011,7 +1009,7 @@ def plot_translations_and_originals(self, fig, dataset): (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, - title='Basis Probe', + title='Basis Probes', basis=self.probe_basis, units=self.units), }, @@ -1022,7 +1020,7 @@ def plot_translations_and_originals(self, fig, dataset): (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, - title='Basis Probe', + title='Basis Probes', basis=self.probe_basis, units=self.units), }, @@ -1041,7 +1039,7 @@ def plot_translations_and_originals(self, fig, dataset): (self.probe if self.fourier_probe else tools.propagators.far_field(self.probe)), fig=fig, - title='Basis Probe, Fourier', + title='Basis Probes, Fourier', ), }, { @@ -1051,7 +1049,7 @@ def plot_translations_and_originals(self, fig, dataset): (self.probe if self.fourier_probe else tools.propagators.far_field(self.probe)), fig=fig, - title='Basis Probe, Fourier', + title='Basis Probes, Fourier', ), }, { diff --git a/src/cdtools/models/rpi.py b/src/cdtools/models/rpi.py index 4dadb166..0cf0b688 100644 --- a/src/cdtools/models/rpi.py +++ b/src/cdtools/models/rpi.py @@ -58,7 +58,7 @@ def __init__( propagation_distance=0, units='um', dtype=t.float32, - panel_plot_mode=False, + panel_plot_mode=True, plot_level=1, ): @@ -167,7 +167,7 @@ def from_dataset( phase_only=False, probe_threshold=0, dtype=t.float32, - panel_plot_mode=False, + panel_plot_mode=True, plot_level=1, ): complex_dtype = (t.ones([1], dtype=dtype) + @@ -291,7 +291,7 @@ def from_calibration( phase_only=False, initialization='random', dtype=t.float32, - panel_plot_mode=False, + panel_plot_mode=True, plot_level=1, ): diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 53808717..362489c2 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -202,7 +202,7 @@ def make_plot(idx): fig.canvas.toolbar.home() fig.canvas.toolbar.update() - if len(im.shape) >= 3: + if num_images > 1: base = title if title is not None else '('.join(ax_title.split('(')[:-1])[:-1] fig.axes[0].set_title(base + f' ({fig.plot_idx+1} of {num_images})') @@ -295,7 +295,7 @@ def make_plot(idx): if title is not None: ax.set_title(ax_title) - if len(im.shape) >= 3: + if num_images >= 3: ax.set_title(ax_title + f' ({fig.plot_idx+1} of {num_images})') if fig.canvas.toolbar is not None: From bef907c0322475f17eb454410d0acea00b3ed0bc Mon Sep 17 00:00:00 2001 From: allevitan Date: Fri, 20 Mar 2026 20:01:40 +0100 Subject: [PATCH 08/23] Added a minimum plotting interval to model.inspect(dataset) to make it easier to not plot so much without manual intervention --- examples/fancy_ptycho.py | 12 ++++++------ examples/gold_ball_ptycho.py | 11 ++++------- examples/near_field_ptycho.py | 9 +++------ examples/simple_ptycho.py | 6 ++---- examples/transmission_RPI.py | 12 ++++++------ src/cdtools/models/base.py | 14 +++++++++++++- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index 3ac89748..eb9de73d 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -32,9 +32,9 @@ # The batch size sets the minibatch size for loss in recon.optimize(50, lr=0.02, batch_size=10): print(model.report()) - # Plotting is expensive, so we only do it every tenth epoch - if model.epoch % 10 == 0: - model.inspect(dataset) + # Because plotting can be expensive, setting a minimum plotting interval + # (in seconds) can avoid excessive replots. + model.inspect(dataset, min_interval=5) # It's common to chain several different reconstruction loops. Here, we # started with an aggressive refinement to find the probe in the previous @@ -42,12 +42,12 @@ # and larger minibatch for loss in recon.optimize(50, lr=0.005, batch_size=50): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) # This orthogonalizes the recovered probe modes model.tidy_probes() -model.inspect(dataset) +# Setting replot_all will reopen any windows which were closed earlier +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/examples/gold_ball_ptycho.py b/examples/gold_ball_ptycho.py index fd7ba32a..944c760b 100644 --- a/examples/gold_ball_ptycho.py +++ b/examples/gold_ball_ptycho.py @@ -54,13 +54,11 @@ for loss in recon.optimize(20, lr=0.005, batch_size=50): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) for loss in recon.optimize(50, lr=0.002, batch_size=100): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) # We can often reset our guess of the probe positions once we have a # good guess of probe and object, but in this case it causes the @@ -71,8 +69,7 @@ # the loss fails to improve after 10 epochs for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) model.tidy_probes() @@ -80,6 +77,6 @@ # This saves the final result model.save_to_h5('example_reconstructions/gold_balls.h5', dataset) -model.inspect(dataset) +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/examples/near_field_ptycho.py b/examples/near_field_ptycho.py index 7b86bd3e..4534ee58 100644 --- a/examples/near_field_ptycho.py +++ b/examples/near_field_ptycho.py @@ -40,18 +40,15 @@ for loss in recon.optimize(100, lr=0.04, batch_size=10): print(model.report()) - # Plotting is expensive, so we only do it every tenth epoch - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) for loss in recon.optimize(50, lr=0.005, batch_size=50): print(model.report()) - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset, min_interval=5) # This orthogonalizes the recovered probe modes model.tidy_probes() -model.inspect(dataset) +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index 467b9afd..af8aeb11 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -30,10 +30,8 @@ # We print a quick report of the optimization status print(model.report()) # And liveplot the updates to the model as they happen - if model.epoch % 10 == 0: - model.inspect(dataset) + model.inspect(dataset) -# We study the results -model.inspect(dataset, replot_all=True) +# We open a comparison of the simulated and measured data model.compare(dataset) plt.show() diff --git a/examples/transmission_RPI.py b/examples/transmission_RPI.py index c8fd424e..32c02d55 100644 --- a/examples/transmission_RPI.py +++ b/examples/transmission_RPI.py @@ -30,20 +30,20 @@ # Note that the inspect step takes the vast majority of the time # The regularization is an L2 regularizer that empirically helps accelerate # convergence -for loss in model.Adam_optimize(30, dataset, lr=0.4, regularization_factor=[0.05,0.05]): - model.inspect(dataset) +for loss in model.LBFGS_optimize(30, dataset, lr=0.4, regularization_factor=[0.05,0.05]): + model.inspect(dataset, min_interval=5) print(model.report()) # Now we use the regularizer to damp all but the top modes for loss in model.LBFGS_optimize(50, dataset, lr=0.4, regularization_factor=[0.001,0.1]): - #model.inspect(dataset) + model.inspect(dataset, min_interval=5) print(model.report()) -# Save results to a python dictionary -results = model.save_results() +# Save results to an h5 file +model.save_to_h5('example_reconstructions/transmission_RPI.h5', dataset) # Finally, we plot the results -model.inspect(dataset) +model.inspect(dataset, replot_all=True) model.compare(dataset) plt.show() diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index f6473dc2..de9cc723 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -69,6 +69,7 @@ def __init__(self, panel_plot_mode=False, plot_level=np.inf): self.panel_plot_mode = panel_plot_mode self.plot_level = plot_level self.has_inspect_been_called = False + self.last_inspected_time = None def from_dataset(self, dataset): raise NotImplementedError() @@ -557,7 +558,7 @@ def report(self): plot_list = [] - def inspect(self, dataset=None, replot_all=False): + def inspect(self, dataset=None, replot_all=False, min_interval=None): """Plots all the plots defined in the model's plot_panel_list and plot_list attributes Updates any previously plotted figures that are still open. Figures @@ -586,8 +587,17 @@ def inspect(self, dataset=None, replot_all=False): Optional, a dataset matched to the model type replot_all : bool, default: False If True, recreate figures that were previously closed by the user. + min_interval : float, optional + If set, skip updating plots if fewer than this many seconds have + elapsed since the last call to inspect(). The time of the last + update is stored in self.last_inspected_time. """ + if (min_interval is not None + and self.last_inspected_time is not None + and time.time() - self.last_inspected_time < min_interval): + return + plot_panel_list = getattr(self, 'plot_panel_list', None) or [] plot_list = getattr(self, 'plot_list', None) or [] @@ -624,6 +634,8 @@ def inspect(self, dataset=None, replot_all=False): fig.canvas.flush_events() self.has_inspect_been_called = True + self.last_inspected_time = time.time() + def _is_backend_interactive( self From dafb204aa05c82f549d0cf11789d59cb83d35b63 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Fri, 20 Mar 2026 20:18:32 +0100 Subject: [PATCH 09/23] Fix a bug where closed windows wouldn't reopen at the original size --- src/cdtools/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index de9cc723..36d91565 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -768,7 +768,7 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): constrained_layout=True) else: with plt.rc_context({'figure.raise_window': False}): - fig = plt.figure(panel_def['title'], + fig = plt.figure(panel_def['title'], figsize=figsize, constrained_layout=True) fig.clear() From f8ff525722bf2cb3d86d02d8ec50cf9e3e845102 Mon Sep 17 00:00:00 2001 From: allevitan Date: Sat, 21 Mar 2026 08:41:18 +0100 Subject: [PATCH 10/23] Add example patterns for jupyter notebooks --- .gitignore | 3 +- examples/fancy_ptycho_inline.ipynb | 151 ++++++++++++++++++++++++ examples/fancy_ptycho_interactive.ipynb | 144 ++++++++++++++++++++++ 3 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 examples/fancy_ptycho_inline.ipynb create mode 100644 examples/fancy_ptycho_interactive.ipynb diff --git a/.gitignore b/.gitignore index ea4c46c9..ec5e6feb 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ build/* dist */example_data/* *.h5 -.DS_Store \ No newline at end of file +.DS_Store +.ipynb_checkpoints \ No newline at end of file diff --git a/examples/fancy_ptycho_inline.ipynb b/examples/fancy_ptycho_inline.ipynb new file mode 100644 index 00000000..f5cab0fe --- /dev/null +++ b/examples/fancy_ptycho_inline.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "286054ce", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import cdtools\n", + "import torch as t\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "955bb242-e2ed-47c3-919c-1ea690681445", + "metadata": {}, + "outputs": [], + "source": [ + "# Load and inspect a dataset\n", + "\n", + "filename = 'example_data/lab_ptycho_data.cxi'\n", + "dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)\n", + "\n", + "dataset.inspect();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82f71fb4-f013-46bb-817b-1973ba23336a", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize a model from the dataset and move it to the GPU.\n", + "\n", + "model = cdtools.models.FancyPtycho.from_dataset(\n", + " dataset,\n", + " n_modes=3, # Use 3 incoherently mixing probe modes\n", + " oversampling=2, # Simulate the probe on a 2xlarger real-space array\n", + " probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix\n", + " propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm\n", + " units='mm', # Set the units for the live plots\n", + " obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix,\n", + ")\n", + "\n", + "if t.cuda.is_available():\n", + " model.to(device='cuda')\n", + " dataset.get_as(device='cuda')\n", + "\n", + "# Then, create a reconstructor object and view the initialized model\n", + "\n", + "recon = cdtools.reconstructors.AdamReconstructor(model, dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e22a65e-280b-428d-a6f5-f3206d22110b", + "metadata": {}, + "outputs": [], + "source": [ + "# Workaround reconstruction pattern for interactive plotting in jupyter:\n", + "# First, a standalone cell to plot the current model state\n", + "\n", + "model.inspect(dataset, replot_all=True);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81a768b5", + "metadata": {}, + "outputs": [], + "source": [ + "# Second, a cell for running the reconstruction. With this pattern, it is safe\n", + "# to interrupt the kernel. Then, the cell above can be re-run to refresh the plots.\n", + "while model.epoch < 50:\n", + " for loss in recon.optimize(1, lr=0.02, batch_size=10):\n", + " print(model.report())\n", + "\n", + "while model.epoch < 100:\n", + " for loss in recon.optimize(1, lr=0.005, batch_size=10):\n", + " print(model.report())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "237e9286-b6cf-41dc-aeb0-89abfe59b37c", + "metadata": {}, + "outputs": [], + "source": [ + "# Save out the results\n", + "\n", + "model.save_to_h5('lab_ptycho_reconstruction.h5', dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f1565b6", + "metadata": {}, + "outputs": [], + "source": [ + "# Finalize the plotting and create the comparison plot\n", + "\n", + "# This orthogonalizes the recovered probe modes. It is best to do so\n", + "# after saving the results, if you intend to initialize any further\n", + "# reconstructions with the probe.\n", + "model.tidy_probes()\n", + "\n", + "# Final plotting\n", + "model.inspect(dataset)\n", + "model.compare(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63ec3aa4-0d1c-4775-9fb2-3002d404faa4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/fancy_ptycho_interactive.ipynb b/examples/fancy_ptycho_interactive.ipynb new file mode 100644 index 00000000..0e9adad6 --- /dev/null +++ b/examples/fancy_ptycho_interactive.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "286054ce", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib widget\n", + "import cdtools\n", + "import torch as t\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "955bb242-e2ed-47c3-919c-1ea690681445", + "metadata": {}, + "outputs": [], + "source": [ + "# Load and inspect a dataset\n", + "\n", + "filename = 'example_data/lab_ptycho_data.cxi'\n", + "dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)\n", + "\n", + "dataset.inspect();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82f71fb4-f013-46bb-817b-1973ba23336a", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize a model from the dataset and move it to the GPU.\n", + "\n", + "model = cdtools.models.FancyPtycho.from_dataset(\n", + " dataset,\n", + " n_modes=3, # Use 3 incoherently mixing probe modes\n", + " oversampling=2, # Simulate the probe on a 2xlarger real-space array\n", + " probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix\n", + " propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm\n", + " units='mm', # Set the units for the live plots\n", + " obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix,\n", + ")\n", + "\n", + "if t.cuda.is_available():\n", + " model.to(device='cuda')\n", + " dataset.get_as(device='cuda')\n", + "\n", + "# Then, create a reconstructor object and view the initialized model\n", + "\n", + "recon = cdtools.reconstructors.AdamReconstructor(model, dataset)\n", + "model.inspect(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81a768b5", + "metadata": {}, + "outputs": [], + "source": [ + "# Workaround reconstruction pattern for interactive plotting in jupyter:\n", + "\n", + "# With this pattern, it is safe to interrupt the kernel. Doing so will\n", + "# trigger an update of the plots, at which point the current state can\n", + "# be viewed. Then this cell can be re-run to continue the reconstruction\n", + "while model.epoch < 50:\n", + " for loss in recon.optimize(1, lr=0.02, batch_size=10):\n", + " print(model.report())\n", + " model.inspect(dataset, min_interval=10)\n", + "\n", + "while model.epoch < 100:\n", + " for loss in recon.optimize(1, lr=0.005, batch_size=10):\n", + " print(model.report())\n", + " model.inspect(dataset, min_interval=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "237e9286-b6cf-41dc-aeb0-89abfe59b37c", + "metadata": {}, + "outputs": [], + "source": [ + "# Save out the results\n", + "\n", + "model.save_to_h5('lab_ptycho_reconstruction.h5', dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f1565b6", + "metadata": {}, + "outputs": [], + "source": [ + "# Finalize the plotting and create the comparison plot\n", + "\n", + "# This orthogonalizes the recovered probe modes. It is best to do so\n", + "# after saving the results, if you intend to initialize any further\n", + "# reconstructions with the probe.\n", + "model.tidy_probes()\n", + "\n", + "# Final plotting\n", + "model.inspect(dataset)\n", + "model.compare(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63ec3aa4-0d1c-4775-9fb2-3002d404faa4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a4a38530d1248a99010eacd27c26db48ddf2fa98 Mon Sep 17 00:00:00 2001 From: allevitan Date: Sat, 21 Mar 2026 14:57:16 +0100 Subject: [PATCH 11/23] Update the plot_image functions to show sliders --- src/cdtools/models/fancy_ptycho.py | 12 +- src/cdtools/tools/plotting/plotting.py | 240 +++++++++++++++++-------- 2 files changed, 178 insertions(+), 74 deletions(-) diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 44448b36..ccb12ef0 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -930,6 +930,7 @@ def plot_illumination_intensity(self, fig, dataset): convention='probe', invert_xaxis=True ) + plt.gca().set_aspect('equal') def plot_translations_and_originals(self, fig, dataset): @@ -951,7 +952,8 @@ def plot_translations_and_originals(self, fig, dataset): color='k', marker='.' ) - plt.legend() + plt.gca().set_aspect('equal') + plt.legend(loc='upper right') plot_panel_list = [ @@ -968,6 +970,7 @@ def plot_translations_and_originals(self, fig, dataset): self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, + additional_axis_labels=['Mode #',], units=self.units), 'condition': lambda self: not self.exponentiate_obj, }, @@ -978,6 +981,7 @@ def plot_translations_and_originals(self, fig, dataset): self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, + additional_axis_labels=['Mode #',], units=self.units), 'condition': lambda self: not self.exponentiate_obj, }, @@ -988,6 +992,7 @@ def plot_translations_and_originals(self, fig, dataset): self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, + additional_axis_labels=['Mode #',], units=self.units, cmap='cividis'), 'condition': lambda self: self.exponentiate_obj, @@ -999,6 +1004,7 @@ def plot_translations_and_originals(self, fig, dataset): self.obj[self.obj_view_slice], fig=fig, basis=self.obj_basis, + additional_axis_labels=['Mode #',], units=self.units), 'condition': lambda self: self.exponentiate_obj, }, @@ -1011,6 +1017,7 @@ def plot_translations_and_originals(self, fig, dataset): fig=fig, title='Basis Probes', basis=self.probe_basis, + additional_axis_labels=['Mode #',], units=self.units), }, { @@ -1022,6 +1029,7 @@ def plot_translations_and_originals(self, fig, dataset): fig=fig, title='Basis Probes', basis=self.probe_basis, + additional_axis_labels=['Mode #',], units=self.units), }, ], @@ -1040,6 +1048,7 @@ def plot_translations_and_originals(self, fig, dataset): else tools.propagators.far_field(self.probe)), fig=fig, title='Basis Probes, Fourier', + additional_axis_labels=['Mode #',], ), }, { @@ -1050,6 +1059,7 @@ def plot_translations_and_originals(self, fig, dataset): else tools.propagators.far_field(self.probe)), fig=fig, title='Basis Probes, Fourier', + additional_axis_labels=['Mode #',], ), }, { diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 362489c2..a6fdab5e 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -106,6 +106,7 @@ def plot_image( vmax=None, interpolation=None, title=None, + additional_axis_labels=None, **kwargs ): """Plots an image with a colorbar and on an appropriate spatial grid @@ -121,6 +122,10 @@ def plot_image( will be called on each slice of data before it is plotted. This is used internally to enable the plot_real, plot_image, plot_phase, etc. functions. + If the image has more than 2 dimensions, a horizontal slider is created + for each extra axis with length > 1. Up/down arrow keys navigate through + all extra axes odometer-style (last axis changes fastest). + Parameters ---------- @@ -146,6 +151,10 @@ def plot_image( Default is max(plot_func(im)), the maximum value for the colormap interpolation : str What interpolation to use for imshow + additional_axis_labels : list of str, optional + Labels for each extra axis (all dimensions except the last two). + If shorter than the number of extra axes, remaining labels default to + "Axis N". If not set, all labels default to "Axis N". \\**kwargs All other args are passed to fig.add_subplot(111, \\**kwargs) @@ -165,32 +174,40 @@ def plot_image( im = im.detach().cpu().numpy() if fig is None: - fig = plt.figure() - # This nukes everything and updates either the appropriate image from the - # stack of images, or the only image if only a single image has been - # given - def make_plot(idx): - if title is not None: - ax_title = title - else: - try: - ax_title = fig.axes[0].get_title() - except IndexError: - ax_title = '' - - # If im only has two dimensions, this reshape will add a leading - # dimension, and update will be called on index 0. If it has 3 or more - # dimensions, then all the leading dimensions will be compressed into - # one long dimension which can be scrolled through. - s = im.shape - reshaped_im = im.reshape(-1,s[-2],s[-1]) - num_images = reshaped_im.shape[0] - fig.plot_idx = idx % num_images - - to_plot = plot_func(reshaped_im[fig.plot_idx]) + fig = plt.figure(constrained_layout=True) + + # Determine extra (non-image) dimensions and build per-axis slider map + extra_dims = im.shape[:-2] + n_extra = len(extra_dims) + + # I have it say e.g. "0th Axis" instead of "Axis 0", because the latter one + # looks kind of confusing based on the layout that a Slider widget gets + def ordinal(n): + suffix = {1: 'st', 2: 'nd', 3: 'rd'} + return f"{n}{suffix.get(n % 10, 'th') if n % 100 not in (11, 12, 13) else 'th'}" + if additional_axis_labels is None: + additional_axis_labels = [f'{ordinal(i)} Axis' for i in range(n_extra)] + else: + additional_axis_labels = list(additional_axis_labels) + [ + f'{ordinal(i)} Axis' for i in range(len(additional_axis_labels), n_extra) + ] + + # Only axes with length > 1 get sliders + slider_axis_map = [ + (i, extra_dims[i], additional_axis_labels[i]) + for i in range(n_extra) if extra_dims[i] > 1 + ] + n_sliders = len(slider_axis_map) + + def make_plot(idx_list): + # Always update fig._make_plot so slider callbacks get the latest closure + fig._make_plot = make_plot + fig.plot_idx = list(idx_list) + selected = im[tuple(fig.plot_idx)] if n_extra > 0 else im + to_plot = plot_func(selected) # By only updating the data, and not redrawing the fig, we - # don't "reset" the home positions of the other + # don't "reset" the home positions of the toolbar if hasattr(fig, '_current_im'): fig._current_im.set_data(to_plot) fig._current_im.autoscale() @@ -201,33 +218,68 @@ def make_plot(idx): if fig.canvas.toolbar is not None: fig.canvas.toolbar.home() fig.canvas.toolbar.update() + # Sync sliders to the new index without triggering callbacks + if hasattr(fig, '_sliders'): + fig._updating = True + for j, (axis_idx, _, _) in enumerate(slider_axis_map): + fig._sliders[j].set_val(fig.plot_idx[axis_idx]) + fig._updating = False + # Restore image axis as the "current" axis so callers using + # plt.gca() / plt.title() target the right axes + if hasattr(fig, '_plot_ax'): + plt.sca(fig._plot_ax) + return fig - if num_images > 1: - base = title if title is not None else '('.join(ax_title.split('(')[:-1])[:-1] - fig.axes[0].set_title(base + f' ({fig.plot_idx+1} of {num_images})') + if title is not None: + ax_title = title + else: + try: + ax_title = fig.axes[0].get_title() + except IndexError: + ax_title = '' - return fig - fig.clear() - ax = fig.add_subplot(111, **kwargs) + + # gs = fig.add_gridspec( + # 1 + n_sliders, 1, + # height_ratios=[1] + [0.04] * n_sliders, + # ) + # ax = fig.add_subplot(gs[0, 0], **kwargs) + # ax_sliders = [fig.add_subplot(gs[i + 1, 0]) for i in range(n_sliders)] + import matplotlib + gs = matplotlib.gridspec.GridSpec( + 2, 1, + height_ratios=[1] + [0.1 * n_sliders], + figure=fig + ) + + ax_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( + 1,1, subplot_spec=gs[0,0]) + ax = fig.add_subplot(ax_gs[0,0], **kwargs) + if n_sliders > 0: + slider_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( + n_sliders,1, subplot_spec=gs[1,0]) + ax_sliders = [fig.add_subplot(slider_gs[i, 0]) for i in range(n_sliders)] + + fig._plot_ax = ax mpl_im = ax.imshow( to_plot, - cmap = cmap, - interpolation = interpolation, + cmap=cmap, + interpolation=interpolation, vmin=vmin, vmax=vmax, ) fig._current_im = mpl_im ax.set_facecolor('k') - + if basis is not None: # we've closed over basis, so we can't edit it - if isinstance(basis,t.Tensor): + if isinstance(basis, t.Tensor): np_basis = basis.detach().cpu().numpy() else: np_basis = basis - + np_basis = np_basis * get_units_factor(units) if isinstance(view_basis, str) and view_basis.lower() == 'ortho': @@ -237,44 +289,44 @@ def make_plot(idx): # y-axis lies in the x-y plane of the basis, perpendicular # to the x-axis - basis_norm = np.linalg.norm(np_basis, axis = 0) - + basis_norm = np.linalg.norm(np_basis, axis=0) + normed_basis = np_basis / basis_norm - normed_z = np.cross(normed_basis[:,1], normed_basis[:,0]) + normed_z = np.cross(normed_basis[:, 1], normed_basis[:, 0]) normed_z /= np.linalg.norm(normed_z) - normed_yprime = np.cross(normed_z, normed_basis[:,1]) + normed_yprime = np.cross(normed_z, normed_basis[:, 1]) normed_yprime /= np.linalg.norm(normed_yprime) - + np_view_basis = np.stack( - [normed_yprime, normed_basis[:,1]], axis=1) + [normed_yprime, normed_basis[:, 1]], axis=1) else: # We've also closed over view_basis, so we can't update it - if isinstance(view_basis,t.Tensor): + if isinstance(view_basis, t.Tensor): np_view_basis = view_basis.detach().cpu().numpy() else: np_view_basis = view_basis - + # We always normalize the view basis - view_basis_norm = np.linalg.norm(np_view_basis, axis = 0) + view_basis_norm = np.linalg.norm(np_view_basis, axis=0) np_view_basis = np_view_basis / view_basis_norm # Holy cow, this works! transform_matrix = \ - np.linalg.lstsq(np_view_basis[:,::-1], np_basis[:,::-1], + np.linalg.lstsq(np_view_basis[:, ::-1], np_basis[:, ::-1], rcond=None)[0] - [[a,c],[b,d]] = transform_matrix + [[a, c], [b, d]] = transform_matrix - transform = mtransforms.Affine2D.from_values(a,b,c,d,0,0) + transform = mtransforms.Affine2D.from_values(a, b, c, d, 0, 0) - trans_data = transform + plt.gca().transData + trans_data = transform + ax.transData mpl_im.set_transform(trans_data) - corners = np.array([[-0.5,-0.5], - [im.shape[-1]-0.5,-0.5], - [-0.5, im.shape[-2]-0.5], - [im.shape[-1]-0.5, im.shape[-2]-0.5]]) - corners = np.matmul(transform_matrix,corners.transpose()) + corners = np.array([[-0.5, -0.5], + [im.shape[-1] - 0.5, -0.5], + [-0.5, im.shape[-2] - 0.5], + [im.shape[-1] - 0.5, im.shape[-2] - 0.5]]) + corners = np.matmul(transform_matrix, corners.transpose()) mins = np.min(corners, axis=1) maxes = np.max(corners, axis=1) ax.set_xlim([mins[0], maxes[0]]) @@ -282,10 +334,10 @@ def make_plot(idx): ax.invert_yaxis() if show_cbar: - cbar = fig.colorbar(mpl_im, ax=ax, fraction=0.05, pad=0.05) + cbar = fig.colorbar(mpl_im, ax=ax, fraction=0.05, pad=0.05, location='right') if cmap_label is not None: cbar.set_label(cmap_label) - + if basis is not None: ax.set_xlabel('X (' + units + ')') ax.set_ylabel('Y (' + units + ')') @@ -295,44 +347,86 @@ def make_plot(idx): if title is not None: ax.set_title(ax_title) - if num_images >= 3: - ax.set_title(ax_title + f' ({fig.plot_idx+1} of {num_images})') - + + # Create sliders for axes with length > 1 + sliders = [] + for j, (axis_idx, axis_len, label) in enumerate(slider_axis_map): + s = Slider(ax_sliders[j], label, 0, axis_len - 1, + valstep=1, valfmt='%d', + valinit=fig.plot_idx[axis_idx]) + sliders.append(s) + fig._sliders = sliders + + # Slider callbacks guarded by _updating flag to prevent re-entry. + # Uses fig._make_plot so subsequent plot_image calls update the closure. + def make_slider_cb(axis_idx): + def cb(val): + if getattr(fig, '_updating', False): + return + new_idx = list(fig.plot_idx) + new_idx[axis_idx] = int(val) + fig._make_plot(new_idx) + plt.draw() + return cb + + for (axis_idx, _, _), slider in zip(slider_axis_map, sliders): + slider.on_changed(make_slider_cb(axis_idx)) + if fig.canvas.toolbar is not None: fig.canvas.toolbar.update() + # Restore image axis as "current" so callers using plt.gca() / plt.title() + # target the image axes, not the last slider axis added + plt.sca(ax) return fig - if hasattr(fig, 'plot_idx'): + if hasattr(fig, 'plot_idx') and len(fig.plot_idx) == n_extra: result_fig = make_plot(fig.plot_idx) else: - result_fig = make_plot(0) - - update = make_plot + result_fig = make_plot([0] * n_extra) def on_action(event): # Protection for multi-subfigure situation if event.inaxes not in fig.axes: return - if not hasattr(event, 'button'): - event.button = None if not hasattr(event, 'key'): event.key = None + if not getattr(fig, '_sliders', []): + return + + direction = None + if event.key == 'up': + direction = -1 + elif event.key == 'down': + direction = 1 + if direction is None: + return - if event.key == 'up' or event.button == 'up': - update(fig.plot_idx - 1) - elif event.key == 'down' or event.button == 'down': - update(fig.plot_idx + 1) + # Odometer-style: last entry in slider_axis_map changes fastest + new_idx = list(fig.plot_idx) + carry = direction + for axis_idx, axis_len, _ in reversed(slider_axis_map): + new_val = new_idx[axis_idx] + carry + if new_val < 0: + new_idx[axis_idx] = axis_len - 1 + carry = -1 + elif new_val >= axis_len: + new_idx[axis_idx] = 0 + carry = 1 + else: + new_idx[axis_idx] = new_val + carry = 0 + break + + make_plot(new_idx) plt.draw() - if len(im.shape) >=3: - if not hasattr(fig,'my_callbacks'): + if n_sliders > 0: + if not hasattr(fig, 'my_callbacks'): fig.my_callbacks = [] - for cid in fig.my_callbacks: fig.canvas.mpl_disconnect(cid) fig.my_callbacks = [] - fig.my_callbacks.append(fig.canvas.mpl_connect('key_press_event',on_action)) - fig.my_callbacks.append(fig.canvas.mpl_connect('scroll_event',on_action)) + fig.my_callbacks.append(fig.canvas.mpl_connect('key_press_event', on_action)) return result_fig From f04ac9f0ec249bbd576405f0195de5513cf0fcf9 Mon Sep 17 00:00:00 2001 From: allevitan Date: Sat, 21 Mar 2026 17:35:14 +0100 Subject: [PATCH 12/23] Make the colorized plot look nicer, with a more perceptually uniform mapping and a colorbar for accurate reading of phases --- src/cdtools/tools/plotting/plotting.py | 43 ++++++++++++++++---------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index a6fdab5e..bec51226 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -31,7 +31,7 @@ ] -def colorize(z): +def colorize(z, use_cmocean=True): """ Returns RGB values for a complex color plot given a complex array This function returns a set of RGB values that can be used directly in a call to imshow based on an input complex numpy array (not a @@ -48,17 +48,26 @@ def colorize(z): """ amp = np.abs(z) - rmin = 0 - rmax = np.max(amp) - amp = np.where(amp < rmin, rmin, amp) - amp = np.where(amp > rmax, rmax, amp) - ph = np.angle(z, deg=1) + 90 - # HSV are values in range [0,1] - h = (ph % 360) / 360 - s = 0.85 * np.ones_like(h) - v = (amp - rmin) / (rmax - rmin) + scaled_amp = amp / np.max(amp) + ph = np.angle(z, deg=1) + + if not use_cmocean: + # HSV are values in range [0,1] + h = ((ph + 90) % 360) / 360 + s = 0.85 * np.ones_like(h) + v = scaled_amp + return hsv_to_rgb(np.dstack((h,s,v))) + else: + base_rgb_values = [] + for channel in range(3): + base_rgb_values.append(np.interp(ph%360, + np.linspace(0, 360, cm_data.shape [0]), + cm_data[:,channel])) + base_rgb_values = np.dstack(base_rgb_values) + rgb_values = base_rgb_values * scaled_amp[...,None] + return rgb_values + - return hsv_to_rgb(np.dstack((h,s,v))) def get_units_factor(units): @@ -642,7 +651,9 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, **kwar """ plot_func = lambda x: colorize(x) return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, - units=units, show_cbar=False, title=title, **kwargs) + cmap=cmocean_phase, vmin=-np.pi, vmax=np.pi, + cmap_label='Phase (rad)', + units=units, show_cbar=True, title=title, **kwargs) def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, invert_xaxis=True, clear_fig=True, label=None, color=None, marker='.', **kwargs): @@ -1337,8 +1348,8 @@ def on_pick(event): [ 0.65121289, 0.47406244, 0.05044367], [ 0.65830839, 0.46993917, 0.04941288]] -rgb = np.array(cm_data) -rgb_with_alpha = np.zeros((rgb.shape[0],4)) -rgb_with_alpha[:,:3] = rgb +cm_data = np.array(cm_data) +rgb_with_alpha = np.zeros((cm_data.shape[0],4)) +rgb_with_alpha[:,:3] = cm_data rgb_with_alpha[:,3] = 1. #set alpha channel to 1 -cmocean_phase = colors.ListedColormap(rgb_with_alpha, N=rgb.shape[0]) +cmocean_phase = colors.ListedColormap(rgb_with_alpha, N=cm_data.shape[0]) From 29656fd78e35f9096b10ff6eb74c01a9feb68bc5 Mon Sep 17 00:00:00 2001 From: allevitan Date: Sat, 21 Mar 2026 21:27:16 +0100 Subject: [PATCH 13/23] Improve the colorized plotting further, and add a colorbar which will be useful for publishing figures now that it's not just a simple hsv lookup. Also fix a bug with nonresponsive windows when all model plots are closed, but the dataset plots are still showing. --- examples/fancy_ptycho.py | 1 + src/cdtools/datasets/ptycho_2d_dataset.py | 2 - src/cdtools/models/base.py | 2 +- src/cdtools/models/fancy_ptycho.py | 22 ++++---- src/cdtools/reconstructors/base.py | 14 ++--- src/cdtools/tools/plotting/plotting.py | 62 +++++++++++++++++------ 6 files changed, 65 insertions(+), 38 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index eb9de73d..51d0e0f4 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -14,6 +14,7 @@ propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm units='mm', # Set the units for the live plots obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix, + panel_plot_mode=False ) if t.cuda.is_available(): diff --git a/src/cdtools/datasets/ptycho_2d_dataset.py b/src/cdtools/datasets/ptycho_2d_dataset.py index adfe866e..5a84ca14 100644 --- a/src/cdtools/datasets/ptycho_2d_dataset.py +++ b/src/cdtools/datasets/ptycho_2d_dataset.py @@ -245,8 +245,6 @@ def get_images(idx): return np.log10((meas_data * mask) + log_offset) else: return meas_data * mask - - translations = self.translations.detach().cpu().numpy() # This takes about twice as long as it would to just do it all at # once, but it avoids creating another self.patterns-sized array diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index 36d91565..bc7f73d5 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -812,7 +812,7 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): except KeyboardInterrupt: raise except Exception: - pass + raise rendered.append(fig) diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index ccb12ef0..dbcfc638 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -924,7 +924,7 @@ def plot_illumination_intensity(self, fig, dataset): self.corrected_translations(dataset), self.get_probe_intensities(), fig=fig, - cmap='magma', + cmap='viridis', cmap_label='Intensity (a.u.)', units=self.units, convention='probe', @@ -1009,25 +1009,26 @@ def plot_translations_and_originals(self, fig, dataset): 'condition': lambda self: self.exponentiate_obj, }, { - 'title': 'Basis Probes, Colorized', + 'title': 'Probe Modes, Colorized', 'subplot': (0,1), 'plot_func': lambda self, fig: p.plot_colorized( (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, - title='Basis Probes', + title='Probe Modes, Real Space', basis=self.probe_basis, additional_axis_labels=['Mode #',], + amplitude_scaling=np.sqrt, units=self.units), }, { - 'title': 'Basis Probes, Amplitude', + 'title': 'Probe Modes, Amplitude', 'subplot': (1,1), 'plot_func': lambda self, fig: p.plot_amplitude( (self.probe if not self.fourier_probe else tools.propagators.inverse_far_field(self.probe)), fig=fig, - title='Basis Probes', + title='Probe Modes, Real Space', basis=self.probe_basis, additional_axis_labels=['Mode #',], units=self.units), @@ -1041,24 +1042,25 @@ def plot_translations_and_originals(self, fig, dataset): 'grid': (2,3), 'plots': [ { - 'title': 'Basis Probes, Fourier Colorized', + 'title': 'Probe Modes, Fourier Colorized', 'subplot': (0,0), 'plot_func': lambda self, fig: p.plot_colorized( (self.probe if self.fourier_probe else tools.propagators.far_field(self.probe)), fig=fig, - title='Basis Probes, Fourier', + title='Probe Modes, Fourier Space', additional_axis_labels=['Mode #',], + amplitude_scaling = np.sqrt, ), }, { - 'title': 'Basis Probes, Fourier Amplitude', + 'title': 'Probe Modes, Fourier Amplitude', 'subplot': (1,0), 'plot_func': lambda self, fig: p.plot_amplitude( (self.probe if self.fourier_probe else tools.propagators.far_field(self.probe)), fig=fig, - title='Basis Probes, Fourier', + title='Probe Modes, Fourier Space', additional_axis_labels=['Mode #',], ), }, @@ -1070,7 +1072,7 @@ def plot_translations_and_originals(self, fig, dataset): { 'title': 'Detector Background', 'subplot': (1,1), - 'plot_func': lambda self, fig: p.plot_amplitude(self.background**2, fig=fig, cmap='magma', cmap_label='Intensity (detector units)'), + 'plot_func': lambda self, fig: p.plot_amplitude(self.background**2, fig=fig, cmap='viridis', cmap_label='Intensity (detector units)'), }, { 'title': 'Corrected Translations', diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index a9698ee2..90a5287e 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -233,7 +233,7 @@ def closure(): def optimize(self, iterations: int, batch_size: int = 1, - custom_data_loader: torch.utils.data.DataLoader = None, + custom_data_loader: t.utils.data.DataLoader = None, regularization_factor: Union[float, List[float]] = None, thread: bool = True, calculation_width: int = 10, @@ -358,14 +358,10 @@ def target(): try: calc.start() while calc.is_alive(): - figs = getattr(self.model, 'figs', []) - open_fig = next( - (f for f in figs if plt.fignum_exists(f.number)), - None, - ) - if open_fig is not None: - open_fig.canvas.flush_events() - # We need a low value for smooth figure responses + open_figs = plt.get_fignums() + with plt.rc_context({'figure.raise_window': False}): + for fignum in open_figs: + plt.figure(fignum).canvas.flush_events() time.sleep(0.001) except KeyboardInterrupt as e: diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index bec51226..ca8e53f1 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -31,7 +31,7 @@ ] -def colorize(z, use_cmocean=True): +def colorize(z, use_cmocean=False, amplitude_scaling=lambda x: x): """ Returns RGB values for a complex color plot given a complex array This function returns a set of RGB values that can be used directly in a call to imshow based on an input complex numpy array (not a @@ -41,6 +41,8 @@ def colorize(z, use_cmocean=True): ---------- z : array A complex-valued array + use_cmocean : bool + If true, uses the cmocean_phase colormap instead of hue Returns ------- rgb : list(array) @@ -48,24 +50,25 @@ def colorize(z, use_cmocean=True): """ amp = np.abs(z) - scaled_amp = amp / np.max(amp) + scaled_amp = amplitude_scaling(amp / np.max(amp)) ph = np.angle(z, deg=1) - if not use_cmocean: - # HSV are values in range [0,1] - h = ((ph + 90) % 360) / 360 - s = 0.85 * np.ones_like(h) - v = scaled_amp - return hsv_to_rgb(np.dstack((h,s,v))) - else: + if use_cmocean: base_rgb_values = [] for channel in range(3): - base_rgb_values.append(np.interp(ph%360, + base_rgb_values.append(np.interp((ph + 180)%360, np.linspace(0, 360, cm_data.shape [0]), cm_data[:,channel])) base_rgb_values = np.dstack(base_rgb_values) rgb_values = base_rgb_values * scaled_amp[...,None] return rgb_values + else: + # HSV are values in range [0,1] + h = ((ph + 90) % 360) / 360 + s = 0.85 * np.ones_like(h) + v = scaled_amp + return hsv_to_rgb(np.dstack((h,s,v))) + @@ -116,6 +119,7 @@ def plot_image( interpolation=None, title=None, additional_axis_labels=None, + updateable_colorbar=True, **kwargs ): """Plots an image with a colorbar and on an appropriate spatial grid @@ -219,7 +223,8 @@ def make_plot(idx_list): # don't "reset" the home positions of the toolbar if hasattr(fig, '_current_im'): fig._current_im.set_data(to_plot) - fig._current_im.autoscale() + if updateable_colorbar: + fig._current_im.autoscale() # We need to go to the "home" position before updating it # to include the new data, because otherwise it will store # other axes (potentially zoomed in) positions as "home", @@ -344,6 +349,8 @@ def make_plot(idx_list): if show_cbar: cbar = fig.colorbar(mpl_im, ax=ax, fraction=0.05, pad=0.05, location='right') + if not updateable_colorbar: + cbar.ax.set_navigate(False) if cmap_label is not None: cbar.set_label(cmap_label) @@ -619,7 +626,7 @@ def plot_phase( def plot_amplitude_surfacenorm(): pass -def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, **kwargs): +def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, amplitude_scaling=lambda x: x, **kwargs): """ Plots the colorized version of a complex array with dimensions NxM The darkness corresponds to the intensity of the image, and the color @@ -649,11 +656,34 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, **kwar used_fig : matplotlib.figure.Figure The figure object that was actually plotted to. """ - plot_func = lambda x: colorize(x) - return plot_image(im, plot_func=plot_func, fig=fig, basis=basis, + plot_func = lambda x: colorize(x, use_cmocean=True, + amplitude_scaling=amplitude_scaling) + plot_fig = plot_image(im, plot_func=plot_func, fig=fig, basis=basis, cmap=cmocean_phase, vmin=-np.pi, vmax=np.pi, cmap_label='Phase (rad)', - units=units, show_cbar=True, title=title, **kwargs) + units=units, show_cbar=True, title=title, + updateable_colorbar=False, **kwargs) + + # Find the colorbar - this is a bit hacky + cbar_ax = [ax for ax in plot_fig.get_axes() if hasattr(ax, '_colorbar')][0] + + # --- Replace the colorbar image --- + # The internal image is a QuadMesh living on cbar.ax + #qm = cbar.ax.collections + # Build a 2D array to match the colorbar's range, here (0,1) in x + # and (pi, pi) in y + yg = np.linspace(-np.pi, np.pi, 256) # colormap values + xg = np.linspace(0, 1, 64) # second dimension + YY, XX = np.meshgrid(yg, xg, indexing='ij') + dummy_im = XX * np.exp(1j*YY) + cbar_im = plot_func(dummy_im) + for artist in list(cbar_ax.get_children()): + if 'QuadMesh' in type(artist).__name__ or \ + 'AxesImage' in type(artist).__name__: + artist.remove() + + cbar_ax.imshow(cbar_im, origin='lower', aspect='auto', + extent=[xg[0], xg[-1], -np.pi, np.pi]) def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, invert_xaxis=True, clear_fig=True, label=None, color=None, marker='.', **kwargs): @@ -880,7 +910,7 @@ def update_colorbar(im): nanomap_units_factor = get_units_factor(nanomap_units) nanomap = axes[0].scatter(nanomap_units_factor * translations[:,0], nanomap_units_factor * translations[:,1], - s=s,c=values, picker=True) + s=s,c=values, picker=True, cmap=cmap) axes[0].invert_xaxis() axes[0].set_facecolor('k') From f9793882808224faaf91ae498bff142f0fe1d6d8 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Sat, 21 Mar 2026 21:42:16 +0100 Subject: [PATCH 14/23] Fix an annoying warning coming from double-setting the figsize --- src/cdtools/tools/plotting/plotting.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index ca8e53f1..1b30a766 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -853,7 +853,11 @@ def plot_nanomap_with_images(translations, get_image_func, values=None, mask=Non if fig is None: fig = plt.figure(figsize=(8,5.3), constrained_layout=True) else: - fig = plt.figure(fig.number, figsize=(8,5.3), constrained_layout=True) + if plt.fignum_exists(fig.number): + fig = plt.figure(fig.number) + else: + fig = plt.figure(fig.number, + figsize=(8,5.3), constrained_layout=True) fig.clear() if hasattr(fig, 'nanomap_cids'): for cid in fig.nanomap_cids: From 4201b415735c55382f715a556f81c78222cb5e20 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Tue, 24 Mar 2026 09:13:51 +0100 Subject: [PATCH 15/23] Made more adjustments to the plotting system to avoid using constained_layout, which was causing hangups during live plotting and didn't look as nice --- examples/fancy_ptycho.py | 5 +- examples/near_field_ptycho.py | 1 + src/cdtools/models/base.py | 51 +++++-- src/cdtools/models/fancy_ptycho.py | 8 +- src/cdtools/reconstructors/base.py | 3 +- src/cdtools/tools/plotting/plotting.py | 197 ++++++++++++++++++------- 6 files changed, 185 insertions(+), 80 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index 51d0e0f4..1ca88fea 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -14,7 +14,6 @@ propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm units='mm', # Set the units for the live plots obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix, - panel_plot_mode=False ) if t.cuda.is_available(): @@ -35,7 +34,7 @@ print(model.report()) # Because plotting can be expensive, setting a minimum plotting interval # (in seconds) can avoid excessive replots. - model.inspect(dataset, min_interval=5) + model.inspect(dataset, min_interval=10) # It's common to chain several different reconstruction loops. Here, we # started with an aggressive refinement to find the probe in the previous @@ -43,7 +42,7 @@ # and larger minibatch for loss in recon.optimize(50, lr=0.005, batch_size=50): print(model.report()) - model.inspect(dataset, min_interval=5) + model.inspect(dataset, min_interval=10) # This orthogonalizes the recovered probe modes model.tidy_probes() diff --git a/examples/near_field_ptycho.py b/examples/near_field_ptycho.py index 4534ee58..b5748131 100644 --- a/examples/near_field_ptycho.py +++ b/examples/near_field_ptycho.py @@ -6,6 +6,7 @@ dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) dataset.inspect() +plt.show() # Setting near_field equal to True uses an angular spectrum propagator in # lieu of the default Fourier-transform propagator for far-field ptychography. diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index bc7f73d5..d3cb122e 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -695,13 +695,11 @@ def _inspect_individual_figures( if not self.has_inspect_been_called: fig = plt.figure(plot['title'], - figsize=figsize, - constrained_layout=True) + figsize=figsize) else: with plt.rc_context({'figure.raise_window': False}): fig = plt.figure(plot['title'], - figsize = figsize, - constrained_layout=True) + figsize = figsize) try: plot['plot_func'](self, fig) @@ -764,18 +762,16 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): continue if not self.has_inspect_been_called: - fig = plt.figure(panel_def['title'], figsize=figsize, - constrained_layout=True) + fig = plt.figure(panel_def['title'], figsize=figsize) else: with plt.rc_context({'figure.raise_window': False}): - fig = plt.figure(panel_def['title'], figsize=figsize, - constrained_layout=True) + fig = plt.figure(panel_def['title'], figsize=figsize) + for subfig in fig.subfigs: + if hasattr(subfig, '_sliders'): + for slider in subfig._sliders: + slider.disconnect_events() fig.clear() - - fig.get_layout_engine().set( - rect=(0.02, 0.02, 0.96, 0.96), - ) gs = fig.add_gridspec( nrows, ncols, @@ -813,7 +809,7 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): raise except Exception: raise - + rendered.append(fig) if self._is_backend_interactive(): @@ -821,7 +817,6 @@ def _inspect_panel(self, plot_panel_list, dataset=None, replot_all=False): return rendered - def plot_loss_history(self, fig=None, clear_fig=True): """Plots the loss history on a semilogy axis @@ -847,7 +842,33 @@ def plot_loss_history(self, fig=None, clear_fig=True): if len(fig.axes) >= 1: ax = fig.axes[0] else: - ax = fig.add_subplot(111) + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + bbox = fig.bbox + main_fig_bbox = main_fig.bbox + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 + + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height + + im_ax_bottom = pad_bottom + im_ax_height = 1 - pad_top - im_ax_bottom + + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) ax.semilogy(self.loss_history) plt.title('Loss History') diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index dbcfc638..3edb9ac5 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -961,7 +961,7 @@ def plot_translations_and_originals(self, fig, dataset): 'title': 'Main Results', 'plot_level': 1, 'grid': (2,2), - 'figure_size': (9,7), + 'figure_size': (8.4,6.8), 'plots': [ { 'title': 'Object Phase', @@ -1038,7 +1038,7 @@ def plot_translations_and_originals(self, fig, dataset): { 'title': 'Advanced Monitoring', 'plot_level': 2, - 'figure_size': (12,7), + 'figure_size': (12.6,6.8), 'grid': (2,3), 'plots': [ { @@ -1089,7 +1089,7 @@ def plot_translations_and_originals(self, fig, dataset): { 'title': 'Unstable Probe Refinement Details', 'plot_level': 2, - 'figure_size': (9,3.5), + 'figure_size': (8.4,3.4), 'grid': (1,2), 'condition': lambda self: len(self.weights.shape) >= 2, 'plots': [ @@ -1109,7 +1109,7 @@ def plot_translations_and_originals(self, fig, dataset): 'condition': lambda self: len(self.weights.shape) >= 2 }, { - 'title': 'Average Weight Matrix Amplitudes', + 'title': 'Mean Weight Matrix Amplitudes', 'subplot': (0,1), 'plot_func': lambda self, fig: p.plot_amplitude( np.nanmean(np.abs(self.weights.data.cpu().numpy()), axis=0), diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 90a5287e..4ad7931c 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -362,7 +362,8 @@ def target(): with plt.rc_context({'figure.raise_window': False}): for fignum in open_figs: plt.figure(fignum).canvas.flush_events() - time.sleep(0.001) + + time.sleep(0.01) except KeyboardInterrupt as e: stop_event.set() diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 1b30a766..70bc9442 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -13,7 +13,7 @@ from matplotlib.widgets import Slider from matplotlib import ticker, patheffects from matplotlib import transforms as mtransforms -from matplotlib import colors +from matplotlib import colors, gridspec __all__ = [ @@ -104,6 +104,7 @@ def get_units_factor(units): factor=1e12 return factor + def plot_image( im, plot_func=lambda x: x, @@ -179,15 +180,10 @@ def plot_image( # convert to numpy if isinstance(im, t.Tensor): - # If final dimension is 2, assume it is a complex array. If not, - # assume it represents a real array - if im.shape[-1] == 2: - im = im.detach().cpu().numpy() - else: - im = im.detach().cpu().numpy() + im = im.detach().cpu().numpy() if fig is None: - fig = plt.figure(constrained_layout=True) + fig = plt.figure() # Determine extra (non-image) dimensions and build per-axis slider map extra_dims = im.shape[:-2] @@ -198,11 +194,13 @@ def plot_image( def ordinal(n): suffix = {1: 'st', 2: 'nd', 3: 'rd'} return f"{n}{suffix.get(n % 10, 'th') if n % 100 not in (11, 12, 13) else 'th'}" + if additional_axis_labels is None: additional_axis_labels = [f'{ordinal(i)} Axis' for i in range(n_extra)] else: additional_axis_labels = list(additional_axis_labels) + [ - f'{ordinal(i)} Axis' for i in range(len(additional_axis_labels), n_extra) + f'{ordinal(i)} Axis' + for i in range(len(additional_axis_labels), n_extra) ] # Only axes with length > 1 get sliders @@ -216,8 +214,9 @@ def make_plot(idx_list): # Always update fig._make_plot so slider callbacks get the latest closure fig._make_plot = make_plot fig.plot_idx = list(idx_list) - selected = im[tuple(fig.plot_idx)] if n_extra > 0 else im - to_plot = plot_func(selected) + + selected_im = im[tuple(fig.plot_idx)] if n_extra > 0 else im + to_plot = plot_func(selected_im) # By only updating the data, and not redrawing the fig, we # don't "reset" the home positions of the toolbar @@ -232,16 +231,18 @@ def make_plot(idx_list): if fig.canvas.toolbar is not None: fig.canvas.toolbar.home() fig.canvas.toolbar.update() + # Sync sliders to the new index without triggering callbacks if hasattr(fig, '_sliders'): fig._updating = True for j, (axis_idx, _, _) in enumerate(slider_axis_map): fig._sliders[j].set_val(fig.plot_idx[axis_idx]) fig._updating = False - # Restore image axis as the "current" axis so callers using - # plt.gca() / plt.title() target the right axes + + # Restore image axis as the "current" axis if hasattr(fig, '_plot_ax'): - plt.sca(fig._plot_ax) + plt.sca(fig._current_im.ax) + return fig if title is not None: @@ -252,31 +253,45 @@ def make_plot(idx_list): except IndexError: ax_title = '' - fig.clear() - - # gs = fig.add_gridspec( - # 1 + n_sliders, 1, - # height_ratios=[1] + [0.04] * n_sliders, - # ) - # ax = fig.add_subplot(gs[0, 0], **kwargs) - # ax_sliders = [fig.add_subplot(gs[i + 1, 0]) for i in range(n_sliders)] - import matplotlib - gs = matplotlib.gridspec.GridSpec( - 2, 1, - height_ratios=[1] + [0.1 * n_sliders], - figure=fig - ) + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + bbox = fig.bbox + main_fig_bbox = main_fig.bbox + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 - ax_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( - 1,1, subplot_spec=gs[0,0]) - ax = fig.add_subplot(ax_gs[0,0], **kwargs) - if n_sliders > 0: - slider_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( - n_sliders,1, subplot_spec=gs[1,0]) - ax_sliders = [fig.add_subplot(slider_gs[i, 0]) for i in range(n_sliders)] + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height - fig._plot_ax = ax + im_ax_bottom = pad_bottom + n_sliders * 0.05 + 0.025 + im_ax_height = 1 - pad_top - im_ax_bottom + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) + + pad_left_slider = 1.2 / total_width + pad_right_slider = 0.8 / total_width + if n_sliders > 0: + ax_sliders = [] + for n in range(n_sliders): + ax_sliders.append( + fig.add_axes([pad_left_slider, 0.025 + n * 0.05, + 1-pad_left_slider-pad_right_slider, 0.05])) + + ax_sliders = ax_sliders[::-1] + mpl_im = ax.imshow( to_plot, cmap=cmap, @@ -287,6 +302,8 @@ def make_plot(idx_list): fig._current_im = mpl_im ax.set_facecolor('k') + # Lots of logic to deal with all sorts of wild non-orthogonal + # image basis options. Leave this alone. if basis is not None: # we've closed over basis, so we can't edit it if isinstance(basis, t.Tensor): @@ -348,7 +365,11 @@ def make_plot(idx_list): ax.invert_yaxis() if show_cbar: - cbar = fig.colorbar(mpl_im, ax=ax, fraction=0.05, pad=0.05, location='right') + cbar = fig.colorbar(mpl_im, ax=ax, + fraction=0.15, + pad=0.05, + location='right') + ax.set_anchor('C') if not updateable_colorbar: cbar.ax.set_navigate(False) if cmap_label is not None: @@ -390,9 +411,9 @@ def cb(val): if fig.canvas.toolbar is not None: fig.canvas.toolbar.update() - # Restore image axis as "current" so callers using plt.gca() / plt.title() - # target the image axes, not the last slider axis added + plt.sca(ax) + return fig if hasattr(fig, 'plot_idx') and len(fig.plot_idx) == n_extra: @@ -726,12 +747,38 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver if clear_fig: fig.clear() - + if len(fig.axes) >= 1: ax = fig.axes[0] else: - ax = fig.add_subplot(111, **kwargs) - + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + bbox = fig.bbox + main_fig_bbox = main_fig.bbox + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 + + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height + + im_ax_bottom = pad_bottom + im_ax_height = 1 - pad_top - im_ax_bottom + + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) + if isinstance(translations, t.Tensor): translations = translations.detach().cpu().numpy() @@ -797,10 +844,9 @@ def plot_nanomap( fig = plt.figure() fig.clear() - ax = fig.add_subplot(111) factor = get_units_factor(units) - bbox = fig.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + if isinstance(translations, t.Tensor): trans = translations.detach().cpu().numpy() else: @@ -814,9 +860,38 @@ def plot_nanomap( if convention.lower() != 'probe': trans = trans * -1 - s = bbox.width * bbox.height / trans.shape[0] * 72**2 #72 is points per inch - s /= 4 # A rough value to make the size work out + try: + total_width, total_height = fig.get_size_inches() + except AttributeError: + # Only support one layer of nested subfigures + main_fig = fig.figure # get enclosing figure + bbox = fig.bbox + main_fig_bbox = main_fig.bbox + fig_w, fig_h = main_fig.get_size_inches() + total_width = fig.bbox.width * fig_w / main_fig.bbox.width + total_height = fig.bbox.height * fig_h / main_fig.bbox.height + except AttributeError: + # Fall back to default figsize + total_width, total_height = (6.4, 4.8) + + pad_left = 0.6 / total_height + # De-adjusts for an ad-hoc offset introduced by matplotlib + pad_right = 0.6 / total_width - 0.05 + + pad_bottom = 0.5 / total_height + pad_top = 0.4 / total_height + + im_ax_bottom = pad_bottom + im_ax_height = 1 - pad_top - im_ax_bottom + + ax = fig.add_axes( + [pad_left, im_ax_bottom, 1-pad_left-pad_right, im_ax_height] + ) + + s = total_width * total_height / trans.shape[0] * 72**2 + s /= 4 # A rough value to make the size work out + scatter_plot = ax.scatter( factor * trans[:,0],factor * trans[:,1],s=s,c=values, cmap=cmap) if invert_xaxis: @@ -825,7 +900,14 @@ def plot_nanomap( ax.set_facecolor('k') ax.set_xlabel('Translation x (' + units + ')') ax.set_ylabel('Translation y (' + units + ')') - cbar = fig.colorbar(scatter_plot, ax=ax, fraction=0.05, pad=0.05) + cbar = fig.colorbar( + scatter_plot, + ax=ax, + fraction=0.15, + pad=0.05, + location='right', + ) + ax.set_anchor('C') if cmap_label is not None: cbar.set_label(cmap_label) @@ -851,20 +933,21 @@ def plot_nanomap_with_images(translations, get_image_func, values=None, mask=Non # mode, i.e. on a figure that already has this thing showing. if fig is None: - fig = plt.figure(figsize=(8,5.3), constrained_layout=True) + fig = plt.figure(figsize=(8,5.3)) else: if plt.fignum_exists(fig.number): fig = plt.figure(fig.number) else: fig = plt.figure(fig.number, - figsize=(8,5.3), constrained_layout=True) + figsize=(8,5.3)) fig.clear() if hasattr(fig, 'nanomap_cids'): for cid in fig.nanomap_cids: fig.canvas.mpl_disconnect(cid) # Does figsize work with the fig.subplots, or just for plt.subplots? - gs = fig.add_gridspec(2, 2, height_ratios=[0.9,0.1], width_ratios=[1,1]) + gs = fig.add_gridspec(2, 2, height_ratios=[0.92,0.08], width_ratios=[1,1], + bottom=0.04) axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])] axslider = fig.add_subplot(gs[1, :]) # full width @@ -918,14 +1001,14 @@ def update_colorbar(im): axes[0].invert_xaxis() axes[0].set_facecolor('k') - axes[0].set_xlabel('Translation x ('+nanomap_units+')', labelpad=1) - axes[0].set_ylabel('Translation y ('+nanomap_units+')', labelpad=1) + axes[0].set_xlabel('Translation x ('+nanomap_units+')') + axes[0].set_ylabel('Translation y ('+nanomap_units+')') axes[0].set_aspect('equal') cb1 = plt.colorbar(nanomap, ax=axes[0], orientation='horizontal', format='%.2e', - ticks=ticker.LinearLocator(numticks=5))#, - #pad=0.17,fraction=0.1) - cb1.ax.set_title(nanomap_colorbar_title, size="medium")#, pad=5) + ticks=ticker.LinearLocator(numticks=5), + pad=0.19,fraction=0.1) + cb1.ax.set_title(nanomap_colorbar_title, size="medium", pad=5) cb1.ax.tick_params(labelrotation=20) if values is None: # This seems to do a good job of leaving the appropriate space @@ -978,8 +1061,8 @@ def update_colorbar(im): cb2 = plt.colorbar(meas, ax=axes[1], orientation='horizontal', format='%.2e', - ticks=ticker.LinearLocator(numticks=5))#, - #pad=-0.17)#,fraction=0.1) + ticks=ticker.LinearLocator(numticks=5), + pad=0.19,fraction=0.1) cb2.ax.tick_params(labelrotation=20) cb2.ax.set_title(image_colorbar_title, size="medium", pad=5) cb2.ax.callbacks.connect('xlim_changed', lambda ax: update_colorbar(meas)) From 5c5c0be7c67096e6e46ec6b96c0aab1566048271 Mon Sep 17 00:00:00 2001 From: allevitan Date: Tue, 24 Mar 2026 11:00:43 +0100 Subject: [PATCH 16/23] Reformat plotting function signatures to one-parameter-per-line style All public functions in plotting.py now use the expanded multi-line signature format with trailing commas for consistency and readability. Co-Authored-By: Claude Sonnet 4.6 --- src/cdtools/tools/plotting/plotting.py | 232 +++++++++++++++++-------- 1 file changed, 162 insertions(+), 70 deletions(-) diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 70bc9442..30fbbac9 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -13,7 +13,7 @@ from matplotlib.widgets import Slider from matplotlib import ticker, patheffects from matplotlib import transforms as mtransforms -from matplotlib import colors, gridspec +from matplotlib import colors __all__ = [ @@ -31,7 +31,11 @@ ] -def colorize(z, use_cmocean=False, amplitude_scaling=lambda x: x): +def colorize( + z, + use_cmocean=True, + amplitude_scaling=lambda x: x, +): """ Returns RGB values for a complex color plot given a complex array This function returns a set of RGB values that can be used directly in a call to imshow based on an input complex numpy array (not a @@ -43,6 +47,9 @@ def colorize(z, use_cmocean=False, amplitude_scaling=lambda x: x): A complex-valued array use_cmocean : bool If true, uses the cmocean_phase colormap instead of hue + amplitude_scaling : callable + A function applied to the normalized amplitude before colorizing. + Default is the identity (no scaling). Returns ------- rgb : list(array) @@ -53,24 +60,25 @@ def colorize(z, use_cmocean=False, amplitude_scaling=lambda x: x): scaled_amp = amplitude_scaling(amp / np.max(amp)) ph = np.angle(z, deg=1) + # The offsets for both options are chosen to match, and to place red + # at the cut (at -pi and pi) if use_cmocean: base_rgb_values = [] for channel in range(3): - base_rgb_values.append(np.interp((ph + 180)%360, + # Flipping the cm_data matches the usual + # order of colors from hsv + base_rgb_values.append(np.interp((ph + 100)%360, np.linspace(0, 360, cm_data.shape [0]), - cm_data[:,channel])) + cm_data[::-1,channel])) base_rgb_values = np.dstack(base_rgb_values) rgb_values = base_rgb_values * scaled_amp[...,None] return rgb_values else: # HSV are values in range [0,1] - h = ((ph + 90) % 360) / 360 + h = ((ph + 170) % 360) / 360 s = 0.85 * np.ones_like(h) v = scaled_amp return hsv_to_rgb(np.dstack((h,s,v))) - - - def get_units_factor(units): @@ -121,7 +129,7 @@ def plot_image( title=None, additional_axis_labels=None, updateable_colorbar=True, - **kwargs + **kwargs, ): """Plots an image with a colorbar and on an appropriate spatial grid @@ -170,7 +178,8 @@ def plot_image( If shorter than the number of extra axes, remaining labels default to "Axis N". If not set, all labels default to "Axis N". \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed to plt.figure(\\**kwargs), if no figure is + provided Returns ------- @@ -183,7 +192,7 @@ def plot_image( im = im.detach().cpu().numpy() if fig is None: - fig = plt.figure() + fig = plt.figure(**kwargs) # Determine extra (non-image) dimensions and build per-axis slider map extra_dims = im.shape[:-2] @@ -193,7 +202,12 @@ def plot_image( # looks kind of confusing based on the layout that a Slider widget gets def ordinal(n): suffix = {1: 'st', 2: 'nd', 3: 'rd'} - return f"{n}{suffix.get(n % 10, 'th') if n % 100 not in (11, 12, 13) else 'th'}" + def get_suffix(n): + if n % 100 not in (11, 12, 13): + suffix.get(n % 10, 'th') + else: + return 'th' + return f"{n}{get_suffix(n)}" if additional_axis_labels is None: additional_axis_labels = [f'{ordinal(i)} Axis' for i in range(n_extra)] @@ -258,8 +272,6 @@ def make_plot(idx_list): except AttributeError: # Only support one layer of nested subfigures main_fig = fig.figure # get enclosing figure - bbox = fig.bbox - main_fig_bbox = main_fig.bbox fig_w, fig_h = main_fig.get_size_inches() total_width = fig.bbox.width * fig_w / main_fig.bbox.width total_height = fig.bbox.height * fig_h / main_fig.bbox.height @@ -394,8 +406,8 @@ def make_plot(idx_list): sliders.append(s) fig._sliders = sliders - # Slider callbacks guarded by _updating flag to prevent re-entry. - # Uses fig._make_plot so subsequent plot_image calls update the closure. + # Slider callbacks guarded by the _updating flag to prevent re-entry. + # Uses fig._make_plot so further plot_image calls update the closure. def make_slider_cb(axis_idx): def cb(val): if getattr(fig, '_updating', False): @@ -463,12 +475,23 @@ def on_action(event): for cid in fig.my_callbacks: fig.canvas.mpl_disconnect(cid) fig.my_callbacks = [] - fig.my_callbacks.append(fig.canvas.mpl_connect('key_press_event', on_action)) + fig.my_callbacks.append( + fig.canvas.mpl_connect('key_press_event', on_action) + ) return result_fig -def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Real Part (a.u.)', title=None, **kwargs): +def plot_real( + im, + fig=None, + basis=None, + units='$\\mu$m', + cmap='viridis', + cmap_label='Real Part (a.u.)', + title=None, + **kwargs, +): """Plots the real part of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -491,8 +514,10 @@ def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ Default is 'viridis', the colormap to plot with cmap_label : str What to label the colorbar when plotting + title : str, optional + Title for the axes. \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -505,8 +530,16 @@ def plot_real(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ title=title, **kwargs) - -def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Imaginary Part (a.u.)', title=None, **kwargs): +def plot_imag( + im, + fig=None, + basis=None, + units='$\\mu$m', + cmap='viridis', + cmap_label='Imaginary Part (a.u.)', + title=None, + **kwargs, +): """Plots the imaginary part of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -529,8 +562,10 @@ def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ Default is 'viridis', the colormap to plot with cmap_label : str What to label the colorbar when plotting + title : str, optional + Title for the axes. \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -543,7 +578,16 @@ def plot_imag(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_ title=title, **kwargs) -def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', cmap_label='Amplitude (a.u.)', title=None, **kwargs): +def plot_amplitude( + im, + fig=None, + basis=None, + units='$\\mu$m', + cmap='viridis', + cmap_label='Amplitude (a.u.)', + title=None, + **kwargs, +): """Plots the amplitude of a complex array with dimensions NxM If a figure is given explicitly, it will clear that existing figure and @@ -566,8 +610,10 @@ def plot_amplitude(im, fig = None, basis=None, units='$\\mu$m', cmap='viridis', Default is 'viridis', the colormap to plot with cmap_label : str What to label the colorbar when plotting + title : str, optional + Title for the axes. \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -590,7 +636,7 @@ def plot_phase( vmin=None, vmax=None, title=None, - **kwargs + **kwargs, ): """ Plots the phase of a complex array with dimensions NxM @@ -600,8 +646,8 @@ def plot_phase( If a basis is explicitly passed, the image will be plotted in real-space coordinates - If the cmap is entered as 'phase', it will plot the cmocean phase colormap, - and by default set the limits to [-pi,pi]. + If the cmap is entered as 'phase', it will plot the cmocean phase + colormap, and by default set the limits to [-pi,pi]. Parameters ---------- @@ -623,7 +669,7 @@ def plot_phase( Default is max(angle(im)), the maximum value for the colormap \\**kwargs - All other args are passed to fig.add_subplot(111, \\**kwargs) + All other args are passed through to plotting.plot_image Returns ------- @@ -644,10 +690,16 @@ def plot_phase( **kwargs) -def plot_amplitude_surfacenorm(): - pass - -def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, amplitude_scaling=lambda x: x, **kwargs): +def plot_colorized( + im, + fig=None, + basis=None, + units='$\\mu$m', + title=None, + use_cmocean=True, + amplitude_scaling=lambda x: x, + **kwargs, +): """ Plots the colorized version of a complex array with dimensions NxM The darkness corresponds to the intensity of the image, and the color @@ -669,6 +721,13 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, amplit Optional, the 3x2 probe basis units : str The length units to mark on the plot, default is um + title : str, optional + Title for the axes. + use_cmocean : bool + If true, uses the cmocean_phase colormap instead of hue + amplitude_scaling : callable + A function applied to the normalized amplitude before colorizing. + Default is the identity (no scaling). \\**kwargs All other args are passed to fig.add_subplot(111, \\**kwargs) @@ -677,7 +736,7 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, amplit used_fig : matplotlib.figure.Figure The figure object that was actually plotted to. """ - plot_func = lambda x: colorize(x, use_cmocean=True, + plot_func = lambda x: colorize(x, use_cmocean=use_cmocean, amplitude_scaling=amplitude_scaling) plot_fig = plot_image(im, plot_func=plot_func, fig=fig, basis=basis, cmap=cmocean_phase, vmin=-np.pi, vmax=np.pi, @@ -686,7 +745,8 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, amplit updateable_colorbar=False, **kwargs) # Find the colorbar - this is a bit hacky - cbar_ax = [ax for ax in plot_fig.get_axes() if hasattr(ax, '_colorbar')][0] + cbar_ax = [ax for ax in plot_fig.get_axes() + if hasattr(ax, '_colorbar')][0] # --- Replace the colorbar image --- # The internal image is a QuadMesh living on cbar.ax @@ -707,7 +767,18 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', title=None, amplit extent=[xg[0], xg[-1], -np.pi, np.pi]) -def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, invert_xaxis=True, clear_fig=True, label=None, color=None, marker='.', **kwargs): +def plot_translations( + translations, + fig=None, + units='$\\mu$m', + lines=True, + invert_xaxis=True, + clear_fig=True, + label=None, + color=None, + marker='.', + **kwargs, +): """Plots a set of probe translations in a nicely formatted way Parameters @@ -721,13 +792,15 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver lines : bool Whether to plot lines indicating the path taken invert_xaxis : bool - Default is True. This flips the x axis to match the convention from .cxi files of viewing the image from the beam's perspective + Default is True. This flips the x axis to match the convention from + .cxi files of viewing the image from the beam's perspective clear_fig : bool Default is True. Whether to clear the figure before plotting. label : str Default is None. A label to give the plotted markers for a legend. color : str - Default is None. The color to plot the markers in. By default, will follow the matplotlib color cycle. + Default is None. The color to plot the markers in. By default, will + follow the matplotlib color cycle. color : str Default is '.'. The marker style to plot with. \\**kwargs @@ -756,8 +829,6 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver except AttributeError: # Only support one layer of nested subfigures main_fig = fig.figure # get enclosing figure - bbox = fig.bbox - main_fig_bbox = main_fig.bbox fig_w, fig_h = main_fig.get_size_inches() total_width = fig.bbox.width * fig_w / main_fig.bbox.width total_height = fig.bbox.height * fig_h / main_fig.bbox.height @@ -804,14 +875,14 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver def plot_nanomap( - translations, - values, - fig=None, - cmap='viridis', - cmap_label=None, - units='$\\mu$m', - convention='probe', - invert_xaxis=True + translations, + values, + fig=None, + cmap='viridis', + cmap_label=None, + units='$\\mu$m', + convention='probe', + invert_xaxis=True, ): """Plots a set of nanomap data in a flexible way @@ -830,9 +901,11 @@ def plot_nanomap( units : str Default is um, units to report in (assuming input in m) convention : str - Default is 'probe', alternative is 'obj'. Whether the translations refer to the probe or object. + Default is 'probe', alternative is 'obj'. Whether the translations + refer to the probe or object. invert_xaxis : bool - Default is True. This flips the x axis to match the convention from .cxi files of viewing the image from the beam's perspective + Default is True. This flips the x axis to match the convention from + .cxi files of viewing the image from the beam's perspective Returns ------- @@ -866,8 +939,6 @@ def plot_nanomap( except AttributeError: # Only support one layer of nested subfigures main_fig = fig.figure # get enclosing figure - bbox = fig.bbox - main_fig_bbox = main_fig.bbox fig_w, fig_h = main_fig.get_size_inches() total_width = fig.bbox.width * fig_w / main_fig.bbox.width total_height = fig.bbox.height * fig_h / main_fig.bbox.height @@ -914,7 +985,22 @@ def plot_nanomap( return fig -def plot_nanomap_with_images(translations, get_image_func, values=None, mask=None, basis=None, fig=None, nanomap_units='$\\mu$m', image_units='$\\mu$m', convention='probe', image_title='Image', image_colorbar_title='Image Amplitude', nanomap_colorbar_title='Integrated Intensity', cmap='viridis', **kwargs): +def plot_nanomap_with_images( + translations, + get_image_func, + values=None, + mask=None, + basis=None, + fig=None, + nanomap_units='$\\mu$m', + image_units='$\\mu$m', + convention='probe', + image_title='Image', + image_colorbar_title='Image Amplitude', + nanomap_colorbar_title='Integrated Intensity', + cmap='viridis', + **kwargs, +): """Plots a nanomap, with an image or stack of images for each point In many situations, ptychography data or the output of ptychography @@ -954,15 +1040,16 @@ def plot_nanomap_with_images(translations, get_image_func, values=None, mask=Non # This gets the set of sizes for the points in the nanomap def calculate_sizes(idx): - bbox = axes[0].get_window_extent().transformed(fig.dpi_scale_trans.inverted()) - s0 = bbox.width * bbox.height / translations.shape[0] * 72**2 #72 is points per inch + bbox = axes[0].get_window_extent().transformed( + fig.dpi_scale_trans.inverted()) + # 72 is points per inch + s0 = bbox.width * bbox.height / translations.shape[0] * 72**2 s0 /= 4 # A rough value to make the size work out s = np.ones(translations.shape[0]) * s0 s[idx] *= 4 return s - def update_colorbar(im): # # This solves the problem of the colorbar being changed @@ -1019,7 +1106,7 @@ def update_colorbar(im): # Now we set up the second plot, which shows the individual # diffraction patterns axes[1].set_title(image_title) - #Plot in a basis if it exists, otherwise dont + # Plot in a basis if it exists, otherwise dont if basis is not None: axes[1].set_xlabel('X (' + image_units + ')') axes[1].set_ylabel('Y (' + image_units + ')') @@ -1080,9 +1167,9 @@ def update(idx, im_idx=None): # Get the new data for this index im = get_image_func(idx) if len(im.shape) >= 3: - if im_idx == None and hasattr(axes[1],'image_idx'): + if im_idx is None and hasattr(axes[1],'image_idx'): im_idx = axes[1].image_idx - elif im_idx == None: + elif im_idx is None: im_idx=0 axes[1].image_idx = im_idx axes[1].text_box.set_text(str(im_idx)) @@ -1097,8 +1184,7 @@ def update(idx, im_idx=None): ax_im.set_data(im) ax_im.norecurse=False update_colorbar(ax_im) - #plt.draw() - + # plt.draw() # # Now we define the functions to handle various kinds of events @@ -1107,7 +1193,10 @@ def update(idx, im_idx=None): # We start by creating the slider here, so it can be used # by the update hooks. - slider = Slider(axslider, 'Image #', 0, translations.shape[0]-1, valstep=1, valfmt="%d") + slider = Slider( + axslider,'Image #', 0, translations.shape[0]-1, + valstep=1, valfmt="%d" + ) # This handles scroll wheel and keypress events def on_action(event): @@ -1119,13 +1208,13 @@ def on_action(event): im = im.reshape(-1,im.shape[-2],im.shape[-1]) im_idx = axes[1].image_idx - if (event.key == 'up' - or (hasattr(event, 'button') and event.button == 'up') - or event.key == 'left'): + if event.key == "left" or ( + hasattr(event, "button") and event.button == "left" + ): im_idx = (im_idx - 1) % im.shape[0] - if (event.key == 'down' - or (hasattr(event, 'button') and event.button == 'down') - or event.key == 'right'): + if event.key == "right" or ( + hasattr(event, "button") and event.button == "right" + ): im_idx = (im_idx + 1) % im.shape[0] axes[1].image_idx=im_idx @@ -1139,9 +1228,13 @@ def on_action(event): if not hasattr(event, 'key'): event.key = None - if event.key == 'up' or event.button == 'up' or event.key == 'left': + if event.key == "left" or ( + hasattr(event, "button") and event.button == "left" + ): idx = slider.val - 1 - elif event.key == 'down' or event.button == 'down' or event.key == 'right': + elif event.key == "right" or ( + hasattr(event, "button") and event.button == "right" + ): idx = slider.val + 1 else: # This prevents errors from being thrown on irrelevant key @@ -1159,7 +1252,6 @@ def on_pick(event): if event.mouseevent.button == 1: slider.set_val(event.ind[0]) - # Here we connect the various update functions cid1 = fig.canvas.mpl_connect('pick_event',on_pick) cid2 = fig.canvas.mpl_connect('key_press_event',on_action) From 25c7b66c7c5f56de7221510ec36f34fa71e75331 Mon Sep 17 00:00:00 2001 From: allevitan Date: Tue, 24 Mar 2026 11:10:23 +0100 Subject: [PATCH 17/23] Do a final review of all the examples to ensure they work and are well structured --- .flake8 | 2 +- examples/gold_ball_split.py | 6 +++--- examples/gold_ball_synthesize.py | 6 +++--- examples/near_field_ptycho.py | 1 - examples/simple_ptycho.py | 2 +- examples/tutorial_finale.py | 6 ++++-- examples/tutorial_simple_ptycho.py | 28 ++++++++++++++++++++-------- pyproject.toml | 3 ++- src/cdtools/models/base.py | 11 +++++++++++ src/cdtools/models/fancy_ptycho.py | 10 ++++++++++ src/cdtools/models/simple_ptycho.py | 25 ++++++++++--------------- 11 files changed, 65 insertions(+), 35 deletions(-) diff --git a/.flake8 b/.flake8 index 9214cb9e..3e08f47c 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,2 @@ [flake8] -ignore = E501, W503 \ No newline at end of file +ignore = E501, W503, E731 \ No newline at end of file diff --git a/examples/gold_ball_split.py b/examples/gold_ball_split.py index 9fc5b083..fde19c6a 100644 --- a/examples/gold_ball_split.py +++ b/examples/gold_ball_split.py @@ -32,9 +32,9 @@ model.weights.requires_grad = False - device = 'cuda' - model.to(device=device) - dataset.get_as(device=device) + if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') # Create the reconstructor recon = cdtools.reconstructors.AdamReconstructor(model, dataset) diff --git a/examples/gold_ball_synthesize.py b/examples/gold_ball_synthesize.py index 3a2c68aa..41319056 100644 --- a/examples/gold_ball_synthesize.py +++ b/examples/gold_ball_synthesize.py @@ -5,11 +5,11 @@ # We load all three reconstructions half_1 = cdtools.tools.data.h5_to_nested_dict( - f'example_reconstructions/gold_balls_half_1.h5') + 'example_reconstructions/gold_balls_half_1.h5') half_2 = cdtools.tools.data.h5_to_nested_dict( - f'example_reconstructions/gold_balls_half_2.h5') + 'example_reconstructions/gold_balls_half_2.h5') full = cdtools.tools.data.h5_to_nested_dict( - f'example_reconstructions/gold_balls_full.h5') + 'example_reconstructions/gold_balls_full.h5') # This defines the region of recovered object to use for the analysis. pad = 260 diff --git a/examples/near_field_ptycho.py b/examples/near_field_ptycho.py index b5748131..4534ee58 100644 --- a/examples/near_field_ptycho.py +++ b/examples/near_field_ptycho.py @@ -6,7 +6,6 @@ dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) dataset.inspect() -plt.show() # Setting near_field equal to True uses an angular spectrum propagator in # lieu of the default Fourier-transform propagator for far-field ptychography. diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index af8aeb11..8218a229 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -24,7 +24,7 @@ dataset.get_as(device='cuda') model.inspect(dataset) -print('hi') + # We run the reconstruction for loss in model.Adam_optimize(30, dataset, batch_size=10): # We print a quick report of the optimization status diff --git a/examples/tutorial_finale.py b/examples/tutorial_finale.py index 9f698b76..4e95ec63 100644 --- a/examples/tutorial_finale.py +++ b/examples/tutorial_finale.py @@ -1,5 +1,6 @@ from tutorial_basic_ptycho_dataset import BasicPtychoDataset from tutorial_simple_ptycho import SimplePtycho +import torch as t from h5py import File from matplotlib import pyplot as plt @@ -11,8 +12,9 @@ model = SimplePtycho.from_dataset(dataset) -model.to(device='cuda') -dataset.get_as(device='cuda') +if t.cuda.is_available(): + model.to(device='cuda') + dataset.get_as(device='cuda') for loss in model.Adam_optimize(10, dataset): model.inspect(dataset) diff --git a/examples/tutorial_simple_ptycho.py b/examples/tutorial_simple_ptycho.py index 7e8a3e5d..67544d74 100644 --- a/examples/tutorial_simple_ptycho.py +++ b/examples/tutorial_simple_ptycho.py @@ -108,14 +108,26 @@ def loss(self, real_data, sim_data): # This lists all the plots to display on a call to model.inspect() plot_list = [ - ('Probe Amplitude', - lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)), - ('Probe Phase', - lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)), - ('Object Phase', - lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis)) + { + 'title': 'Probe Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.probe, fig, basis=self.probe_basis), + }, + { + 'title': 'Probe Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.probe, fig, basis=self.probe_basis) + }, + { + 'title': 'Object Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.obj, fig, basis=self.probe_basis) + }, + { + 'title': 'Object Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.obj, fig, basis=self.probe_basis) + }, ] def save_results(self, dataset): diff --git a/pyproject.toml b/pyproject.toml index 9ba1438c..aacaecfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,4 +50,5 @@ docs = [ ] [tool.ruff] -line-length = 79 \ No newline at end of file +line-length = 79 +ignore = ["E501", "E731"] \ No newline at end of file diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index d3cb122e..e7516593 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -61,6 +61,16 @@ class CDIModel(t.nn.Module): """ def __init__(self, panel_plot_mode=False, plot_level=np.inf): + """Initializes the CDIModel base class. + + Parameters + ---------- + panel_plot_mode : bool, default: False + If True, plot_panel_list entries are rendered as multi-subplot + figures. If False, each subplot is rendered as its own figure. + plot_level : float, default: np.inf + Only plots whose plot_level <= this value are shown. + """ super(CDIModel, self).__init__() self.loss_history = [] @@ -640,6 +650,7 @@ def inspect(self, dataset=None, replot_all=False, min_interval=None): def _is_backend_interactive( self ): + """Returns True if the current matplotlib backend is interactive.""" backend = matplotlib.get_backend().lower() try: # matplotlib >= 3.9 diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 3edb9ac5..b3635917 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -848,6 +848,15 @@ def tidy_probes(self): def get_probe_intensities(self): + """Returns the effective probe intensity at each scan position. + + Handles both the simple (1D weights) and OPRP (2D weights) cases. + + Returns + ------- + probe_intensities : np.ndarray + Array of probe intensities, one per scan position. + """ if not hasattr(self, 'weights'): raise NotImplementedError( "I don't know how to handle having no weights") @@ -920,6 +929,7 @@ def get_probes(idx): def plot_illumination_intensity(self, fig, dataset): + """Plots the probe intensity nanomap. Only used to make a plot for the plot list.""" p.plot_nanomap( self.corrected_translations(dataset), self.get_probe_intensities(), diff --git a/src/cdtools/models/simple_ptycho.py b/src/cdtools/models/simple_ptycho.py index 234ed5d7..734ba20c 100644 --- a/src/cdtools/models/simple_ptycho.py +++ b/src/cdtools/models/simple_ptycho.py @@ -110,28 +110,23 @@ def loss(self, real_data, sim_data): plot_list = [ { 'title': 'Probe Amplitude', - 'subplot' : (0, 0), 'plot_func': lambda self, fig: - p.plot_amplitude(self.probe, fig, - basis=self.probe_basis), - }, { + p.plot_amplitude(self.probe, fig, basis=self.probe_basis), + }, + { 'title': 'Probe Phase', - 'subplot' : (0, 1), 'plot_func': lambda self, fig: - p.plot_phase(self.probe, fig, - basis=self.probe_basis) - }, { + p.plot_phase(self.probe, fig, basis=self.probe_basis) + }, + { 'title': 'Object Amplitude', - 'subplot' : (1, 0), 'plot_func': lambda self, fig: - p.plot_amplitude(self.obj, fig, - basis=self.probe_basis) - }, { + p.plot_amplitude(self.obj, fig, basis=self.probe_basis) + }, + { 'title': 'Object Phase', - 'subplot' : (1, 1), 'plot_func': lambda self, fig: - p.plot_phase(self.obj, fig, - basis=self.probe_basis) + p.plot_phase(self.obj, fig, basis=self.probe_basis) }, ] From 29e9fddd722ec78eefe541981d28041de5fcdbba Mon Sep 17 00:00:00 2001 From: allevitan Date: Tue, 24 Mar 2026 16:01:45 +0100 Subject: [PATCH 18/23] Add test coverage for plot_translations, plot_nanomap, and plot_nanomap_with_images Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/test_plotting.py | 99 +++++++++++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 7 deletions(-) diff --git a/tests/tools/test_plotting.py b/tests/tools/test_plotting.py index baca33dd..8a1567f0 100644 --- a/tests/tools/test_plotting.py +++ b/tests/tools/test_plotting.py @@ -13,13 +13,22 @@ def test_plot_amplitude(show_plot): plotting.plot_amplitude(im, basis=np.array([[0, -1], [-1, 0], [0, 0]]), title='Test Amplitude') if show_plot: plt.show() - - # Test with numpy array - im = scipy.datasets.ascent().astype(np.complex128) + plt.close('all') + + # Test with numpy array and an extra dimension + im = np.stack([scipy.datasets.ascent().astype(np.complex128)]*3, axis=0) plotting.plot_amplitude(im, title='Test Amplitude') if show_plot: plt.show() - + plt.close('all') + + # Test with pytorch tensor and two extra dimensions + im = t.as_tensor(np.stack([im]*5, axis=0)) + plotting.plot_amplitude(im, title='Test Amplitude', + additional_axis_labels=['Hi','There']) + if show_plot: + plt.show() + plt.close('all') def test_plot_phase(show_plot): # Test with tensor @@ -27,12 +36,14 @@ def test_plot_phase(show_plot): plotting.plot_phase(im, title='Test Phase') if show_plot: plt.show() - + plt.close('all') + # Test with numpy array im = initializers.gaussian([512, 512], [200, 200], amplitude=100, curvature=[.1, .1]).numpy() plotting.plot_phase(im, title='Test Phase', basis=np.array([[0, -1], [-1, 0], [0, 0]])) if show_plot: plt.show() + plt.close('all') def test_plot_colorized(show_plot): @@ -42,9 +53,83 @@ def test_plot_colorized(show_plot): plotting.plot_colorized(im, title='Test Colorize', basis=np.array([[0, -1], [-1, 0], [0, 0]])) if show_plot: plt.show() + plt.close('all') - # Test with numpy array + # Test with numpy array and hsv im = im.numpy() - plotting.plot_colorized(im, title='Test Colorize') + plotting.plot_colorized(im, title='Test Colorize', use_cmocean=False) + if show_plot: + plt.show() + plt.close('all') + + +def test_plot_translations(show_plot): + rng = np.random.default_rng(0) + trans_np = rng.uniform(-5e-6, 5e-6, (20, 2)) + trans_t = t.as_tensor(trans_np) + + # numpy, defaults + plotting.plot_translations(trans_np) + if show_plot: + plt.show() + plt.close('all') + + # torch tensor and reuse figure + fig = plotting.plot_translations(trans_t) + plotting.plot_translations(trans_np, lines=False, color='red', label='scan', fig=fig, clear_fig=False) + if show_plot: + plt.show() + plt.close('all') + +def test_plot_nanomap(show_plot): + rng = np.random.default_rng(0) + trans_np = rng.uniform(-5e-6, 5e-6, (20, 2)) + values_np = np.random.default_rng(1).uniform(0, 1, 20) + trans_t = t.as_tensor(trans_np) + values_t = t.as_tensor(values_np) + + # numpy, defaults + plotting.plot_nanomap(trans_np, values_np) + if show_plot: + plt.show() + plt.close('all') + + # torch tensors + plotting.plot_nanomap(trans_t, values_t, units='nm', cmap_label='Intensity', convention='sample') + if show_plot: + plt.show() + plt.close('all') + + +def test_plot_nanomap_with_images(show_plot): + rng = np.random.default_rng(0) + trans_np = rng.uniform(-5e-6, 5e-6, (20, 2)) + values_np = np.random.default_rng(1).uniform(0, 1, 20) + # plot_nanomap_with_images requires tensor translations + trans_t = t.as_tensor(trans_np) + values_t = t.as_tensor(values_np) + + def get_image_2d(i): + return np.random.default_rng(i).uniform(0, 1, (32, 32)) + + def get_image_3d(i): + return np.random.default_rng(i).uniform(0, 1, (4, 32, 32)) + + # basic call, no values + plotting.plot_nanomap_with_images(trans_np, get_image_2d) + if show_plot: + plt.show() + plt.close('all') + + # with explicit values + plotting.plot_nanomap_with_images(trans_t, get_image_2d, values=values_np) + if show_plot: + plt.show() + plt.close('all') + + # 3D image stack + fig = plt.figure(figsize=(11,7)) + plotting.plot_nanomap_with_images(trans_np, get_image_3d, values=values_t, fig=fig) if show_plot: plt.show() + plt.close('all') From ad2b09f2f0cdedfde81e89f76c7923c2d7069d6b Mon Sep 17 00:00:00 2001 From: allevitan Date: Tue, 24 Mar 2026 16:19:26 +0100 Subject: [PATCH 19/23] Update the tests to work better when checking the model plotting, and make sure to cover panel_plot_mode=False --- src/cdtools/tools/plotting/plotting.py | 16 +++++++++-- tests/models/test_fancy_ptycho.py | 28 +++++++++++------- tests/models/test_simple_ptycho.py | 8 ++++-- tests/test_reconstructors.py | 40 +++++++++++++++++--------- 4 files changed, 64 insertions(+), 28 deletions(-) diff --git a/src/cdtools/tools/plotting/plotting.py b/src/cdtools/tools/plotting/plotting.py index 30fbbac9..a4847b79 100644 --- a/src/cdtools/tools/plotting/plotting.py +++ b/src/cdtools/tools/plotting/plotting.py @@ -204,7 +204,7 @@ def ordinal(n): suffix = {1: 'st', 2: 'nd', 3: 'rd'} def get_suffix(n): if n % 100 not in (11, 12, 13): - suffix.get(n % 10, 'th') + return suffix.get(n % 10, 'th') else: return 'th' return f"{n}{get_suffix(n)}" @@ -1074,7 +1074,11 @@ def update_colorbar(im): # First we set up the left-hand plot, which shows an overview map axes[0].set_title('Relative Displacement Map') - translations = translations.detach().cpu().numpy() + if isinstance(translations, t.Tensor): + translations = translations.detach().cpu().numpy() + + if isinstance(values, t.Tensor): + values = values.detach().cpu().numpy() if convention.lower() != 'probe': translations = translations * -1 @@ -1082,9 +1086,13 @@ def update_colorbar(im): s = calculate_sizes(0) nanomap_units_factor = get_units_factor(nanomap_units) + + # Suppresses a warning from ax.scatter() + if values is None: + cmap = None nanomap = axes[0].scatter(nanomap_units_factor * translations[:,0], nanomap_units_factor * translations[:,1], - s=s,c=values, picker=True, cmap=cmap) + s=s, c=values, picker=True, cmap=cmap) axes[0].invert_xaxis() axes[0].set_facecolor('k') @@ -1101,7 +1109,9 @@ def update_colorbar(im): # This seems to do a good job of leaving the appropriate space # where the colorbar should have been to avoid stretching the # nanomap plot, while still not showing the (now useless) colorbar. + pos = axes[0].get_position() cb1.remove() + axes[0].set_position(pos) # Now we set up the second plot, which shows the individual # diffraction patterns diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index db05a75e..4de9f4af 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -1,5 +1,7 @@ import pytest +import time import torch as t +from matplotlib import pyplot as plt import cdtools @@ -67,6 +69,7 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): units='mm', obj_view_crop=-50, use_qe_mask=True, # test this in the case where no qe mask is defined + panel_plot_mode=True, # test with panel plot mode ) print('Running reconstruction on provided reconstruction_device,', @@ -76,24 +79,26 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) for loss in model.Adam_optimize(25, dataset, lr=0.001, batch_size=50): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # If this fails, the reconstruction has gotten worse assert model.loss_history[-1] < 0.0013 @@ -110,6 +115,7 @@ def test_near_field_ptycho(near_field_ptycho_cxi, reconstruction_device, show_pl n_modes=1, near_field=True, propagation_distance=3.65e-3, # 3.65 downstream from focus + panel_plot_mode=False, # test without panel plot mode ) print('Running reconstruction on provided reconstruction_device,', @@ -119,19 +125,21 @@ def test_near_field_ptycho(near_field_ptycho_cxi, reconstruction_device, show_pl for loss in model.Adam_optimize(100, dataset, lr=0.04, batch_size=10): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # If this fails, the reconstruction has gotten worse assert model.loss_history[-1] < 0.005 diff --git a/tests/models/test_simple_ptycho.py b/tests/models/test_simple_ptycho.py index b6b18680..225e463a 100644 --- a/tests/models/test_simple_ptycho.py +++ b/tests/models/test_simple_ptycho.py @@ -1,5 +1,7 @@ import pytest +import time import torch as t +from matplotlib import pyplot as plt import cdtools @@ -18,12 +20,14 @@ def test_simple_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): for loss in model.Adam_optimize(100, dataset, batch_size=10): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # If this fails, the reconstruction got worse assert model.loss_history[-1] < 0.013 diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py index 132a8c27..3f9245b6 100644 --- a/tests/test_reconstructors.py +++ b/tests/test_reconstructors.py @@ -1,4 +1,5 @@ import pytest +import time import cdtools import torch as t import numpy as np @@ -36,7 +37,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): probe_support_radius=50, propagation_distance=2e-6, units='um', - probe_fourier_crop=pad + probe_fourier_crop=pad, + panel_plot_mode=False, # At least one check without panel plot mode ) model.translation_offsets.data += 0.7 * \ @@ -67,8 +69,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr_tup[i], batch_size=batch_size_tup[i]): print(model_recon.report()) - if show_plot and model_recon.epoch % 10 == 0: - model_recon.inspect(dataset) + if show_plot: + model_recon.inspect(dataset, min_interval=10) # Check hyperparameter update assert recon.optimizer.param_groups[0]['lr'] == lr_tup[i] @@ -86,6 +88,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): if show_plot: model_recon.inspect(dataset) model_recon.compare(dataset) + time.sleep(3) + plt.close('all') # ******* Reconstructions with CDIModel.Adam_optimize ******* print('Running reconstruction using CDIModel.Adam_optimize on provided' + @@ -99,14 +103,16 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr_tup[i], batch_size=batch_size_tup[i]): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # Ensure equivalency between the model reconstructions during the first # pass, where they should be identical @@ -170,8 +176,8 @@ def test_LBFGS_RPI(optical_data_ss_cxi, for loss in recon.optimize(iterations, lr=0.4, regularization_factor=reg_factor_tup[i]): - if show_plot and i == 0: - model_recon.inspect(dataset) + if show_plot: + model_recon.inspect(dataset, min_interval=10) print(model_recon.report()) # Check hyperparameter update (or lack thereof) @@ -180,6 +186,8 @@ def test_LBFGS_RPI(optical_data_ss_cxi, if show_plot: model_recon.inspect(dataset) model_recon.compare(dataset) + time.sleep(3) + plt.close('all') # Check model pointing assert id(model_recon) == id(recon.model) @@ -193,13 +201,15 @@ def test_LBFGS_RPI(optical_data_ss_cxi, dataset, lr=0.4, regularization_factor=reg_factor_tup[i]): # noqa - if show_plot and i == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) print(model.report()) if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # Check loss equivalency between the two reconstructions assert np.allclose(model.loss_history[:epoch_tup[0]], model_recon.loss_history[:epoch_tup[0]]) @@ -271,8 +281,8 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr, batch_size=batch_size): print(model_recon.report()) - if show_plot and model_recon.epoch % 10 == 0: - model_recon.inspect(dataset) + if show_plot: + model_recon.inspect(dataset, min_interval=10) # Check hyperparameter update assert recon.optimizer.param_groups[0]['lr'] == lr @@ -290,6 +300,8 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): if show_plot: model_recon.inspect(dataset) model_recon.compare(dataset) + time.sleep(3) + plt.close('all') # ******* Reconstructions with cdtools.CDIModel.SGD_optimize ******* print('Running reconstruction using CDIModel.SGD_optimize on provided' + @@ -301,14 +313,16 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): lr=lr, batch_size=batch_size): print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) + if show_plot: + model.inspect(dataset, min_interval=10) model.tidy_probes() if show_plot: model.inspect(dataset) model.compare(dataset) + time.sleep(3) + plt.close('all') # Ensure equivalency between the model reconstructions assert np.allclose(model_recon.loss_history[-1], model.loss_history[-1]) From 603486e8241a6c6c0a8f020b9959523214528e5b Mon Sep 17 00:00:00 2001 From: allevitan Date: Tue, 24 Mar 2026 16:55:56 +0100 Subject: [PATCH 20/23] Change the default colormap for exponentiated objects to match that for un-exponentiated objects, and also update test_fancy_ptycho to show all the plots (not just plot level 2) --- src/cdtools/models/fancy_ptycho.py | 7 +++++-- tests/models/test_fancy_ptycho.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index b3635917..58f4f2a2 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -1004,7 +1004,8 @@ def plot_translations_and_originals(self, fig, dataset): basis=self.obj_basis, additional_axis_labels=['Mode #',], units=self.units, - cmap='cividis'), + cmap='cividis', + ), 'condition': lambda self: self.exponentiate_obj, }, { @@ -1015,7 +1016,9 @@ def plot_translations_and_originals(self, fig, dataset): fig=fig, basis=self.obj_basis, additional_axis_labels=['Mode #',], - units=self.units), + units=self.units, + cmap='viridis_r', + ), 'condition': lambda self: self.exponentiate_obj, }, { diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index 4de9f4af..081d85b7 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -69,7 +69,8 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): units='mm', obj_view_crop=-50, use_qe_mask=True, # test this in the case where no qe mask is defined - panel_plot_mode=True, # test with panel plot mode + panel_plot_mode=True, # test with panel plot mode, + plot_level=4, # test with all plots ) print('Running reconstruction on provided reconstruction_device,', From f55b35f0dde79261eb60d772fb2aeee66fb5f55f Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Tue, 24 Mar 2026 16:56:41 +0100 Subject: [PATCH 21/23] Stop stopping at each plot to show it --- tests/tools/test_plotting.py | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/tests/tools/test_plotting.py b/tests/tools/test_plotting.py index 8a1567f0..6b5b7519 100644 --- a/tests/tools/test_plotting.py +++ b/tests/tools/test_plotting.py @@ -11,21 +11,16 @@ def test_plot_amplitude(show_plot): # Test with tensor im = t.as_tensor(scipy.datasets.ascent(), dtype=t.complex128) plotting.plot_amplitude(im, basis=np.array([[0, -1], [-1, 0], [0, 0]]), title='Test Amplitude') - if show_plot: - plt.show() - plt.close('all') # Test with numpy array and an extra dimension im = np.stack([scipy.datasets.ascent().astype(np.complex128)]*3, axis=0) plotting.plot_amplitude(im, title='Test Amplitude') - if show_plot: - plt.show() - plt.close('all') # Test with pytorch tensor and two extra dimensions im = t.as_tensor(np.stack([im]*5, axis=0)) plotting.plot_amplitude(im, title='Test Amplitude', additional_axis_labels=['Hi','There']) + if show_plot: plt.show() plt.close('all') @@ -34,13 +29,11 @@ def test_plot_phase(show_plot): # Test with tensor im = initializers.gaussian([512, 512], [200, 200], amplitude=100, curvature=[.1, .1]) plotting.plot_phase(im, title='Test Phase') - if show_plot: - plt.show() - plt.close('all') # Test with numpy array im = initializers.gaussian([512, 512], [200, 200], amplitude=100, curvature=[.1, .1]).numpy() plotting.plot_phase(im, title='Test Phase', basis=np.array([[0, -1], [-1, 0], [0, 0]])) + if show_plot: plt.show() plt.close('all') @@ -51,13 +44,11 @@ def test_plot_colorized(show_plot): gaussian = initializers.gaussian([512, 512], [200, 200], amplitude=100, curvature=[.1, .1]) im = gaussian * t.as_tensor(scipy.datasets.ascent(), dtype=t.complex64) plotting.plot_colorized(im, title='Test Colorize', basis=np.array([[0, -1], [-1, 0], [0, 0]])) - if show_plot: - plt.show() - plt.close('all') # Test with numpy array and hsv im = im.numpy() plotting.plot_colorized(im, title='Test Colorize', use_cmocean=False) + if show_plot: plt.show() plt.close('all') @@ -70,13 +61,11 @@ def test_plot_translations(show_plot): # numpy, defaults plotting.plot_translations(trans_np) - if show_plot: - plt.show() - plt.close('all') # torch tensor and reuse figure fig = plotting.plot_translations(trans_t) plotting.plot_translations(trans_np, lines=False, color='red', label='scan', fig=fig, clear_fig=False) + if show_plot: plt.show() plt.close('all') @@ -90,12 +79,10 @@ def test_plot_nanomap(show_plot): # numpy, defaults plotting.plot_nanomap(trans_np, values_np) - if show_plot: - plt.show() - plt.close('all') # torch tensors plotting.plot_nanomap(trans_t, values_t, units='nm', cmap_label='Intensity', convention='sample') + if show_plot: plt.show() plt.close('all') @@ -117,19 +104,14 @@ def get_image_3d(i): # basic call, no values plotting.plot_nanomap_with_images(trans_np, get_image_2d) - if show_plot: - plt.show() - plt.close('all') # with explicit values plotting.plot_nanomap_with_images(trans_t, get_image_2d, values=values_np) - if show_plot: - plt.show() - plt.close('all') # 3D image stack fig = plt.figure(figsize=(11,7)) plotting.plot_nanomap_with_images(trans_np, get_image_3d, values=values_t, fig=fig) + if show_plot: plt.show() plt.close('all') From 4172a9a11a04e5c4e309dafd5d7c66e70f304c44 Mon Sep 17 00:00:00 2001 From: allevitan Date: Tue, 24 Mar 2026 19:30:14 +0100 Subject: [PATCH 22/23] A few small changes to revert unimportant edits and fix linting issues --- examples/fancy_ptycho.py | 2 +- examples/near_field_ptycho.py | 2 +- examples/simple_ptycho.py | 2 +- src/cdtools/models/fancy_ptycho.py | 5 +---- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index 1ca88fea..d6dfc748 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -13,7 +13,7 @@ probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm units='mm', # Set the units for the live plots - obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix, + obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix ) if t.cuda.is_available(): diff --git a/examples/near_field_ptycho.py b/examples/near_field_ptycho.py index 4534ee58..af0012d2 100644 --- a/examples/near_field_ptycho.py +++ b/examples/near_field_ptycho.py @@ -27,7 +27,7 @@ propagation_distance=3.65e-3, # 3.65 downstream from focus units='um', # Set the units for the live plots obj_view_crop=-35, - panel_plot_mode=True, + panel_plot_mode=True, # Set to False to get individual figures ) if t.cuda.is_available(): diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index 8218a229..41c0c6c2 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -26,7 +26,7 @@ model.inspect(dataset) # We run the reconstruction -for loss in model.Adam_optimize(30, dataset, batch_size=10): +for loss in model.Adam_optimize(100, dataset, batch_size=10): # We print a quick report of the optimization status print(model.report()) # And liveplot the updates to the model as they happen diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 58f4f2a2..3387b15a 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -7,8 +7,6 @@ from matplotlib import pyplot as plt from datetime import datetime import numpy as np -from scipy import linalg as sla -from copy import copy __all__ = ['FancyPtycho'] @@ -557,7 +555,7 @@ def interaction(self, index, translations, *args): else: try: Ws = t.ones(len(index)) # I'm positive this introduced a bug - except: + except TypeError: Ws = 1 if self.weights is None or len(self.weights[0].shape) == 0: @@ -782,7 +780,6 @@ def tidy_probes(self): # First we treat the incoherent but stable case, where the weights are # just one per-shot overall weight if self.weights.dim() == 1: - probe = self.probe.detach().cpu().numpy() ortho_probes = analysis.orthogonalize_probes(self.probe.detach()) self.probe.data = ortho_probes return From 0bd8df498cc514ec660a0259c616184fede57b4e Mon Sep 17 00:00:00 2001 From: allevitan Date: Tue, 24 Mar 2026 21:39:04 +0100 Subject: [PATCH 23/23] Updated the documentation to reflect the changes to the plotting system --- docs/source/examples.rst | 30 +++++++++++++++++++++------ docs/source/tutorial.rst | 40 ++++++++++++++++++++++++------------ examples/gold_ball_ptycho.py | 2 +- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/docs/source/examples.rst b/docs/source/examples.rst index a3ef41eb..4c101c90 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -31,9 +31,9 @@ When reading this script, note the basic workflow. After the data is loaded, a m Next, the model is moved to the GPU using the :code:`model.to` function. Any device understood by :code:`torch.Tensor.to` can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the :code:`dataset.get_as` function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using :code:`dataset.to`, but the speedup is empirically quite small. -Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. +Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. Inside the loop, :code:`model.inspect(dataset)` is called every epoch to live-update a set of plots showing the current state of the model parameters. -Finally, the results can be studied using :code:`model.inspect(dataset)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. +Finally, :code:`model.compare(dataset)` is called to show how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. Fancy Ptycho @@ -63,11 +63,13 @@ By default, FancyPtycho will also optimize over the following model parameters, These corrections can be turned off (on) by calling :code:`model..requires_grad = False #(True)`. -Note as well two other changes that are made in this script, when compared to `simple_ptycho.py`. First, a `Reconstructor` object is explicitly created, in this case an `AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to `Reconstructor.optimize()`. +Note as well two other changes that are made in this script, when compared to :code:`simple_ptycho.py`. First, a :code:`Reconstructor` object is explicitly created, in this case an :code:`AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to :code:`Reconstructor.optimize()`. -We use this pattern, instead of the simpler call to `model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object. +We use this pattern, instead of the simpler call to :code:`model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object. -In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. `model.LBFGS_optimize()`). +In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. :code:`model.LBFGS_optimize()`). + +Note also the use of :code:`min_interval=10` in the calls to :code:`model.inspect(dataset)`. Because generating plots can be expensive, passing a minimum interval (in seconds) prevents excessive replots. Finally, the call to :code:`model.inspect(dataset, replot_all=True)` at the end of the script reopens any plot windows that the user may have closed during the reconstruction, so that all results are visible at the end. Gold Ball Ptycho @@ -77,11 +79,27 @@ This script shows how the FancyPtycho model might be used in a realistic situati .. literalinclude:: ../../examples/gold_ball_ptycho.py -Note, in particular, the use of :code:`model.save_on_exception` and :code:`model.save_to_h5` to save the results of the reconstruction. If a different file format is required, :code:`model.save_results` will save to a pure-python dictionary. +Note first the explicit addition of the :code:`plot_level=2` argument in the call to :code:`FancyPtycho.from_dataset`. This value controls which plots are generated. With :code:`plot_level=1`, only the main results are shown - :code:`plot_level=2` shows some more advanced monitoring of the error correction terms (background, position error, etc.), and :code:`plot_level=3` shows all registered plots. + +Note also the use of :code:`model.save_on_exception` and :code:`model.save_to_h5` to save the results of the reconstruction. If a different file format is required, :code:`model.save_results` will save to a pure-python dictionary which can be processed further. Finally, note that there are several small adjustments made to the script to counteract particular sources of error that are present in this dataset, for example the raster grid pathology caused by the scan pattern used. Also note that not every mixin is needed every time - in this case, we turn off optimization of the :code:`weights` parameter. +Near-Field Ptycho +----------------- + +This script shows how the FancyPtycho model can be used on a typical near-field ptychography (also known as Fresnel ptychography) dataset. + +.. literalinclude:: ../../examples/near_field_ptycho.py + +The major change here is the setting of the :code:`near_field=True` argument to :code:`FancyPtycho.from_dataset`. This changes the propagator to a near-field propagator. As noted in the comments, if :code:`propagation_distance` is not set, the model will assume a standard near-field geomtry with flat illumination. + +If :code:`propagation_distance` is set, it will assume a Fresnel scaling theorem-type geometry, with :code:`propagation_distance` as the focus-to-sample distance, and the distance set in the dataset object as the sample-to-detector distance. + +Finally, note the addition of the :code:`panel_plot_mode=True` argument. This is the default mode, and returns the plots in a panel format, good for easily monitoring the progress of a reconstruction. If individual plots are needed for use in presentations, papers, or otherwise, setting :code:`panel_plot_mode=False` will plot each output in it's own window. + + Gold Ball Split --------------- diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 1533e838..8a2fa863 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -367,28 +367,42 @@ The forward propagator maps the exit wave to the wave at the surface of the dete Plotting ++++++++ -The base CDIModel class has a function, :code:`model.inspect()`, which looks for a class variable called :code:`plot_list` and plots everything contained within. The plot list should be formatted as a list of tuples, with each tuple containing: +The base CDIModel class has a function, :code:`model.inspect()`, which looks for a class variable called :code:`plot_list` and plots everything contained within. The plot list should be formatted as a list of dictionaries, with each dictionary containing: + +* :code:`'title'`: the title of the plot +* :code:`'plot_func'`: a function that takes in the model (and optionally a figure) and generates the relevant plot +* :code:`'condition'` (optional): a function that takes in the model and returns whether or not the plot should be generated -* The title of the plot -* A function that takes in the model and generates the relevant plot -* Optional, a function that takes in the model and returns whether or not the plot should be generated - .. code-block:: python # This lists all the plots to display on a call to model.inspect() plot_list = [ - ('Probe Amplitude', - lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)), - ('Probe Phase', - lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)), - ('Object Amplitude', - lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)), - ('Object Phase', - lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis)) + { + 'title': 'Probe Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis), + }, + { + 'title': 'Probe Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.probe, fig=fig, basis=self.probe_basis), + }, + { + 'title': 'Object Amplitude', + 'plot_func': lambda self, fig: + p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis), + }, + { + 'title': 'Object Phase', + 'plot_func': lambda self, fig: + p.plot_phase(self.obj, fig=fig, basis=self.probe_basis), + }, ] In this case, we've made use of the convenience plotting functions defined in :code:`tools.plotting`. +More advanced models like :code:`FancyPtycho` also define a :code:`plot_panel_list`, which groups related plots together into multi-subplot figures. The :code:`panel_plot_mode` argument (passed at construction time) controls whether these panels are rendered as combined multi-subplot figures or as individual windows. For a simple model like :code:`SimplePtycho`, :code:`plot_list` is sufficient. + Saving ++++++ diff --git a/examples/gold_ball_ptycho.py b/examples/gold_ball_ptycho.py index 944c760b..aeb510c1 100644 --- a/examples/gold_ball_ptycho.py +++ b/examples/gold_ball_ptycho.py @@ -27,7 +27,7 @@ propagation_distance=2e-6, units='um', probe_fourier_crop=pad, - panel_plot_mode=True, + plot_level=2, )