diff --git a/RELEASES.md b/RELEASES.md index b28bfbda7..c734202c0 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,6 @@ # Releases + ## 0.9.7.dev0 This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. @@ -12,8 +13,13 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) - Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) +<<<<<<< HEAD - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +======= +- Add cost functions between linear operators following + [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920) (PR #792) +>>>>>>> 8d13c55 (edits as per PR #792) #### Closed issues diff --git a/ot/backend.py b/ot/backend.py index d7fed4e2f..0568f2e2f 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -622,6 +622,46 @@ def clip(self, a, a_min=None, a_max=None): """ raise NotImplementedError() + def real(self, a): + """ + Return the real part of the tensor element-wise. + + This function follows the api from :any:`numpy.real` + + See: https://numpy.org/doc/stable/reference/generated/numpy.real.html + """ + raise NotImplementedError() + + def imag(self, a): + """ + Return the imaginary part of the tensor element-wise. + + This function follows the api from :any:`numpy.imag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.imag.html + """ + raise NotImplementedError() + + def conj(self, a): + """ + Return the complex conjugate, element-wise. + + This function follows the api from :any:`numpy.conj` + + See: https://numpy.org/doc/stable/reference/generated/numpy.conj.html + """ + raise NotImplementedError() + + def arccos(self, a): + """ + Trigonometric inverse cosine, element-wise. + + This function follows the api from :any:`numpy.arccos` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arccos.html + """ + raise NotImplementedError() + def repeat(self, a, repeats, axis=None): r""" Repeats elements of a tensor. @@ -1193,7 +1233,7 @@ def _from_numpy(self, a, type_as=None): elif isinstance(a, float): return a else: - return a.astype(type_as.dtype) + return np.asarray(a, dtype=type_as.dtype) def set_gradients(self, val, inputs, grads): # No gradients for numpy @@ -1313,6 +1353,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return np.clip(a, a_min, a_max) + def real(self, a): + return np.real(a) + + def imag(self, a): + return np.imag(a) + + def conj(self, a): + return np.conj(a) + + def arccos(self, a): + return np.arccos(a) + def repeat(self, a, repeats, axis=None): return np.repeat(a, repeats, axis) @@ -1604,7 +1656,7 @@ def _from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) + return self._change_device(jnp.asarray(a, dtype=type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1730,6 +1782,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return jnp.clip(a, a_min, a_max) + def real(self, a): + return jnp.real(a) + + def imag(self, a): + return jnp.imag(a) + + def conj(self, a): + return jnp.conj(a) + + def arccos(self, a): + return jnp.arccos(a) + def repeat(self, a, repeats, axis=None): return jnp.repeat(a, repeats, axis) @@ -1803,7 +1867,9 @@ def randperm(self, size, type_as=None): if not isinstance(size, int): raise ValueError("size must be an integer") if type_as is not None: - return jax.random.permutation(subkey, size).astype(type_as.dtype) + return jnp.asarray( + jax.random.permutation(subkey, size), dtype=type_as.dtype + ) else: return jax.random.permutation(subkey, size) @@ -2227,6 +2293,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return torch.clamp(a, a_min, a_max) + def real(self, a): + return torch.real(a) + + def imag(self, a): + return torch.imag(a) + + def conj(self, a): + return torch.conj(a) + + def arccos(self, a): + return torch.acos(a) + def repeat(self, a, repeats, axis=None): return torch.repeat_interleave(a, repeats, dim=axis) @@ -2728,6 +2806,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return cp.clip(a, a_min, a_max) + def real(self, a): + return cp.real(a) + + def imag(self, a): + return cp.imag(a) + + def conj(self, a): + return cp.conj(a) + + def arccos(self, a): + return cp.arccos(a) + def repeat(self, a, repeats, axis=None): return cp.repeat(a, repeats, axis) @@ -2819,7 +2909,7 @@ def randperm(self, size, type_as=None): return self.rng_.permutation(size) else: with cp.cuda.Device(type_as.device): - return self.rng_.permutation(size).astype(type_as.dtype) + return cp.asarray(self.rng_.permutation(size), dtype=type_as.dtype) def coo_matrix(self, data, rows, cols, shape=None, type_as=None): data = self.from_numpy(data) @@ -3162,6 +3252,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return tnp.clip(a, a_min, a_max) + def real(self, a): + return tnp.real(a) + + def imag(self, a): + return tnp.imag(a) + + def conj(self, a): + return tnp.conj(a) + + def arccos(self, a): + return tnp.arccos(a) + def repeat(self, a, repeats, axis=None): return tnp.repeat(a, repeats, axis) diff --git a/ot/sgot.py b/ot/sgot.py new file mode 100644 index 000000000..c8ee5b91c --- /dev/null +++ b/ot/sgot.py @@ -0,0 +1,428 @@ +# -*- coding: utf-8 -*- +""" +Spectral-Grassmann optimal transport for linear operators. + +This module implements the Spectral-Grassmann Wasserstein framework for +comparing dynamical systems via their learned operator representations. + +It provides tools to extract spectral "atoms" (eigenvalues and associated +eigenspaces) from linear operators and to compute an optimal transport metric +that combines a spectral term on eigenvalues with a Grassmannian term on +eigenspaces. +""" + +# Author: Sienna O'Shea +# Thibaut Germain +# License: MIT License + +import ot +from ot.backend import get_backend + +##################################################################################################################################### +##################################################################################################################################### +### NORMALISATION AND OPERATOR ATOMS ### +##################################################################################################################################### +##################################################################################################################################### + + +def eigenvalue_cost_matrix(Ds, Dt, q=1, eigen_scaling=None, nx=None): + """Compute pairwise eigenvalue distances for source and target domains. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Dt: array-like, shape (n_t,) + Target eigenvalues. + eigen_scaling: None or array-like of length 2, optional + Scaling (real_scale, imag_scale) applied to eigenvalues before computing + distances. If None, defaults to (1.0, 1.0). Accepts tuple/list or + array/tensor with two entries. + + Returns + ---------- + C: np.ndarray, shape (n_s, n_t) + Eigenvalue cost matrix. + """ + if nx is None: + nx = get_backend(Ds, Dt) + + if eigen_scaling is None: + real_scale, imag_scale = 1.0, 1.0 + else: + if isinstance(eigen_scaling, (tuple, list)): + real_scale, imag_scale = eigen_scaling + else: + real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1] + + Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale + Dtn = nx.real(Dt) * real_scale + 1j * nx.imag(Dt) * imag_scale + prod = Dsn[:, None] - Dtn[None, :] + prod = nx.real(prod * nx.conj(prod)) + return prod ** (q / 2) + + +def _normalize_columns(A, nx, eps=1e-12): + """Normalize the columns of an array with a backend-aware norm. + + Parameters + ---------- + A: array-like, shape (d, n) + Input array whose columns are normalized. + nx: module + Backend (NumPy-compatible) used for math operations. + eps: float, optional + Minimum norm value to avoid division by zero, default 1e-12. + + Returns + ---------- + A_norm: array-like, shape (d, n) + Column-normalized array. + """ + nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) + nrm = nx.real(nrm) # norm is real; avoid complex dtype for maximum (e.g. torch) + nrm = nx.maximum(nrm, eps) + return A / nrm + + +def _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): + """Compute the normalized inner-product delta matrix for eigenspaces. + + Parameters + ---------- + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + eps: float, optional + Minimum norm value used in normalization, default 1e-12. + + Returns + ---------- + delta: array-like, shape (n_s, n_t) + Delta matrix with entries in [0, 1]. + """ + if nx is None: + nx = get_backend(Rs, Ls, Rt, Lt) + + Rsn = _normalize_columns(Rs, nx=nx, eps=eps) + Lsn = _normalize_columns(Ls, nx=nx, eps=eps) + Rtn = _normalize_columns(Rt, nx=nx, eps=eps) + Ltn = _normalize_columns(Lt, nx=nx, eps=eps) + + Cr = nx.dot(nx.conj(Rsn).T, Rtn) + Cl = nx.dot(nx.conj(Lsn).T, Ltn) + + delta = nx.abs(Cr * Cl) + delta = nx.clip(delta, 0.0, 1.0) + return delta + + +def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300): + """Compute squared Grassmannian distances from delta similarities. + + Parameters + ---------- + delta: array-like + Similarity values in [0, 1]. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + eps: float, optional + Minimum value used for numerical stability, default 1e-300. + + Returns + ---------- + dist2: array-like + Squared Grassmannian distance(s). + """ + if nx is None: + nx = get_backend(delta) + + delta = nx.clip(delta, 0.0, 1.0) + + if grassman_metric == "geodesic": + return nx.arccos(delta) ** 2 + if grassman_metric == "chordal": + return 1.0 - delta**2 + if grassman_metric == "procrustes": + return 2.0 * (1.0 - delta) + if grassman_metric == "martin": + # Martin-type Grassmann metric: -log(delta^2) with lower clamp at eps. + # We deliberately avoid any upper threshold to stay close to the + # information-geometric interpretation in Germain et al. (2025). + delta2 = nx.maximum(delta**2, eps) + return -nx.log(delta2) + raise ValueError(f"Unknown grassman_metric: {grassman_metric}") + + +##################################################################################################################################### +##################################################################################################################################### +### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### +##################################################################################################################################### +##################################################################################################################################### +def sgot_cost_matrix( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=0.5, + p=2, + q=1, + grassman_metric="chordal", + eigen_scaling=None, + nx=None, +): + r"""Compute the SGOT cost matrix between two spectral decompositions. + + This returns the discrete ground cost matrix used in the SGOT optimal transport + objective. Each spectral atom is :math:`z_i=(\lambda_i, V_i)` where + :math:`\lambda_i \in \mathbb{C}` is an eigenvalue and :math:`V_i` is the + associated (bi-orthogonal) eigenspace point. + + .. math:: + C_2(i,j) \;=\; \eta\,C_\lambda(i,j) \;+\; (1-\eta)\,C_G(i,j), + + with spectral term + + .. math:: + C_\lambda(i,j) \;=\; \big|\lambda_i - \lambda'_j\big|^{q}, + + and Grassmann term computed from a similarity score :math:`\delta_{ij}\in[0,1]` + built from left/right eigenvectors + + .. math:: + \delta_{ij} \;=\; \left|\langle r_i, r'_j\rangle\,\langle \ell_i, \ell'_j\rangle\right|. + + Depending on ``grassman_metric``, the Grassmann contribution is: + + - ``"chordal"``: + .. math:: + C_G(i,j) \;=\; 1 - \delta_{ij}^2 + - ``"geodesic"``: + .. math:: + C_G(i,j) \;=\; \arccos(\delta_{ij})^2 + - ``"procrustes"``: + .. math:: + C_G(i,j) \;=\; 2(1-\delta_{ij}) + - ``"martin"``: + .. math:: + C_G(i,j) \;=\; -\log\!\left(\max(\delta_{ij}^2,\varepsilon)\right) + + Finally, we return a matrix suited for a :math:`p`-Wasserstein objective by + treating :math:`C_2 \approx d^2` and outputting + + .. math:: + C(i,j) \;=\; \big(\operatorname{Re}(C_2(i,j))\big)^{p/2}. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Eigenvalues of operator T1. + Rs: array-like, shape (L, n_s) + Right eigenvectors of operator T1. + Ls: array-like, shape (L, n_s) + Left eigenvectors of operator T1. + Dt: array-like, shape (n_t,) + Eigenvalues of operator T2. + Rt: array-like, shape (L, n_t) + Right eigenvectors of operator T2. + Lt: array-like, shape (L, n_t) + Left eigenvectors of operator T2. + eta: float, optional + Weighting between spectral and Grassmann terms, default 0.5. + p: int, optional + Exponent defining the OT ground cost. The returned cost is :math:`d^p` with + :math:`d^2 \approx C_2`. Default is 2. + q: int, optional + Exponent applied to the eigenvalue distance in the spectral term. + Default is 1. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + eigen_scaling: None or array-like of length 2, optional + Scaling ``(real_scale, imag_scale)`` applied to eigenvalues before computing + :math:`C_\lambda`. If provided, eigenvalues are transformed as + :math:`\lambda \mapsto \alpha\operatorname{Re}(\lambda) + i\,\beta\operatorname{Im}(\lambda)`. + If None, defaults to ``(1.0, 1.0)``. Accepts tuple/list or array/tensor with + two entries. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + + Returns + ---------- + C: array-like, shape (n_s, n_t) + SGOT cost matrix :math:`C = d^p`. + + References + ---------- + Germain et al., *Spectral-Grassmann Optimal Transport* (SGOT). + """ + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + + if Ds.ndim != 1 or Dt.ndim != 1: + raise ValueError( + f"sgot_cost_matrix() expects Ds, Dt 1D; " + f"got Ds {getattr(Ds, 'shape', None)}, Dt {getattr(Dt, 'shape', None)}" + ) + + if Rs.shape != Ls.shape or Rt.shape != Lt.shape: + raise ValueError( + "Right/left eigenvector shapes must match; got " + f"(Rs,Ls)=({Rs.shape},{Ls.shape}), (Rt,Lt)=({Rt.shape},{Lt.shape})" + ) + + if Rs.shape[1] != Ds.shape[0] or Rt.shape[1] != Dt.shape[0]: + raise ValueError( + "Eigenvector columns must match eigenvalues: " + f"Rs {Rs.shape[1]} vs Ds {Ds.shape[0]}, " + f"Rt {Rt.shape[1]} vs Dt {Dt.shape[0]}" + ) + + C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx) + + delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx) + C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) + + C2 = eta * C_lambda + (1.0 - eta) * C_grass + C = nx.real(C2) ** (p / 2.0) + + return C + + +def _validate_sgot_metric_inputs(Ds, Dt): + """Validate that eigenvalue inputs for SGOT metric are 1D.""" + Ds_shape = getattr(Ds, "shape", None) + Dt_shape = getattr(Dt, "shape", None) + Ds_ndim = getattr(Ds, "ndim", None) + Dt_ndim = getattr(Dt, "ndim", None) + if Ds_ndim != 1 or Dt_ndim != 1: + raise ValueError( + "sgot_metric() expects Ds and Dt to be 1D (n,), " + f"got Ds shape {Ds_shape} and Dt shape {Dt_shape}" + ) + + +def sgot_metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=0.5, + p=2, + q=1, + r=2, + grassman_metric="chordal", + eigen_scaling=None, + Ws=None, + Wt=None, + nx=None, +): + r"""Compute the SGOT metric between two spectral decompositions. + + This function computes a discrete optimal transport problem between two measures + over spectral atoms :math:`z_i=(\lambda_i, V_i)` and :math:`z'_j=(\lambda'_j, V'_j)`. + Using the ground cost matrix :math:`C=d^p` returned by :func:`sgot_cost_matrix`, + we solve: + + .. math:: + P^\star \in \arg\min_{P\in\Pi(W_s, W_t)} \langle C, P\rangle, + + and compute the associated :math:`p`-Wasserstein objective: + + .. math:: + \mathrm{obj} \;=\; \left(\sum_{i,j} C(i,j)\,P^\star_{ij}\right)^{1/p}. + + This implementation returns an additional outer root: + + .. math:: + \mathrm{SGOT} \;=\; \mathrm{obj}^{1/r}. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Eigenvalues of operator T1. + Rs: array-like, shape (L, n_s) + Right eigenvectors of operator T1. + Ls: array-like, shape (L, n_s) + Left eigenvectors of operator T1. + Dt: array-like, shape (n_t,) + Eigenvalues of operator T2. + Rt: array-like, shape (L, n_t) + Right eigenvectors of operator T2. + Lt: array-like, shape (L, n_t) + Left eigenvectors of operator T2. + eta: float, optional + Weighting between spectral and Grassmann terms, default 0.5. + p: int, optional + Exponent defining the OT ground cost and Wasserstein order. The cost matrix + is :math:`d^p` and the OT objective is raised to the power :math:`1/p`. + Default is 2. + q: int, optional + Exponent applied to the eigenvalue distance in the spectral term. + Default is 1. + r: int, optional + Outer root applied to the Wasserstein objective. Default is 2. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + eigen_scaling: None or array-like of length 2, optional + Scaling ``(real_scale, imag_scale)`` applied to eigenvalues before computing + the spectral part of the cost. If None, defaults to ``(1.0, 1.0)``. + Ws: array-like, shape (n_s,), optional + Source distribution. If None, uses a uniform distribution. + Wt: array-like, shape (n_t,), optional + Target distribution. If None, uses a uniform distribution. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + + Returns + ---------- + dist: float + SGOT metric value. + + References + ---------- + Germain et al., *Spectral-Grassmann Optimal Transport* (SGOT). + """ + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + + _validate_sgot_metric_inputs(Ds, Dt) + + C = sgot_cost_matrix( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=eta, + p=p, + q=q, + grassman_metric=grassman_metric, + eigen_scaling=eigen_scaling, + nx=nx, + ) + + if Ws is None: + Ws = nx.ones((C.shape[0],), type_as=C) / float(C.shape[0]) + if Wt is None: + Wt = nx.ones((C.shape[1],), type_as=C) / float(C.shape[1]) + + Ws = Ws / nx.sum(Ws) + Wt = Wt / nx.sum(Wt) + + obj = ot.emd2(Ws, Wt, nx.real(C)) + obj = obj ** (1.0 / p) + return obj ** (1.0 / float(r)) diff --git a/test/test_backend.py b/test/test_backend.py index cd6a85762..fe6af9c67 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -339,6 +339,9 @@ def test_func_backends(nx): sp_col = np.array([0, 3, 1, 2, 2]) sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64) + M_complex = M + 1j * rnd.randn(10, 3) + v_acos = np.clip(v, -0.99, 0.99) + lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: @@ -723,6 +726,24 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("atan2") + M_complex_b = nx.from_numpy(M_complex) + A = nx.real(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("real") + + A = nx.imag(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("imag") + + A = nx.conj(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("conj") + + v_acos_b = nx.from_numpy(v_acos) + A = nx.arccos(v_acos_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("arccos") + A = nx.transpose(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append("transpose") diff --git a/test/test_sgot.py b/test/test_sgot.py new file mode 100644 index 000000000..af64e20fb --- /dev/null +++ b/test/test_sgot.py @@ -0,0 +1,264 @@ +"""Tests for ot.sgot module""" + +# Author: Sienna O'Shea +# Thibaut Germain +# License: MIT License + +import numpy as np +import pytest + +from ot.sgot import ( + eigenvalue_cost_matrix, + _delta_matrix_1d, + _grassmann_distance_squared, + sgot_cost_matrix, + sgot_metric, +) + + +def random_atoms(d=8, r=4, seed=42): + """Deterministic complex atoms for given d, r.""" + + def _rand_complex(shape, seed_): + rng = np.random.RandomState(seed_) + real = rng.randn(*shape) + imag = rng.randn(*shape) + return real + 1j * imag + + Ds = _rand_complex((r,), seed + 0) + Rs = _rand_complex((d, r), seed + 1) + Ls = _rand_complex((d, r), seed + 2) + Dt = _rand_complex((r,), seed + 3) + Rt = _rand_complex((d, r), seed + 4) + Lt = _rand_complex((d, r), seed + 5) + + return Ds, Rs, Ls, Dt, Rt, Lt + + +# --------------------------------------------------------------------- +# DATA / SAMPLING TESTS +# --------------------------------------------------------------------- + + +def test_random_d_r(nx): + """Sample d and r uniformly and run sgot_cost_matrix (and sgot_metric when available) with those shapes.""" + rng = np.random.RandomState(0) + d_min, d_max = 4, 12 + r_min, r_max = 2, 6 + for _ in range(5): + d = int(rng.randint(d_min, d_max + 1)) + r = int(rng.randint(r_min, r_max + 1)) + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms(d=d, r=r) + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + C_np = nx.to_numpy(C) + np.testing.assert_allclose(C_np.shape, (r, r)) + assert np.all(np.isfinite(C_np)) and np.all(C_np >= 0) + try: + dist = sgot_metric(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + dist_np = nx.to_numpy(dist) + assert np.isfinite(dist_np) and dist_np >= 0 + except TypeError: + pytest.skip("sgot_metric() unavailable (emd_c signature mismatch)") + + +# --------------------------------------------------------------------- +# DELTA MATRIX TESTS +# --------------------------------------------------------------------- + + +def test_eigenvalue_cost_matrix_simple(): + Ds = np.array([0.0, 1.0]) + Dt = np.array([0.0, 2.0]) + C = eigenvalue_cost_matrix(Ds, Dt, q=2) + expected = np.array([[0.0, 4.0], [1.0, 1.0]]) + np.testing.assert_allclose(C, expected) + + +def test_delta_matrix_1d_identity(): + r = 4 + I = np.eye(r, dtype=complex) + delta = _delta_matrix_1d(I, I, I, I) + np.testing.assert_allclose(delta, np.eye(r), atol=1e-12) + + +def test_delta_matrix_1d_swap_invariance(): + d, r = 6, 3 + _, R, _, _, _, _ = random_atoms(d=d, r=r) + L = R.copy() + delta1 = _delta_matrix_1d(R, L, R, L) + delta2 = _delta_matrix_1d(L, R, L, R) + np.testing.assert_allclose(delta1, delta2, atol=1e-12) + + +# --------------------------------------------------------------------- +# GRASSMANN DISTANCE TESTS +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_grassmann_zero_distance(grassman_metric, nx): + delta = nx.from_numpy(np.ones((3, 3))) + dist2 = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) + dist2_np = nx.to_numpy(dist2) + np.testing.assert_allclose(dist2_np, 0.0, atol=1e-12) + + +def test_grassmann_distance_invalid_name(): + delta = np.ones((2, 2)) + with pytest.raises(ValueError): + _grassmann_distance_squared(delta, grassman_metric="cordal") + + +# --------------------------------------------------------------------- +# COST TESTS +# --------------------------------------------------------------------- + + +def test_cost_self_zero(nx): + """(D_S R_S L_S D_S): diagonal of sgot_cost_matrix matrix (same atom to same atom) should be near zero.""" + Ds, Rs, Ls, _, _, _ = random_atoms() + Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) + C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2) + C_np = nx.to_numpy(C) + np.testing.assert_allclose(np.diag(C_np), np.zeros(C_np.shape[0]), atol=1e-10) + np.testing.assert_allclose(C_np, C_np.T, atol=1e-10) + + +def test_grassmann_cost_reference(nx): + """Cost with same inputs and HPs should be deterministic (np.testing.assert_allclose).""" + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + eta, p, q = 0.5, 2, 1 + C1 = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + C2 = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + np.testing.assert_allclose(nx.to_numpy(C1), nx.to_numpy(C2), atol=1e-12) + + +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_grassmann_cost_basic_properties(grassman_metric, nx): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + C = sgot_cost_matrix( + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, grassman_metric=grassman_metric + ) + C_np = nx.to_numpy(C) + assert C_np.shape == (Ds.shape[0], Dt.shape[0]) + assert np.all(np.isfinite(C_np)) + assert np.all(C_np >= 0) + + +def test_sgot_cost_input_validation(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + + with pytest.raises(ValueError): + sgot_cost_matrix(Ds.reshape(-1, 1), Rs, Ls, Dt, Rt, Lt) + + with pytest.raises(ValueError): + sgot_cost_matrix(Ds, Rs[:, :-1], Ls, Dt, Rt, Lt) + + +# --------------------------------------------------------------------- +# METRIC TESTS +# --------------------------------------------------------------------- + + +def test_sgot_metric_self_zero(nx): + Ds, Rs, Ls, _, _, _ = random_atoms() + Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) + dist = sgot_metric(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2, nx=nx) + dist_np = nx.to_numpy(dist) + assert np.isfinite(dist_np) + assert abs(float(dist_np)) < 5e-4 + + +def test_sgot_metric_symmetry(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + d1 = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt) + d2 = sgot_metric(Dt, Rt, Lt, Ds, Rs, Ls) + np.testing.assert_allclose(d1, d2, atol=1e-8) + + +def test_sgot_metric_with_weights(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + r = Ds.shape[0] + + rng = np.random.RandomState(1) + Ws = rng.rand(r) + Ws = Ws / np.sum(Ws) + + Wt = rng.rand(r) + Wt = Wt / np.sum(Wt) + + dist = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt, Ws=Ws, Wt=Wt) + assert np.isfinite(dist) + + +# --------------------------------------------------------------------- +# HYPERPARAMETER SWEEP TEST +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "eta, p, q, grassman_metric", + [ + (0.5, 1, 1, "geodesic"), + (0.5, 2, 1, "chordal"), + (0.3, 2, 2, "procrustes"), + (0.7, 1, 2, "martin"), + ], +) +def test_hyperparameter_sweep_cost(nx, eta, p, q, grassman_metric): + """Sweep over a set of fixed HPs and run cost().""" + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + + C = sgot_cost_matrix( + Ds_b, + Rs_b, + Ls_b, + Dt_b, + Rt_b, + Lt_b, + eta=eta, + p=p, + q=q, + grassman_metric=grassman_metric, + ) + C_np = nx.to_numpy(C) + assert C_np.shape == (Ds.shape[0], Dt.shape[0]) + assert np.all(np.isfinite(C_np)) + assert np.all(C_np >= 0) + + +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_hyperparameter_sweep(grassman_metric): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + rng = np.random.RandomState(3) + eta = rng.uniform(0.0, 1.0) + p = rng.choice([1, 2]) + q = rng.choice([1, 2]) + r = rng.choice([1, 2]) + + dist = sgot_metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=eta, + p=p, + q=q, + r=r, + grassman_metric=grassman_metric, + ) + + assert np.isfinite(dist) + assert dist >= 0