From ee46fa942a35cc1433af93749ebfd797fcd202f2 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Tue, 12 Aug 2025 08:56:01 +0100 Subject: [PATCH 1/6] JAX support --- dyson/__init__.py | 4 +- dyson/_backend.py | 90 ++++++++++++++++++++++++++++++++ dyson/expressions/adc.py | 15 ++++++ dyson/expressions/ccsd.py | 19 +++++-- dyson/expressions/fci.py | 4 ++ dyson/expressions/gw.py | 3 ++ dyson/expressions/hamiltonian.py | 9 +++- dyson/expressions/hf.py | 4 ++ dyson/grids/frequency.py | 4 +- dyson/representations/lehmann.py | 4 +- dyson/solvers/dynamic/corrvec.py | 5 +- dyson/solvers/static/chempot.py | 4 +- dyson/solvers/static/davidson.py | 14 +++-- dyson/solvers/static/density.py | 12 ++++- dyson/solvers/static/exact.py | 5 +- dyson/util/linalg.py | 5 +- pyproject.toml | 5 ++ tests/conftest.py | 15 ++++++ tests/test_chempot.py | 2 +- tests/test_corrvec.py | 2 +- tests/test_cpgf.py | 2 +- tests/test_davidson.py | 2 +- tests/test_density.py | 2 +- tests/test_downfolded.py | 2 +- tests/test_expressions.py | 30 +++++++---- 25 files changed, 218 insertions(+), 45 deletions(-) create mode 100644 dyson/_backend.py diff --git a/dyson/__init__.py b/dyson/__init__.py index 6f3ab4c..72c7272 100644 --- a/dyson/__init__.py +++ b/dyson/__init__.py @@ -122,9 +122,7 @@ __version__ = "1.0.0" -import numpy -import scipy - +from dyson._backend import set_backend, numpy, scipy from dyson.printing import console, quiet from dyson.representations import Lehmann, Spectral, Dynamic from dyson.solvers import ( diff --git a/dyson/_backend.py b/dyson/_backend.py new file mode 100644 index 0000000..c98aec6 --- /dev/null +++ b/dyson/_backend.py @@ -0,0 +1,90 @@ +"""Backend management for :mod:`dyson`.""" + +from __future__ import annotations + +import functools +import importlib +import os + +from types import ModuleType +from typing import Callable, Any + + +try: + import jax + jax.config.update("jax_enable_x64", True) +except ImportError: + pass + +_BACKEND = os.environ.get("DYSON_BACKEND", "numpy") +_BACKEND_WARNINGS = os.environ.get("DYSON_BACKEND_WARNINGS", "0") == "1" + +_MODULE_CACHE: dict[tuple[str, str], ModuleType] = {} +_BACKENDS = { + "numpy": { + "numpy": "numpy", + "scipy": "scipy", + }, + "jax": { + "numpy": "jax.numpy", + "scipy": "jax.scipy", + }, +} + + +def set_backend(backend: str) -> None: + """Set the backend for :mod:`dyson`.""" + global _BACKEND + if backend not in _BACKENDS: + raise ValueError( + f"Invalid backend: {backend}. Available backends are: {list(_BACKENDS.keys())}" + ) + _BACKEND = backend + + +def cast_returned_array(func: Callable[[Any], Any]) -> Callable[[Any], Any]: + """Decorate a function to coerce its returned array to the backend type.""" + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + result = func(*args, **kwargs) + if isinstance(result, tuple): + return tuple(numpy.asarray(r) for r in result) + return numpy.asarray(result) + + return wrapper + + +class ProxyModule(ModuleType): + """Dynamic proxy module for backend-specific imports.""" + + def __init__(self, key: str) -> None: + """Initialise the object.""" + super().__init__(f"{__name__}.{key}") + self._key = key + + def __getattr__(self, attr: str) -> ModuleType: + """Get the attribute from the backend module.""" + mod = self._load() + return getattr(mod, attr) + + def _load(self) -> ModuleType: + """Load the backend module.""" + # Check the cache + key = (self._key, _BACKEND) + if key in _MODULE_CACHE: + return _MODULE_CACHE[key] + + # Load the module + keys = self._key.split(".") + module = _BACKENDS[_BACKEND][keys[0]] + if len(keys) > 1: + module += "." + ".".join(keys[1:]) + _MODULE_CACHE[key] = importlib.import_module(module) + + return _MODULE_CACHE[key] + + +numpy = ProxyModule("numpy") +scipy = ProxyModule("scipy") +scipy.optimize = ProxyModule("scipy.optimize") # SciPy doesn't seem to export this diff --git a/dyson/expressions/adc.py b/dyson/expressions/adc.py index 8eb7c72..d63d33c 100644 --- a/dyson/expressions/adc.py +++ b/dyson/expressions/adc.py @@ -19,6 +19,7 @@ from dyson import numpy as np from dyson import util +from dyson._backend import cast_returned_array from dyson.expressions.expression import BaseExpression, ExpressionCollection from dyson.representations.enums import Reduction @@ -89,6 +90,7 @@ def from_mf(cls, mf: RHF) -> BaseADC: adc_obj.kernel_gs() return cls.from_adc(adc_obj) + @cast_returned_array def apply_hamiltonian(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector. @@ -115,6 +117,7 @@ def apply_hamiltonian_left(self, vector: Array) -> Array: """ raise NotImplementedError("Left application of Hamiltonian is not implemented for ADC.") + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. @@ -234,6 +237,12 @@ def build_se_moments(self, nmom: int, reduction: Reduction = Reduction.NONE) -> ooov = ooov.reshape(eo.size, eo.size, eo.size, ev.size) left = ooov * 2 - ooov.swapaxes(1, 2) + # Cast arrays + eo = np.asarray(eo) + ev = np.asarray(ev) + ooov = np.asarray(ooov) + left = np.asarray(left) + # Get the subscript based on the reduction if Reduction(reduction) == Reduction.NONE: subscript = "ikla,jkla->ij" @@ -302,6 +311,12 @@ def build_se_moments(self, nmom: int, reduction: Reduction = Reduction.NONE) -> vvvo = vvvo.reshape(ev.size, ev.size, ev.size, eo.size) left = vvvo * 2 - vvvo.swapaxes(1, 2) + # Cast arrays + eo = np.asarray(eo) + ev = np.asarray(ev) + vvvo = np.asarray(vvvo) + left = np.asarray(left) + # Get the subscript based on the reduction if Reduction(reduction) == Reduction.NONE: subscript = "acdi,bcdi->ab" diff --git a/dyson/expressions/ccsd.py b/dyson/expressions/ccsd.py index ae85e11..475ebf0 100644 --- a/dyson/expressions/ccsd.py +++ b/dyson/expressions/ccsd.py @@ -21,6 +21,7 @@ from dyson import numpy as np from dyson import util +from dyson._backend import cast_returned_array from dyson.expressions.expression import BaseExpression, ExpressionCollection from dyson.representations.enums import Reduction @@ -63,10 +64,10 @@ def __init__( imds: Intermediate integrals. """ self._mol = mol - self._t1 = t1 - self._t2 = t2 - self._l1 = l1 - self._l2 = l2 + self._t1 = np.asarray(t1) + self._t2 = np.asarray(t2) + self._l1 = np.asarray(l1) + self._l2 = np.asarray(l2) self._imds = imds self._precompute_imds() @@ -199,6 +200,7 @@ def _precompute_imds(self) -> None: """Precompute intermediate integrals.""" self._imds.make_ip() + @cast_returned_array def vector_to_amplitudes(self, vector: Array, *args: Any) -> tuple[Array, Array]: """Convert a vector to amplitudes. @@ -211,6 +213,7 @@ def vector_to_amplitudes(self, vector: Array, *args: Any) -> tuple[Array, Array] """ return self.PYSCF_EOM.vector_to_amplitudes_ip(vector, self.nphys, self.nocc) + @cast_returned_array def amplitudes_to_vector(self, t1: Array, t2: Array) -> Array: """Convert amplitudes to a vector. @@ -223,6 +226,7 @@ def amplitudes_to_vector(self, t1: Array, t2: Array) -> Array: """ return self.PYSCF_EOM.amplitudes_to_vector_ip(t1, t2) + @cast_returned_array def apply_hamiltonian_right(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector on the right. @@ -240,6 +244,7 @@ def apply_hamiltonian_right(self, vector: Array) -> Array: """ return -self.PYSCF_EOM.lipccsd_matvec(self, vector, imds=self._imds) + @cast_returned_array def apply_hamiltonian_left(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector on the left. @@ -260,6 +265,7 @@ def apply_hamiltonian_left(self, vector: Array) -> Array: apply_hamiltonian = apply_hamiltonian_right apply_hamiltonian.__doc__ = BaseCCSD.apply_hamiltonian.__doc__ + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. @@ -367,6 +373,7 @@ def _precompute_imds(self) -> None: """Precompute intermediate integrals.""" self._imds.make_ea() + @cast_returned_array def vector_to_amplitudes(self, vector: Array, *args: Any) -> tuple[Array, Array]: """Convert a vector to amplitudes. @@ -379,6 +386,7 @@ def vector_to_amplitudes(self, vector: Array, *args: Any) -> tuple[Array, Array] """ return self.PYSCF_EOM.vector_to_amplitudes_ea(vector, self.nphys, self.nocc) + @cast_returned_array def amplitudes_to_vector(self, t1: Array, t2: Array) -> Array: """Convert amplitudes to a vector. @@ -391,6 +399,7 @@ def amplitudes_to_vector(self, t1: Array, t2: Array) -> Array: """ return self.PYSCF_EOM.amplitudes_to_vector_ea(t1, t2) + @cast_returned_array def apply_hamiltonian_right(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector on the right. @@ -402,6 +411,7 @@ def apply_hamiltonian_right(self, vector: Array) -> Array: """ return self.PYSCF_EOM.eaccsd_matvec(self, vector, imds=self._imds) + @cast_returned_array def apply_hamiltonian_left(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector on the left. @@ -416,6 +426,7 @@ def apply_hamiltonian_left(self, vector: Array) -> Array: apply_hamiltonian = apply_hamiltonian_right apply_hamiltonian.__doc__ = BaseCCSD.apply_hamiltonian.__doc__ + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. diff --git a/dyson/expressions/fci.py b/dyson/expressions/fci.py index ca7d203..18af3a4 100644 --- a/dyson/expressions/fci.py +++ b/dyson/expressions/fci.py @@ -12,6 +12,7 @@ from pyscf import ao2mo, fci +from dyson._backend import cast_returned_array from dyson.expressions.expression import BaseExpression, ExpressionCollection from dyson.representations.enums import Reduction @@ -103,6 +104,7 @@ def from_mf(cls, mf: RHF) -> BaseFCI: ci.kernel(h1e, h2e, mf.mol.nao, mf.mol.nelec) return cls.from_fci(ci, h1e, h2e) + @cast_returned_array def apply_hamiltonian(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector. @@ -123,6 +125,7 @@ def apply_hamiltonian(self, vector: Array) -> Array: result -= (self.e_fci + self.chempot) * vector return self.SIGN * result + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. @@ -131,6 +134,7 @@ def diagonal(self) -> Array: """ return self.SIGN * (self._diagonal - (self.e_fci + self.chempot)) + @cast_returned_array def get_excitation_vector(self, orbital: int) -> Array: r"""Obtain the vector corresponding to a fermionic operator acting on the ground state. diff --git a/dyson/expressions/gw.py b/dyson/expressions/gw.py index e23abf9..168578a 100644 --- a/dyson/expressions/gw.py +++ b/dyson/expressions/gw.py @@ -21,6 +21,7 @@ from dyson import numpy as np from dyson import util +from dyson._backend import cast_returned_array from dyson.expressions.expression import BaseExpression, ExpressionCollection from dyson.representations.enums import Reduction @@ -151,6 +152,7 @@ def non_dyson(self) -> bool: class TDAGW_Dyson(BaseGW_Dyson): """GW expressions with Tamm--Dancoff (TDA) approximation for the Dyson Green's function.""" + @cast_returned_array def apply_hamiltonian(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector. @@ -215,6 +217,7 @@ def apply_hamiltonian_left(self, vector: Array) -> Array: """ raise NotImplementedError("Left application of Hamiltonian is not implemented for TDA-GW.") + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. diff --git a/dyson/expressions/hamiltonian.py b/dyson/expressions/hamiltonian.py index a43c91c..e89dbfc 100644 --- a/dyson/expressions/hamiltonian.py +++ b/dyson/expressions/hamiltonian.py @@ -7,6 +7,7 @@ from dyson import numpy as np from dyson import util +from dyson._backend import cast_returned_array from dyson.expressions.expression import BaseExpression from dyson.representations.enums import Reduction @@ -37,8 +38,8 @@ def __init__( :meth:`~dyson.expressions.expression.BaseExpression.get_excitation_ket`. """ self._hamiltonian = hamiltonian - self._bra = bra - self._ket = ket + self._bra = np.asarray(bra) if bra is not None else None + self._ket = np.asarray(ket) if bra is not None else None if isinstance(hamiltonian, np.ndarray): self.hermitian_upfolded = np.allclose(hamiltonian, hamiltonian.T.conj()) @@ -58,6 +59,7 @@ def from_mf(cls, mf: RHF) -> Hamiltonian: """ raise NotImplementedError("Cannot create Hamiltonian expression from mean-field object.") + @cast_returned_array def apply_hamiltonian(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector. @@ -69,6 +71,7 @@ def apply_hamiltonian(self, vector: Array) -> Array: """ return self._hamiltonian @ vector + @cast_returned_array def apply_hamiltonian_left(self, vector: Array) -> Array: """Apply the Hamiltonian to a vector on the left. @@ -80,6 +83,7 @@ def apply_hamiltonian_left(self, vector: Array) -> Array: """ return vector @ self._hamiltonian + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. @@ -88,6 +92,7 @@ def diagonal(self) -> Array: """ return self._hamiltonian.diagonal() + @cast_returned_array def build_matrix(self) -> Array: """Build the Hamiltonian matrix. diff --git a/dyson/expressions/hf.py b/dyson/expressions/hf.py index f212ba3..b43d35e 100644 --- a/dyson/expressions/hf.py +++ b/dyson/expressions/hf.py @@ -11,6 +11,7 @@ from dyson import numpy as np from dyson import util +from dyson._backend import cast_returned_array from dyson.expressions.expression import BaseExpression, ExpressionCollection from dyson.representations.enums import Reduction @@ -109,6 +110,7 @@ def nconfig(self) -> int: class HF_1h(BaseHF): # pylint: disable=invalid-name """HF expressions for the hole Green's function.""" + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. @@ -150,6 +152,7 @@ def nsingle(self) -> int: class HF_1p(BaseHF): # pylint: disable=invalid-name """HF expressions for the particle Green's function.""" + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. @@ -191,6 +194,7 @@ def nsingle(self) -> int: class HF_Dyson(BaseHF): # pylint: disable=invalid-name """HF expressions for the Dyson Green's function.""" + @cast_returned_array def diagonal(self) -> Array: """Get the diagonal of the Hamiltonian. diff --git a/dyson/grids/frequency.py b/dyson/grids/frequency.py index fc13f28..3f77eae 100644 --- a/dyson/grids/frequency.py +++ b/dyson/grids/frequency.py @@ -5,10 +5,8 @@ from abc import abstractmethod from typing import TYPE_CHECKING -import scipy.special - from dyson import numpy as np -from dyson import util +from dyson import util, scipy from dyson.grids.grid import BaseGrid from dyson.representations.enums import Component, Ordering, Reduction diff --git a/dyson/representations/lehmann.py b/dyson/representations/lehmann.py index 78484bb..7d87fb8 100644 --- a/dyson/representations/lehmann.py +++ b/dyson/representations/lehmann.py @@ -5,10 +5,8 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, cast -import scipy.linalg - from dyson import numpy as np -from dyson import util +from dyson import util, scipy from dyson.representations.enums import Reduction from dyson.representations.representation import BaseRepresentation from dyson.typing import Array diff --git a/dyson/solvers/dynamic/corrvec.py b/dyson/solvers/dynamic/corrvec.py index 19bd6cb..dd97f4d 100644 --- a/dyson/solvers/dynamic/corrvec.py +++ b/dyson/solvers/dynamic/corrvec.py @@ -13,6 +13,7 @@ from dyson import console, printing from dyson import numpy as np +from dyson._backend import _BACKEND from dyson.grids.frequency import RealFrequencyGrid from dyson.representations.dynamic import Dynamic from dyson.representations.enums import Component, Ordering, Reduction @@ -26,7 +27,9 @@ from dyson.representations.lehmann import Lehmann from dyson.typing import Array -# TODO: Can we use DIIS? +if _BACKEND == "jax": + # No LGMRES in JAX, so use GMRES + from dyson.scipy.sprase.linalg import gmres as lgmres class CorrectionVector(DynamicSolver): diff --git a/dyson/solvers/static/chempot.py b/dyson/solvers/static/chempot.py index 52a8694..9c40592 100644 --- a/dyson/solvers/static/chempot.py +++ b/dyson/solvers/static/chempot.py @@ -6,9 +6,7 @@ import warnings from typing import TYPE_CHECKING -import scipy.optimize - -from dyson import console, printing, util +from dyson import console, printing, util, scipy from dyson import numpy as np from dyson.representations.lehmann import Lehmann, shift_energies from dyson.solvers.solver import StaticSolver diff --git a/dyson/solvers/static/davidson.py b/dyson/solvers/static/davidson.py index 929a746..7bb8acb 100644 --- a/dyson/solvers/static/davidson.py +++ b/dyson/solvers/static/davidson.py @@ -14,10 +14,11 @@ import warnings from typing import TYPE_CHECKING -from pyscf import lib +from pyscf.lib.linalg_helper import davidson1, davidson_nosym1 from dyson import console, printing, util from dyson import numpy as np +from dyson._backend import _BACKEND from dyson.representations.lehmann import Lehmann from dyson.representations.spectral import Spectral from dyson.solvers.solver import StaticSolver @@ -29,6 +30,13 @@ from dyson.expressions.expression import BaseExpression from dyson.typing import Array +if _BACKEND == "jax": + # Try to get the JAX version of the Davidson algorithm, only available for Hermitian case + try: + from pyscfad.lib.linalg_helper import davidson1 + except: + pass + def _pick_real_eigenvalues( eigvals: Array, @@ -249,7 +257,7 @@ def _callback(env: dict[str, Any]) -> None: # Call the Davidson function if self.hermitian: - converged, eigvals, eigvecs = lib.linalg_helper.davidson1( + converged, eigvals, eigvecs = davidson1( lambda vectors: [self.matvec(vector) for vector in vectors], self.get_guesses(), self.diagonal, @@ -268,7 +276,7 @@ def _callback(env: dict[str, Any]) -> None: else: with util.catch_warnings(UserWarning): - converged, eigvals, left, right = lib.linalg_helper.davidson_nosym1( + converged, eigvals, left, right = davidson_nosym1( lambda vectors: [self.matvec(vector) for vector in vectors], self.get_guesses(), self.diagonal, diff --git a/dyson/solvers/static/density.py b/dyson/solvers/static/density.py index ce4e22b..4f91d45 100644 --- a/dyson/solvers/static/density.py +++ b/dyson/solvers/static/density.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING -from pyscf import lib +from pyscf.lib.diis import DIIS from dyson import console, printing from dyson import numpy as np +from dyson._backend import _BACKEND from dyson.representations.lehmann import Lehmann from dyson.solvers.solver import StaticSolver from dyson.solvers.static.chempot import AufbauPrinciple, AuxiliaryShift @@ -42,6 +43,13 @@ def __call__( """ ... +if _BACKEND == "jax": + # Try to get the JAX version of DIIS + try: + from pyscfad.lib.diis import DIIS + except ImportError: + pass + def get_fock_matrix_function(mf: scf.hf.RHF) -> StaticFunction: """Get a function to compute the Fock matrix for a given density matrix. @@ -290,7 +298,7 @@ def kernel(self) -> Spectral: self_energy = result.get_self_energy() # Initialise DIIS for the inner loop - diis = lib.diis.DIIS() + diis = DIIS() diis.space = self.diis_min_space diis.max_space = self.diis_max_space diis.incore = True diff --git a/dyson/solvers/static/exact.py b/dyson/solvers/static/exact.py index 4111644..fff9596 100644 --- a/dyson/solvers/static/exact.py +++ b/dyson/solvers/static/exact.py @@ -70,8 +70,8 @@ def project_eigenvectors( left, right = util.biorthonormalise(left, right) # Return the physical vectors to the original basis - left[:, :nphys] = left[:, :nphys] @ unorth.T.conj() - right[:, :nphys] = right[:, :nphys] @ unorth + left = util.rotate_subspace(left.T, unorth.conj()).T + right = util.rotate_subspace(right.T, unorth).T # Rotate the eigenvectors eigvecs = np.array([left.T.conj() @ eigvecs[0], right.T.conj() @ eigvecs[1]]) @@ -79,6 +79,7 @@ def project_eigenvectors( return eigvecs + def orthogonalise_self_energy( static: Array, self_energy: Lehmann, diff --git a/dyson/util/linalg.py b/dyson/util/linalg.py index 7f8d9c6..ad298ba 100644 --- a/dyson/util/linalg.py +++ b/dyson/util/linalg.py @@ -6,9 +6,8 @@ import warnings from typing import TYPE_CHECKING, cast -import scipy.linalg - from dyson import numpy as np +from dyson import scipy from dyson.util import cache_by_id if TYPE_CHECKING: @@ -175,7 +174,7 @@ def _sort_eigvals(eigvals: Array, eigvecs: Array, threshold: float = 1e-11) -> t first ordering based on the rounded real and imaginary parts of the eigenvalues, and then sorting by the true real and imaginary parts. """ - decimals = round(-np.log10(threshold)) + decimals = int(round(-np.log10(threshold))) real_approx = np.round(eigvals.real, decimals=decimals) imag_approx = np.round(eigvals.imag, decimals=decimals) idx = np.lexsort((eigvals.imag, eigvals.real, imag_approx, real_approx)) diff --git a/pyproject.toml b/pyproject.toml index d5b864e..1af39e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,11 @@ dependencies = [ ] [project.optional-dependencies] +jax = [ + "jax>=0.5.0", + "jaxlib>=0.5.0", + "pyscfad>=0.1.0", +] dev = [ "ruff>=0.10.0", "mypy>=1.5.0", diff --git a/tests/conftest.py b/tests/conftest.py index 5fbd1ad..1a20421 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from pyscf import gto, scf from dyson import numpy as np +from dyson._backend import set_backend from dyson.expressions import ADC2, CCSD, FCI, HF, TDAGW, ADC2x from dyson.representations.lehmann import Lehmann from dyson.representations.spectral import Spectral @@ -73,6 +74,13 @@ def pytest_generate_tests(metafunc): # type: ignore expressions.append(method) ids.append(name) metafunc.parametrize("expression_method", expressions, ids=ids) + if "backend" in metafunc.fixturenames: + try: + import jax # noqa: F401 + metafunc.parametrize("backend", ["numpy", "jax"], scope="function") + except ImportError: + # If JAX is not available, only use numpy backend + metafunc.parametrize("backend", ["numpy"], scope="function") class Helper: @@ -128,6 +136,13 @@ def has_orthonormal_couplings(greens_function: Lehmann, tol: float = 1e-8) -> bo ) +@pytest.fixture(scope="session") +def backend(request) -> str: + """Fixture to set the backend for the tests.""" + set_backend(request.param) + return request.param + + @pytest.fixture(scope="session") def helper() -> Helper: """Fixture for the :class:`Helper` class.""" diff --git a/tests/test_chempot.py b/tests/test_chempot.py index cf0aeaa..07804b7 100644 --- a/tests/test_chempot.py +++ b/tests/test_chempot.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING -import numpy as np import pytest from dyson.solvers import AufbauPrinciple, AuxiliaryShift +from dyson import numpy as np from .conftest import _get_central_result diff --git a/tests/test_corrvec.py b/tests/test_corrvec.py index a39171a..10dbcaa 100644 --- a/tests/test_corrvec.py +++ b/tests/test_corrvec.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING -import numpy as np import pytest +from dyson import numpy as np from dyson.grids import RealFrequencyGrid from dyson.solvers import CorrectionVector diff --git a/tests/test_cpgf.py b/tests/test_cpgf.py index 0f0ba11..827623f 100644 --- a/tests/test_cpgf.py +++ b/tests/test_cpgf.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING -import numpy as np import pytest +from dyson import numpy as np from dyson.expressions.hf import BaseHF from dyson.grids import RealFrequencyGrid from dyson.solvers import CPGF diff --git a/tests/test_davidson.py b/tests/test_davidson.py index 1dc347c..0f6b21f 100644 --- a/tests/test_davidson.py +++ b/tests/test_davidson.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING -import numpy as np import pytest +from dyson import numpy as np from dyson.representations.lehmann import Lehmann from dyson.representations.spectral import Spectral from dyson.solvers import Davidson diff --git a/tests/test_density.py b/tests/test_density.py index b87682d..8976bbf 100644 --- a/tests/test_density.py +++ b/tests/test_density.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING -import numpy as np import pytest +from dyson import numpy as np from dyson.representations.spectral import Spectral from dyson.solvers import DensityRelaxation from dyson.solvers.static.density import get_fock_matrix_function diff --git a/tests/test_downfolded.py b/tests/test_downfolded.py index 32be3d7..72e2b41 100644 --- a/tests/test_downfolded.py +++ b/tests/test_downfolded.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING -import numpy as np import pytest +from dyson import numpy as np from dyson.representations.spectral import Spectral from dyson.solvers import Downfolded diff --git a/tests/test_expressions.py b/tests/test_expressions.py index ec521fd..1d88200 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -5,10 +5,10 @@ import itertools from typing import TYPE_CHECKING -import numpy as np import pyscf import pytest +from dyson import numpy as np from dyson import util from dyson.expressions import ADC2, CCSD, FCI, HF, TDAGW, ADC2x from dyson.solvers import Davidson, Exact @@ -21,7 +21,7 @@ from .conftest import ExactGetter, Helper -def test_init(mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> None: +def test_init(backend: str, mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> None: """Test the instantiation of the expression from a mean-field object.""" expression = expression_cls.from_mf(mf) assert expression.mol is mf.mol @@ -30,7 +30,7 @@ def test_init(mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> None: assert expression.nvir == mf.mol.nao - mf.mol.nelectron // 2 -def test_hamiltonian(mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> None: +def test_hamiltonian(backend: str, mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> None: """Test the Hamiltonian of the expression.""" expression = expression_cls.from_mf(mf) if expression.nconfig > 1024: @@ -41,6 +41,7 @@ def test_hamiltonian(mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> No if expression_cls in ADC2._classes: # ADC(2)-x diagonal is set to ADC(2) diagonal in PySCF for better Davidson convergence assert np.allclose(np.diag(hamiltonian), diagonal) + assert isinstance(hamiltonian, np.ndarray) assert hamiltonian.shape == expression.shape assert (expression.nconfig + expression.nsingle) == diagonal.size @@ -51,12 +52,14 @@ def test_hamiltonian(mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> No except NotImplementedError: vh = None + assert isinstance(hv, np.ndarray) assert np.allclose(hv, hamiltonian @ vector) if vh is not None: + assert isinstance(vh, np.ndarray) assert np.allclose(vh, vector @ hamiltonian) -def test_gf_moments(mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> None: +def test_gf_moments(backend: str, mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> None: """Test the Green's function moments of the expression.""" # Get the quantities required from the expression expression = expression_cls.from_mf(mf) @@ -69,17 +72,22 @@ def test_gf_moments(mf: scf.hf.RHF, expression_cls: type[BaseExpression]) -> Non for i, j in itertools.product(range(expression.nphys), repeat=2): bra = expression.get_excitation_bra(j) ket = expression.get_excitation_ket(i) + assert isinstance(bra, np.ndarray) + assert isinstance(ket, np.ndarray) moments[0, j, i] += bra.conj() @ ket moments[1, j, i] += bra.conj() @ hamiltonian @ ket # Compare the moments to the reference ref = expression.build_gf_moments(2) + assert isinstance(ref, np.ndarray) + assert isinstance(moments, np.ndarray) assert np.allclose(ref[0], moments[0]) assert np.allclose(ref[1], moments[1]) def test_static( + backend: str, helper: Helper, mf: scf.hf.RHF, expression_cls: type[BaseExpression], @@ -98,11 +106,13 @@ def test_static( greens_function = exact.result.get_greens_function() static = exact.result.get_static_self_energy() + assert isinstance(gf_moments, np.ndarray) + assert isinstance(static, np.ndarray) assert helper.have_equal_moments(gf_moments, greens_function, 2) assert np.allclose(static, gf_moments[1]) -def test_hf(mf: scf.hf.RHF) -> None: +def test_hf(backend: str, mf: scf.hf.RHF) -> None: """Test the HF expression.""" hf_h = HF.h.from_mf(mf) hf_p = HF.p.from_mf(mf) @@ -136,7 +146,7 @@ def test_hf(mf: scf.hf.RHF) -> None: assert np.allclose(result.get_greens_function().as_perturbed_mo_energy(), mf.mo_energy) -def test_ccsd(mf: scf.hf.RHF) -> None: +def test_ccsd(backend: str, mf: scf.hf.RHF) -> None: """Test the CCSD expression.""" ccsd_h = CCSD.h.from_mf(mf) ccsd_p = CCSD.p.from_mf(mf) @@ -179,7 +189,7 @@ def test_ccsd(mf: scf.hf.RHF) -> None: assert np.allclose(gf_moment_0, np.eye(mf.mol.nao)) -def test_fci(mf: scf.hf.RHF) -> None: +def test_fci(backend: str, mf: scf.hf.RHF) -> None: """Test the FCI expression.""" fci_h = FCI.h.from_mf(mf) fci_p = FCI.p.from_mf(mf) @@ -214,7 +224,7 @@ def test_fci(mf: scf.hf.RHF) -> None: assert np.allclose(gf_moments_ccsd[1], gf_moments_h[1]) -def test_adc2(mf: scf.hf.RHF) -> None: +def test_adc2(backend: str, mf: scf.hf.RHF) -> None: """Test the ADC(2) expression.""" adc_h = ADC2.h.from_mf(mf) adc_p = ADC2.p.from_mf(mf) @@ -249,7 +259,7 @@ def test_adc2(mf: scf.hf.RHF) -> None: assert np.allclose(gf_moment_0, np.eye(mf.mol.nao)) -def test_adc2x(mf: scf.hf.RHF) -> None: +def test_adc2x(backend: str, mf: scf.hf.RHF) -> None: """Test the ADC(2)-x expression.""" adc_h = ADC2x.h.from_mf(mf) adc_p = ADC2x.p.from_mf(mf) @@ -285,7 +295,7 @@ def test_adc2x(mf: scf.hf.RHF) -> None: assert np.allclose(gf_moment_0, np.eye(mf.mol.nao)) -def test_tdagw(mf: scf.hf.RHF, exact_cache: ExactGetter) -> None: +def test_tdagw(backend: str, mf: scf.hf.RHF, exact_cache: ExactGetter) -> None: """Test the TDAGW expression.""" tdagw = TDAGW["dyson"].from_mf(mf) dft = mf.to_rks() From 2d176c2710e106f33725faaf388ed6c24665fe1e Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Tue, 12 Aug 2025 09:28:53 +0100 Subject: [PATCH 2/6] Linting --- dyson/_backend.py | 19 +++++++++++-------- dyson/grids/frequency.py | 2 +- dyson/representations/lehmann.py | 2 +- dyson/solvers/dynamic/corrvec.py | 6 ++++-- dyson/solvers/static/chempot.py | 2 +- dyson/solvers/static/davidson.py | 2 +- dyson/solvers/static/density.py | 3 ++- dyson/solvers/static/exact.py | 2 -- dyson/typing.py | 2 +- pyproject.toml | 4 ++++ tests/conftest.py | 1 + tests/test_chempot.py | 2 +- 12 files changed, 28 insertions(+), 19 deletions(-) diff --git a/dyson/_backend.py b/dyson/_backend.py index c98aec6..d1308ac 100644 --- a/dyson/_backend.py +++ b/dyson/_backend.py @@ -5,13 +5,12 @@ import functools import importlib import os - from types import ModuleType -from typing import Callable, Any - +from typing import TYPE_CHECKING, Any, Callable try: import jax + jax.config.update("jax_enable_x64", True) except ImportError: pass @@ -34,7 +33,7 @@ def set_backend(backend: str) -> None: """Set the backend for :mod:`dyson`.""" - global _BACKEND + global _BACKEND # noqa: PLW0603 if backend not in _BACKENDS: raise ValueError( f"Invalid backend: {backend}. Available backends are: {list(_BACKENDS.keys())}" @@ -42,7 +41,7 @@ def set_backend(backend: str) -> None: _BACKEND = backend -def cast_returned_array(func: Callable[[Any], Any]) -> Callable[[Any], Any]: +def cast_returned_array(func: Callable[..., Any]) -> Callable[..., Any]: """Decorate a function to coerce its returned array to the backend type.""" @functools.wraps(func) @@ -85,6 +84,10 @@ def _load(self) -> ModuleType: return _MODULE_CACHE[key] -numpy = ProxyModule("numpy") -scipy = ProxyModule("scipy") -scipy.optimize = ProxyModule("scipy.optimize") # SciPy doesn't seem to export this +if TYPE_CHECKING: + import numpy + import scipy +else: + numpy = ProxyModule("numpy") # type: ignore[assignment] + scipy = ProxyModule("scipy") # type: ignore[assignment] + scipy.optimize = ProxyModule("scipy.optimize") # type: ignore[assignment] diff --git a/dyson/grids/frequency.py b/dyson/grids/frequency.py index 3f77eae..5b75513 100644 --- a/dyson/grids/frequency.py +++ b/dyson/grids/frequency.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from dyson import numpy as np -from dyson import util, scipy +from dyson import scipy, util from dyson.grids.grid import BaseGrid from dyson.representations.enums import Component, Ordering, Reduction diff --git a/dyson/representations/lehmann.py b/dyson/representations/lehmann.py index 7d87fb8..f0989b0 100644 --- a/dyson/representations/lehmann.py +++ b/dyson/representations/lehmann.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast from dyson import numpy as np -from dyson import util, scipy +from dyson import scipy, util from dyson.representations.enums import Reduction from dyson.representations.representation import BaseRepresentation from dyson.typing import Array diff --git a/dyson/solvers/dynamic/corrvec.py b/dyson/solvers/dynamic/corrvec.py index dd97f4d..40a7925 100644 --- a/dyson/solvers/dynamic/corrvec.py +++ b/dyson/solvers/dynamic/corrvec.py @@ -27,9 +27,11 @@ from dyson.representations.lehmann import Lehmann from dyson.typing import Array -if _BACKEND == "jax": +if _BACKEND == "jax" and not TYPE_CHECKING: # No LGMRES in JAX, so use GMRES - from dyson.scipy.sprase.linalg import gmres as lgmres + from dyson import scipy + + lgmres = scipy.sparse.linalg.gmres class CorrectionVector(DynamicSolver): diff --git a/dyson/solvers/static/chempot.py b/dyson/solvers/static/chempot.py index 9c40592..13cfbfa 100644 --- a/dyson/solvers/static/chempot.py +++ b/dyson/solvers/static/chempot.py @@ -6,7 +6,7 @@ import warnings from typing import TYPE_CHECKING -from dyson import console, printing, util, scipy +from dyson import console, printing, scipy, util from dyson import numpy as np from dyson.representations.lehmann import Lehmann, shift_energies from dyson.solvers.solver import StaticSolver diff --git a/dyson/solvers/static/davidson.py b/dyson/solvers/static/davidson.py index 7bb8acb..3ca5948 100644 --- a/dyson/solvers/static/davidson.py +++ b/dyson/solvers/static/davidson.py @@ -30,7 +30,7 @@ from dyson.expressions.expression import BaseExpression from dyson.typing import Array -if _BACKEND == "jax": +if _BACKEND == "jax" and not TYPE_CHECKING: # Try to get the JAX version of the Davidson algorithm, only available for Hermitian case try: from pyscfad.lib.linalg_helper import davidson1 diff --git a/dyson/solvers/static/density.py b/dyson/solvers/static/density.py index 4f91d45..8b7b6b7 100644 --- a/dyson/solvers/static/density.py +++ b/dyson/solvers/static/density.py @@ -43,7 +43,8 @@ def __call__( """ ... -if _BACKEND == "jax": + +if _BACKEND == "jax" and not TYPE_CHECKING: # Try to get the JAX version of DIIS try: from pyscfad.lib.diis import DIIS diff --git a/dyson/solvers/static/exact.py b/dyson/solvers/static/exact.py index fff9596..0111516 100644 --- a/dyson/solvers/static/exact.py +++ b/dyson/solvers/static/exact.py @@ -38,7 +38,6 @@ def project_eigenvectors( is defined by the null space of the projector formed by the outer product of these vectors. """ hermitian = ket is None - nphys = bra.shape[0] if not hermitian and eigvecs.ndim == 2: raise ValueError( "bra and ket both passed implying a non-hermitian system, but eigvecs is 2D." @@ -79,7 +78,6 @@ def project_eigenvectors( return eigvecs - def orthogonalise_self_energy( static: Array, self_energy: Lehmann, diff --git a/dyson/typing.py b/dyson/typing.py index 8cf9534..3ef33f0 100644 --- a/dyson/typing.py +++ b/dyson/typing.py @@ -2,6 +2,6 @@ from __future__ import annotations -from dyson import numpy +import numpy Array = numpy.ndarray diff --git a/pyproject.toml b/pyproject.toml index 1af39e0..1264a86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,10 @@ ignore_missing_imports = true module = "pyscf.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "pyscfad.*" +ignore_missing_imports = true + [tool.coverage.run] branch = true source = [ diff --git a/tests/conftest.py b/tests/conftest.py index 1a20421..4a69731 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,6 +77,7 @@ def pytest_generate_tests(metafunc): # type: ignore if "backend" in metafunc.fixturenames: try: import jax # noqa: F401 + metafunc.parametrize("backend", ["numpy", "jax"], scope="function") except ImportError: # If JAX is not available, only use numpy backend diff --git a/tests/test_chempot.py b/tests/test_chempot.py index 07804b7..d9a1f80 100644 --- a/tests/test_chempot.py +++ b/tests/test_chempot.py @@ -6,8 +6,8 @@ import pytest -from dyson.solvers import AufbauPrinciple, AuxiliaryShift from dyson import numpy as np +from dyson.solvers import AufbauPrinciple, AuxiliaryShift from .conftest import _get_central_result From 814f8200d04f0ac378515857a3b11820cc8d255f Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Tue, 12 Aug 2025 09:33:13 +0100 Subject: [PATCH 3/6] Fix import --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4a69731..5a79668 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,7 +76,7 @@ def pytest_generate_tests(metafunc): # type: ignore metafunc.parametrize("expression_method", expressions, ids=ids) if "backend" in metafunc.fixturenames: try: - import jax # noqa: F401 + import jax # noqa: PLC0415, F401 metafunc.parametrize("backend", ["numpy", "jax"], scope="function") except ImportError: From 2222d0bdd8a6788384c7fee6070710c422831e3a Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Tue, 12 Aug 2025 09:34:46 +0100 Subject: [PATCH 4/6] Only run tests for available backends --- .github/workflows/ci.yaml | 2 +- dyson/_backend.py | 19 ++++++++++--------- tests/conftest.py | 10 ++-------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7651699..4bf4290 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -31,7 +31,7 @@ jobs: - name: Install dyson run: | python -m pip install wheel - python -m pip install .[dev] + python -m pip install .[dev,jax] - name: Linting run: | ruff check diff --git a/dyson/_backend.py b/dyson/_backend.py index d1308ac..b208df2 100644 --- a/dyson/_backend.py +++ b/dyson/_backend.py @@ -12,23 +12,24 @@ import jax jax.config.update("jax_enable_x64", True) + _HAVE_JAX = True except ImportError: - pass + _HAVE_JAX = False _BACKEND = os.environ.get("DYSON_BACKEND", "numpy") _BACKEND_WARNINGS = os.environ.get("DYSON_BACKEND_WARNINGS", "0") == "1" _MODULE_CACHE: dict[tuple[str, str], ModuleType] = {} -_BACKENDS = { - "numpy": { - "numpy": "numpy", - "scipy": "scipy", - }, - "jax": { +_BACKENDS: dict[str, dict[str, str]] = {} +_BACKENDS["numpy"] = { + "numpy": "numpy", + "scipy": "scipy", +} +if _HAVE_JAX: + _BACKENDS["jax"] = { "numpy": "jax.numpy", "scipy": "jax.scipy", - }, -} + } def set_backend(backend: str) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 5a79668..8a9ea6e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ from pyscf import gto, scf from dyson import numpy as np -from dyson._backend import set_backend +from dyson._backend import set_backend, _BACKENDS from dyson.expressions import ADC2, CCSD, FCI, HF, TDAGW, ADC2x from dyson.representations.lehmann import Lehmann from dyson.representations.spectral import Spectral @@ -75,13 +75,7 @@ def pytest_generate_tests(metafunc): # type: ignore ids.append(name) metafunc.parametrize("expression_method", expressions, ids=ids) if "backend" in metafunc.fixturenames: - try: - import jax # noqa: PLC0415, F401 - - metafunc.parametrize("backend", ["numpy", "jax"], scope="function") - except ImportError: - # If JAX is not available, only use numpy backend - metafunc.parametrize("backend", ["numpy"], scope="function") + metafunc.parametrize("backend", list(_BACKENDS.keys()), scope="session") class Helper: From bd248c2e87140dfe1a0ba2e25eac7a9c3de0bf36 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Tue, 12 Aug 2025 09:37:50 +0100 Subject: [PATCH 5/6] Linting --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8a9ea6e..5111c7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ from pyscf import gto, scf from dyson import numpy as np -from dyson._backend import set_backend, _BACKENDS +from dyson._backend import _BACKENDS, set_backend from dyson.expressions import ADC2, CCSD, FCI, HF, TDAGW, ADC2x from dyson.representations.lehmann import Lehmann from dyson.representations.spectral import Spectral From 3b6c18a9b04c2d5249ea03b5b906818cf1fc8023 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Tue, 12 Aug 2025 10:38:26 +0100 Subject: [PATCH 6/6] Fix conjugation --- dyson/solvers/static/exact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dyson/solvers/static/exact.py b/dyson/solvers/static/exact.py index 0111516..27edfa5 100644 --- a/dyson/solvers/static/exact.py +++ b/dyson/solvers/static/exact.py @@ -70,7 +70,7 @@ def project_eigenvectors( # Return the physical vectors to the original basis left = util.rotate_subspace(left.T, unorth.conj()).T - right = util.rotate_subspace(right.T, unorth).T + right = util.rotate_subspace(right.T, unorth.T).T # Rotate the eigenvectors eigvecs = np.array([left.T.conj() @ eigvecs[0], right.T.conj() @ eigvecs[1]])