From a055f1e738b2a203764d03971cca6e9e8c5bc48b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matti=20Hellstr=C3=B6m?= Date: Tue, 24 Mar 2026 07:31:09 +0100 Subject: [PATCH 1/2] Document plotting helpers and create linear_fit_extrapolate_to_0 and moving_average --- CHANGELOG.md | 3 +- doc/source/components/utils.rst | 52 ++++++++ src/scm/plams/tools/plot.py | 130 +++++++++++++++++-- src/scm/plams/tools/postprocess_results.py | 25 +++- unit_tests/test_tools_plot.py | 88 +++++++++++-- unit_tests/test_tools_postprocess_results.py | 18 +++ 6 files changed, 290 insertions(+), 26 deletions(-) create mode 100644 unit_tests/test_tools_postprocess_results.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 69495f559..62d3bd3e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ This changelog is effective from the 2025 releases. * `view` function to visualize molecules/chemical systems using AMSView * `config.job.on_status_change` callback which fires any time a job status is updated * `plot_image_grid` to plot multiple images (e.g. those generated from `view`) in a grid format +* `linear_fit_extrapolate_to_0` in `scm.plams.tools.plot` and `moving_average` in `scm.plams.tools.postprocess_results` for reuse in plotting and analysis workflows ### Changed * `JobAnalysis` returns an updated copy on modification instead of performing the operation in-place @@ -110,5 +111,3 @@ This changelog is effective from the 2025 releases. * Exception classes `AMSPipeDecodeError`, `AMSPipeError`, `AMSPipeInvalidArgumentError`, `AMSPipeLogicError`, `AMSPipeRuntimeError`, `AMSPipeUnknownArgumentError`, `AMSPipeUnknownMethodError`, `AMSPipeUnknownVersionError`, were moved from scm.plams to scm.amspipe. - - diff --git a/doc/source/components/utils.rst b/doc/source/components/utils.rst index de7ad6d05..9093a1e63 100644 --- a/doc/source/components/utils.rst +++ b/doc/source/components/utils.rst @@ -21,12 +21,20 @@ The class itself serves for "one and only instance" and all methods should be ca Periodic Table ~~~~~~~~~~~~~~~~~~~~~~~~~ +Import path:: + + scm.plams.tools.periodic_table + .. autoclass:: scm.plams.tools.periodic_table.PeriodicTable :exclude-members: __weakref__ Units ~~~~~~~~~~~~~~~~~~~~~~~~~ +Import path:: + + scm.plams.tools.units + .. autoclass:: scm.plams.tools.units.Units :exclude-members: __weakref__ @@ -36,6 +44,10 @@ Geometry tools A small module with simple functions related to 3D geometry operations. +Import path:: + + scm.plams.tools.geometry + .. automodule:: scm.plams.tools.geometry .. _FileFormatConversionTools: @@ -45,6 +57,10 @@ File format conversion tools A small module for converting VASP output to AMS-like output, and for converting ASE .traj trajectory files to the .rkf format. +Import path:: + + scm.plams.tools.converters + .. automodule:: scm.plams.tools.converters .. _ReactionEnergies: @@ -55,6 +71,11 @@ Reaction energies *New in AMS2026*: The ``balance`` function is new in AMS2026. For usage, see the :ref:`BalanceReactionEquationsExample` example. +Import paths:: + + scm.plams.tools.reaction + scm.plams.tools.reaction_energies + .. autofunction:: scm.plams.tools.reaction.balance .. autoclass:: scm.plams.tools.reaction.ReactionEquation @@ -75,6 +96,23 @@ Plotting tools Tools for creating plots with matplotlib. +Import path:: + + scm.plams.tools.plot + +The :mod:`scm.plams.tools.plot` module also contains small reusable helpers for +common analysis tasks. For example, ``linear_fit_extrapolate_to_0`` performs a +linear regression and returns the fitted line extended to ``x = 0``. + +Example:: + + from scm.plams.tools.plot import linear_fit_extrapolate_to_0 + + fit_x, fit_y, slope, intercept = linear_fit_extrapolate_to_0( + [1.0, 2.0, 3.0], + [3.0, 5.0, 7.0], + ) + .. automodule:: scm.plams.tools.plot .. _PostprocessResults: @@ -84,4 +122,18 @@ Postprocess results Tools for postprocessing the results. +Import path:: + + scm.plams.tools.postprocess_results + +The :mod:`scm.plams.tools.postprocess_results` module contains helpers such as +``moving_average`` for smoothing paired ``x``/``y`` data and ``broaden_results`` +for constructing broadened spectra. + +Example:: + + from scm.plams.tools.postprocess_results import moving_average + + avg_x, avg_y = moving_average([1.0, 2.0, 3.0], [3.0, 5.0, 7.0], window=2) + .. automodule:: scm.plams.tools.postprocess_results diff --git a/src/scm/plams/tools/plot.py b/src/scm/plams/tools/plot.py index 7686a3f8c..99af5fc5d 100644 --- a/src/scm/plams/tools/plot.py +++ b/src/scm/plams/tools/plot.py @@ -1,4 +1,15 @@ -from typing import List, Optional, Tuple, Union, TYPE_CHECKING, Dict, Any, Literal, cast +from typing import ( + List, + Optional, + Tuple, + Union, + TYPE_CHECKING, + Dict, + Any, + Literal, + cast, + Sequence, +) import numpy as np from scm.plams.core.errors import MissingOptionalPackageError @@ -21,6 +32,7 @@ from scm.plams.recipes.md.trajectoryanalysis import AMSMSDJob __all__ = [ + "linear_fit_extrapolate_to_0", "plot_band_structure", "plot_phonons_band_structure", "plot_phonons_dos", @@ -34,6 +46,36 @@ ] +def linear_fit_extrapolate_to_0(x: Sequence[float], y: Sequence[float]) -> Tuple[np.ndarray, np.ndarray, float, float]: + """ + Perform a linear regression on ``x`` and ``y`` and return the fit extended to ``x = 0``. + + x: sequence of float + X values for the linear regression. + + y: sequence of float + Y values for the linear regression. + + Returns: tuple + ``fit_x``, ``fit_y``, ``slope``, ``intercept``. + + If ``0`` is already present in ``x``, it is not appended a second time. + """ + try: + from scipy.stats import linregress + except ImportError: + raise MissingOptionalPackageError("scipy") + + result = linregress(x, y) + fit_x_values = list(x) + if 0 not in fit_x_values: + fit_x_values.append(0.0) + fit_x = np.array(fit_x_values, dtype=float) + fit_y = result.slope * fit_x + result.intercept + + return fit_x, fit_y, result.slope, result.intercept + + @requires_optional_package("matplotlib") def plot_band_structure( x: List[float], @@ -263,12 +305,26 @@ def plot_phonons_dos( ax.plot(energy, total_dos, color="black", label="Total DOS", linestyle="-", zorder=1) elif dos_type == "species": - ax.plot(energy, total_dos, color="black", label="Total DOS", linestyle="-", zorder=-1) + ax.plot( + energy, + total_dos, + color="black", + label="Total DOS", + linestyle="-", + zorder=-1, + ) for i, (l, v) in enumerate(dos_per_species.items()): ax.plot(energy, v, label=f"pDOS {l}", dashes=[3, i + 1, 2], zorder=i) elif dos_type == "atoms": - ax.plot(energy, total_dos, color="black", label="Total DOS", linestyle="-", zorder=-1) + ax.plot( + energy, + total_dos, + color="black", + label="Total DOS", + linestyle="-", + zorder=-1, + ) for i, (l, v) in enumerate(dos_per_atom.items()): ax.plot(energy, v, label=f"pDOS {l}", dashes=[3, i + 1, 2], zorder=i) @@ -282,7 +338,10 @@ def plot_phonons_dos( @requires_optional_package("matplotlib") def plot_phonons_thermodynamic_properties( - temperature: List[float], properties: Dict[str, List[float]], units: Dict[str, str], ax: Optional["plt.Axes"] = None + temperature: List[float], + properties: Dict[str, List[float]], + units: Dict[str, str], + ax: Optional["plt.Axes"] = None, ) -> "plt.Axes": """ Plots the phonons thermodynamic properties from DFTB, BAND or QuantumEspresso engines with matplotlib. @@ -310,7 +369,14 @@ def plot_phonons_thermodynamic_properties( _, ax = plt.subplots() for i, (label, prop) in enumerate(properties.items()): - ax.plot(temperature, prop, label=label + " (" + units[label] + ")", linestyle="-", lw=2, zorder=1) + ax.plot( + temperature, + prop, + label=label + " (" + units[label] + ")", + linestyle="-", + lw=2, + zorder=1, + ) plt.legend() @@ -500,16 +566,26 @@ def tolist(x: Any) -> List: data2 = [] for j1, j2 in zip(job1, job2): try: - d1 = cast(Union[List[float], float], j1.results.readrkf(section, variable, file=file)) + d1 = cast( + Union[List[float], float], + j1.results.readrkf(section, variable, file=file), + ) except KeyError: - d1 = cast(Union[List[float], float], j1.results.get_history_property(variable, history_section=section)) + d1 = cast( + Union[List[float], float], + j1.results.get_history_property(variable, history_section=section), + ) d1a = np.ravel(d1) * multiplier try: - d2 = cast(Union[List[float], float], j2.results.readrkf(alt_section, alt_variable, file=file)) + d2 = cast( + Union[List[float], float], + j2.results.readrkf(alt_section, alt_variable, file=file), + ) except KeyError: d2 = cast( - Union[List[float], float], j2.results.get_history_property(alt_variable, history_section=alt_section) + Union[List[float], float], + j2.results.get_history_property(alt_variable, history_section=alt_section), ) d2a = np.ravel(d2) * multiplier @@ -695,7 +771,9 @@ def add_unit(s: str) -> str: @requires_optional_package("matplotlib") def plot_msd( - job: "AMSMSDJob", start_time_fit_fs: Optional[float] = None, ax: Optional["plt.Axes"] = None + job: "AMSMSDJob", + start_time_fit_fs: Optional[float] = None, + ax: Optional["plt.Axes"] = None, ) -> "plt.Axes": """ job: AMSMSDJob @@ -821,11 +899,29 @@ def plot_work_function( # Otherwise: else: - ax.plot([x0, x0 + 0.3 * (x1 - x0)], [Vvacuum[0], Vvacuum[0]], color="black", linestyle="dashed", linewidth=1) + ax.plot( + [x0, x0 + 0.3 * (x1 - x0)], + [Vvacuum[0], Vvacuum[0]], + color="black", + linestyle="dashed", + linewidth=1, + ) ax.text(x0, Vvacuum[0] + 0.1, "Pot. vacuum", fontsize=11, color="black") - ax.plot([x1, x1 - 0.3 * (x1 - x0)], [Vvacuum[1], Vvacuum[1]], color="black", linestyle="dashed", linewidth=1) - ax.text(x1 - 0.3 * (x1 - x0), Vvacuum[1] + 0.1, "Pot. vacuum", fontsize=11, color="black") + ax.plot( + [x1, x1 - 0.3 * (x1 - x0)], + [Vvacuum[1], Vvacuum[1]], + color="black", + linestyle="dashed", + linewidth=1, + ) + ax.text( + x1 - 0.3 * (x1 - x0), + Vvacuum[1] + 0.1, + "Pot. vacuum", + fontsize=11, + color="black", + ) head_length = 0.4 ax.arrow( @@ -838,7 +934,13 @@ def plot_work_function( fc="black", ec="black", ) - ax.text(x0 + 0.02 * (x1 - x0), (Vvacuum[0] + Efermi) / 2, f"WF={WF[0]:.1f} eV", fontsize=11, color="black") + ax.text( + x0 + 0.02 * (x1 - x0), + (Vvacuum[0] + Efermi) / 2, + f"WF={WF[0]:.1f} eV", + fontsize=11, + color="black", + ) ax.arrow( x0 + 1.0 * (x1 - x0), Efermi, diff --git a/src/scm/plams/tools/postprocess_results.py b/src/scm/plams/tools/postprocess_results.py index 76c6503ad..cfc4ff9dd 100644 --- a/src/scm/plams/tools/postprocess_results.py +++ b/src/scm/plams/tools/postprocess_results.py @@ -1,9 +1,32 @@ import numpy as np -from typing import Union, Literal, Tuple, Optional +from typing import Union, Literal, Tuple, Optional, Sequence ArrayOrFloat = Union[np.ndarray, float] +def moving_average(x: Sequence[float], y: Sequence[float], window: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate a moving average of x and y. + + :param x: X values + :type x: Sequence[float] + :param y: Y values + :type y: Sequence[float] + :param window: Moving-average window size + :type window: int + :return: ``x_moving_averaged``, ``y_moving_averaged`` + :rtype: Tuple[np.ndarray, np.ndarray] + """ + if not window: + return np.array(x), np.array(y) + window = min(len(x) - 1, window) + if window <= 1: + return np.array(x), np.array(y) + ret_x = np.convolve(x, np.ones(window) / window, mode="valid") + ret_y = np.convolve(y, np.ones(window) / window, mode="valid") + return ret_x, ret_y + + def _gaussian(x: np.ndarray, A: ArrayOrFloat, x0: ArrayOrFloat, sigma: ArrayOrFloat) -> np.ndarray: return A * np.exp(-((x - x0) ** 2) / (2 * sigma**2)) diff --git a/unit_tests/test_tools_plot.py b/unit_tests/test_tools_plot.py index 8891c0eeb..341e59ad0 100644 --- a/unit_tests/test_tools_plot.py +++ b/unit_tests/test_tools_plot.py @@ -18,6 +18,7 @@ from scm.plams.recipes.md.trajectoryanalysis import AMSMSDJob, AMSMSDResults from scm.plams.tools.plot import ( get_correlation_xy, + linear_fit_extrapolate_to_0, plot_band_structure, plot_phonons_band_structure, plot_phonons_dos, @@ -34,6 +35,23 @@ matplotlib.use("Agg") +def test_linear_fit_extrapolate_to_0(): + fit_x, fit_y, slope, intercept = linear_fit_extrapolate_to_0([1.0, 2.0, 3.0], [3.0, 5.0, 7.0]) + + assert slope == pytest.approx(2.0) + assert intercept == pytest.approx(1.0) + assert fit_x.tolist() == pytest.approx([1.0, 2.0, 3.0, 0.0]) + assert fit_y.tolist() == pytest.approx([3.0, 5.0, 7.0, 1.0]) + + fit_x, fit_y, slope, intercept = linear_fit_extrapolate_to_0([0.0, 1.0, 2.0], [1.0, 3.0, 5.0]) + + assert slope == pytest.approx(2.0) + assert intercept == pytest.approx(1.0) + assert fit_x.tolist() == pytest.approx([0.0, 1.0, 2.0]) + assert fit_y.tolist() == pytest.approx([1.0, 3.0, 5.0]) + assert fit_x.tolist().count(0.0) == 1 + + @pytest.fixture def run_calculations(): run_calculations = False # Manual toggle whether to re-run AMS calculations @@ -48,7 +66,13 @@ def rkf_tools_plot(rkf_folder): # ---------------------------------------------------------- # Testing plot_molecule # ---------------------------------------------------------- -@image_comparison(baseline_images=["plot_molecule"], remove_text=True, extensions=["png"], style="mpl20", tol=30) +@image_comparison( + baseline_images=["plot_molecule"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=30, +) def test_plot_molecule(): plt.close("all") @@ -74,7 +98,13 @@ def test_plot_molecule(): # ---------------------------------------------------------- # Testing plot_grid_molecules # ---------------------------------------------------------- -@image_comparison(baseline_images=["plot_grid_molecules"], remove_text=True, extensions=["png"], style="mpl20", tol=25) +@image_comparison( + baseline_images=["plot_grid_molecules"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=25, +) def test_plot_grid_molecules(): plt.close("all") ethanol = from_smiles("CCO") @@ -146,7 +176,13 @@ def iterate_options(molecules, options): # ---------------------------------------------------------- # Testing plot_band_structure # ---------------------------------------------------------- -@image_comparison(baseline_images=["plot_band_structure"], remove_text=True, extensions=["png"], style="mpl20", tol=3) +@image_comparison( + baseline_images=["plot_band_structure"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=3, +) def test_plot_band_structure(run_calculations, rkf_tools_plot): plt.close("all") @@ -184,7 +220,11 @@ def test_plot_band_structure(run_calculations, rkf_tools_plot): # Testing plot_phonons_band_structure # ---------------------------------------------------------- @image_comparison( - baseline_images=["plot_phonons_band_structure"], remove_text=True, extensions=["png"], style="mpl20", tol=3 + baseline_images=["plot_phonons_band_structure"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=3, ) def test_plot_phonons_band_structure(run_calculations, rkf_tools_plot): plt.close("all") @@ -220,7 +260,13 @@ def test_plot_phonons_band_structure(run_calculations, rkf_tools_plot): # ---------------------------------------------------------- # Testing plot_phonons_dos # ---------------------------------------------------------- -@image_comparison(baseline_images=["plot_phonons_dos"], remove_text=True, extensions=["png"], style="mpl20", tol=3) +@image_comparison( + baseline_images=["plot_phonons_dos"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=3, +) def test_plot_phonons_dos(run_calculations, rkf_tools_plot): plt.close("all") @@ -255,7 +301,13 @@ def test_plot_phonons_dos(run_calculations, rkf_tools_plot): # ---------------------------------------------------------- # Testing plot_correlation & get_correlation_xy # ---------------------------------------------------------- -@image_comparison(baseline_images=["plot_correlation"], remove_text=True, extensions=["png"], style="mpl20", tol=1.0) +@image_comparison( + baseline_images=["plot_correlation"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=1.0, +) def test_plot_correlation(run_calculations, rkf_tools_plot): plt.close("all") @@ -357,7 +409,13 @@ def test_plot_correlation(run_calculations, rkf_tools_plot): # ---------------------------------------------------------- # Testing plot_msd # ---------------------------------------------------------- -@image_comparison(baseline_images=["plot_msd"], remove_text=True, extensions=["png"], style="mpl20", tol=4) +@image_comparison( + baseline_images=["plot_msd"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=4, +) def test_plot_msd(run_calculations, rkf_tools_plot, xyz_folder): plt.close("all") @@ -394,7 +452,13 @@ def test_plot_msd(run_calculations, rkf_tools_plot, xyz_folder): plot_msd(md_job) -@image_comparison(baseline_images=["plot_msd_pisa"], remove_text=True, extensions=["png"], style="mpl20", tol=10) +@image_comparison( + baseline_images=["plot_msd_pisa"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=10, +) def test_plot_msd_with_pisa(run_calculations, rkf_tools_plot, xyz_folder): skip_if_no_scm_pisa() @@ -460,7 +524,13 @@ def test_plot_msd_with_pisa(run_calculations, rkf_tools_plot, xyz_folder): # ---------------------------------------------------------- # Testing plot_work_function # ---------------------------------------------------------- -@image_comparison(baseline_images=["plot_work_function"], remove_text=True, extensions=["png"], style="mpl20", tol=11) +@image_comparison( + baseline_images=["plot_work_function"], + remove_text=True, + extensions=["png"], + style="mpl20", + tol=11, +) def test_plot_work_function(run_calculations, rkf_tools_plot): plt.close("all") diff --git a/unit_tests/test_tools_postprocess_results.py b/unit_tests/test_tools_postprocess_results.py new file mode 100644 index 000000000..c4f203b67 --- /dev/null +++ b/unit_tests/test_tools_postprocess_results.py @@ -0,0 +1,18 @@ +#!/usr/bin/env amspython +# coding: utf-8 + +import pytest + +from scm.plams.tools.postprocess_results import moving_average + + +def test_moving_average(): + avg_x, avg_y = moving_average([1.0, 2.0, 3.0, 4.0], [2.0, 4.0, 6.0, 8.0], window=2) + + assert avg_x.tolist() == pytest.approx([1.5, 2.5, 3.5]) + assert avg_y.tolist() == pytest.approx([3.0, 5.0, 7.0]) + + avg_x, avg_y = moving_average([1.0, 2.0], [3.0, 5.0], window=0) + + assert avg_x.tolist() == pytest.approx([1.0, 2.0]) + assert avg_y.tolist() == pytest.approx([3.0, 5.0]) From 0ed1bd8bba3fba22100cda1a2b469cf7f96e5e0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matti=20Hellstr=C3=B6m?= Date: Tue, 24 Mar 2026 10:26:14 +0100 Subject: [PATCH 2/2] use requires_optional_package for scipy in linear_fit_extrapolate_to_0 So-- --- src/scm/plams/tools/plot.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/scm/plams/tools/plot.py b/src/scm/plams/tools/plot.py index 99af5fc5d..c79745518 100644 --- a/src/scm/plams/tools/plot.py +++ b/src/scm/plams/tools/plot.py @@ -46,6 +46,7 @@ ] +@requires_optional_package("scipy") def linear_fit_extrapolate_to_0(x: Sequence[float], y: Sequence[float]) -> Tuple[np.ndarray, np.ndarray, float, float]: """ Perform a linear regression on ``x`` and ``y`` and return the fit extended to ``x = 0``. @@ -61,10 +62,7 @@ def linear_fit_extrapolate_to_0(x: Sequence[float], y: Sequence[float]) -> Tuple If ``0`` is already present in ``x``, it is not appended a second time. """ - try: - from scipy.stats import linregress - except ImportError: - raise MissingOptionalPackageError("scipy") + from scipy.stats import linregress result = linregress(x, y) fit_x_values = list(x)