From fd7e2b4728aeb777259d20bf6f4e4bc7e8f036ce Mon Sep 17 00:00:00 2001 From: thibaut-germain Date: Mon, 9 Feb 2026 14:29:44 +0100 Subject: [PATCH 01/10] Sienna & Thibaut works --- RELEASES.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index 1660913cd..33cc56992 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,13 @@ # Releases + +## Upcomming 0.9.7.post1 + +#### New features +The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920). + + + ## 0.9.7.dev0 This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. From 9b49c768b1a0415ff6a020cbb9eeaf338056e536 Mon Sep 17 00:00:00 2001 From: thibaut-germain Date: Mon, 9 Feb 2026 15:05:52 +0100 Subject: [PATCH 02/10] add sgot file --- ot/sgot.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 ot/sgot.py diff --git a/ot/sgot.py b/ot/sgot.py new file mode 100644 index 000000000..5811827da --- /dev/null +++ b/ot/sgot.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" +Optimal transport for linear operators. +""" + +# Author: Sienna O'Shea +# Thibaut Germain +# License: MIT License From ba3303287f4efc3def0395b5399e45c14928c070 Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Sun, 15 Feb 2026 21:52:31 +0100 Subject: [PATCH 03/10] first draft of sgot.py --- ot/sgot.py | 911 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 910 insertions(+), 1 deletion(-) diff --git a/ot/sgot.py b/ot/sgot.py index 5811827da..929078f7e 100644 --- a/ot/sgot.py +++ b/ot/sgot.py @@ -1,8 +1,917 @@ # -*- coding: utf-8 -*- """ -Optimal transport for linear operators. +Spectral-Grassmann optimal transport for linear operators. + +This module implements the Spectral-Grassmann Wasserstein framework for +comparing dynamical systems via their learned operator representations. + +It provides tools to extract spectral "atoms" (eigenvalues and associated +eigenspaces) from linear operators and to compute an optimal transport metric +that combines a spectral term on eigenvalues with a Grassmannian term on +eigenspaces. """ # Author: Sienna O'Shea # Thibaut Germain # License: MIT License + +import numpy as np +import ot +from sklearn.utils.extmath import randomized_svd +from ot.backend import get_backend + +### +# Settings : (Ds,Rs,Ls) if primal, (Ds,Xs,prs,pls) if dual +### + +##################################################################################################################################### +##################################################################################################################################### +### PRINCIPAL ANGLE METRICS ### +##################################################################################################################################### +##################################################################################################################################### + + +def hs_metric(Ds, Rs, Ls, Dt, Rt, Lt, sampfreqs: int = 1, sampfreqt: int = 1): + """Compute the Hilbert-Schmidt (Frobenius) distance between two operators. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Dt: array-like, shape (n_t,) + Target eigenvalues. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + sampfreqs: int, optional, sampling frequency for the source operator with default 1 + sampfreqt: int, optional, sampling frequency for the target operator with default 1 + + Returns + ---------- + dist: float, Frobenius norm + """ + Ts = Rs @ (np.exp(Ds / sampfreqs).reshape(-1, 1) * Ls.conj().T) + Tt = Rt @ (np.exp(Dt / sampfreqt).reshape(-1, 1) * Lt.conj().T) + C = Ts - Tt + return np.linalg.norm(C, "fro") + + +def operator_metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + sampfreqs: int = 1, + sampfreqt: int = 1, + exact: bool = False, + n_iter: int = 5, + random_state: int = None, +): + """Compute the spectral norm distance between two reconstructed operators. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Dt: array-like, shape (n_t,) + Target eigenvalues. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + sampfreqs: int, optional + Sampling frequency for the source operator with default 1 + sampfreqt: int, optional + Sampling frequency for the target operator with default 1 + exact: bool, optional + n_iter: int, optional + random_state: int or None, optional + + Returns + ---------- + dist: float + """ + Ts = Rs @ (np.exp(Ds / sampfreqs).reshape(-1, 1) * Ls.conj().T) + Tt = Rt @ (np.exp(Dt / sampfreqt).reshape(-1, 1) * Lt.conj().T) + C = Ts - Tt + if exact: + return np.linalg.norm(C, 2) + else: + _, S, _ = randomized_svd( + C.real, n_components=1, n_iter=n_iter, random_state=random_state + ) + return S[0] + + +def principal_angles_via_svd(A, B): + """Compute principal angles between two subspaces using SVD of QA^T QB. + + Parameters + A: array-like, shape (d, p) whose columns span the first subspace + B: array-like, shape (d, q) whose columns span the second subspace + + Returns + angle: sorted principal angles (in radians), shape (min(p, q),) + """ + QA, _ = np.linalg.qr(A, mode="reduced") + QB, _ = np.linalg.qr(B, mode="reduced") + C = QA.T @ QB + # SVD of small matrix C + _, S, _ = np.linalg.svd(C, full_matrices=False) + S = np.clip(S, -1.0, 1.0) + angles = np.arccos(S) + return np.sort(angles) + + +def principal_angles_distance( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, +): + """Compute a principal angles distance between two spectral decompositions. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Dt: array-like, shape (n_t,) + Target eigenvalues. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + + Returns + ------- + dist: float + Principal-angles distance between the two decompositions. + """ + ns = Rs.shape[1] + nt = Rt.shape[1] + Ms = np.vstack( + [(ls[:, None] * rs.conj()[None, :]).flatten() for ls, rs in zip(Ls.T, Rs.T)] + ).T + Mt = np.vstack( + [(lt[:, None] * rt.conj()[None, :]).flatten() for lt, rt in zip(Lt.T, Rt.T)] + ).T + angles = principal_angles_via_svd(Ms, Mt) + if angles.shape[0] != max(ns, nt): + angles = np.hstack([angles, np.pi / 2 * np.ones(max(ns, nt) - angles.shape[0])]) + return np.sqrt(np.sum(angles**2)) + + +##################################################################################################################################### +##################################################################################################################################### +### OT METRIC ### +##################################################################################################################################### +##################################################################################################################################### + + +def principal_grassman_matrix(Ps, Pt, eps: float = 1e-12): + """Compute the unitary Grassmann matrix for source and target domains. + + Parameters + ---------- + Ps : array-like, shape (l, n_ds) + Source domain data, with columns spanning the source subspace. + Pt : array-like, shape (l, n_dt) + Target domain data, with columns spanning the target subspace. + eps : float, optional + Minimum column norm used to avoid division by zero. Default is 1e-12. + + Returns + ------- + C : np.ndarray, shape (n_ds, n_dt) + Grassmann matrix between source and target subspaces. + """ + ns = np.linalg.norm(Ps, axis=0, keepdims=True) + nt = np.linalg.norm(Pt, axis=0, keepdims=True) + ns = np.maximum(ns, eps) + nt = np.maximum(nt, eps) + + Psn = Ps / ns + Ptn = Pt / nt + + C = Psn.conj().T @ Ptn + return C + + +def eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt): + """Compute pairwise Grassmann matrices for source and target domains. + + Parameters + ---------- + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + + Returns + ---------- + C: np.ndarray, shape (n_s, n_t) + Eigenvector chordal cost matrix. + """ + Cr = principal_grassman_matrix(Rs, Rt) + Cl = principal_grassman_matrix(Ls, Lt) + C = np.sqrt(1 - np.clip((Cr * Cl).real, a_min=0, a_max=1)) + return C + + +def eigenvalue_cost_matrix(Ds, Dt, real_scale: float = 1.0, imag_scale: float = 1.0): + """Compute pairwise eigenvalue distances for source and target domains. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Dt: array-like, shape (n_t,) + Target eigenvalues. + real_scale: float, optional + Scale factor for real parts, default 1.0. + imag_scale: float, optional + Scale factor for imaginary parts, default 1.0. + + Returns + ---------- + C: np.ndarray, shape (n_s, n_t) + Eigenvalue cost matrix. + """ + Dsn = Ds.real * real_scale + 1j * Ds.imag * imag_scale + Dtn = Dt.real * real_scale + 1j * Dt.imag * imag_scale + C = np.abs(Dsn[:, None] - Dtn[None, :]) + return C + + +def ChordalCostFunction( + real_scale: float = 1.0, imag_scale: float = 1.0, alpha: float = 0.5, p: int = 2 +): + """Generate the chordal cost function. + + Parameters + ---------- + real_scale: float, optional + Scale factor for real parts, default 1.0. + imag_scale: float, optional + Scale factor for imaginary parts, default 1.0. + alpha: float, optional + Weighting factor for the eigenvalue cost, default 0.5. + p: int, optional + Power for the chordal distance, default 2. + + Returns + ---------- + cost_function: Chordal cost function. + """ + + def cost_function(Ds, Rs, Ls, Dt, Rt, Lt) -> np.ndarray: + """Compute the chordal cost matrix between source and target spectral decompositions. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Dt: array-like, shape (n_t,) + Target eigenvalues. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + + Returns + ---------- + C: np.ndarray, shape (n_s, n_t) + Chordal cost matrix. + """ + CD = eigenvalue_cost_matrix( + Ds, Dt, real_scale=real_scale, imag_scale=imag_scale + ) + CC = eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt) + C = alpha * CD + (1 - alpha) * CC + return C**p + + return cost_function + + +def ot_plan(C, Ws=None, Wt=None): + """Compute the optimal transport plan for a given cost matrix and marginals. + + Parameters + ---------- + C: array-like, shape (n, m) + Cost matrix. + Ws: array-like, shape (n,), optional + Source distribution. If None, uses a uniform distribution. + Wt: array-like, shape (m,), optional + Target distribution. If None, uses a uniform distribution. + + Returns + ---------- + P: np.ndarray, shape (n, m) + Optimal transport plan. + """ + if Ws is None: + Ws = np.ones(C.shape[0]) / C.shape[0] + if Wt is None: + Wt = np.ones(C.shape[1]) / C.shape[1] + return ot.emd(Ws, Wt, C) + + +def ot_score(C, P, p: int = 2) -> float: + """Compute the OT score (distance) given a cost matrix and a transport plan. + + Parameters + ---------- + C: array-like, shape (n, m) + Cost matrix. + P: array-like, shape (n, m) + Transport plan. + p: int, optional + Power for the OT score, default 2. + + Returns + ---------- + dist: float + OT score (distance). + """ + return np.sum(C * P) ** (1 / p) + + +def chordal_metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + real_scale: float = 1.0, + imag_scale: float = 1.0, + alpha: float = 0.5, + p: int = 2, +): + """Compute the chordal OT metric between two spectral decompositions. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Dt: array-like, shape (n_t,) + Target eigenvalues. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + real_scale: float, optional + Scale factor for real parts, default 1.0. + imag_scale: float, optional + Scale factor for imaginary parts, default 1.0. + alpha: float, optional + Weighting factor for the eigenvalue cost, default 0.5. + p: int, optional + Power for the chordal distance, default 2. + + Returns + ---------- + dist: float + Chordal OT metric value. + """ + cost_fn = ChordalCostFunction(real_scale, imag_scale, alpha, p) + C = cost_fn(Ds, Rs, Ls, Dt, Rt, Lt) + P = ot_plan(C) + return ot_score(C, P, p) + + +##################################################################################################################################### +##################################################################################################################################### +### NORMALISATION AND OPERATOR ATOMS ### +##################################################################################################################################### +##################################################################################################################################### + + +def _normalize_columns(A, nx, eps=1e-12): + """Normalize the columns of an array with a backend-aware norm. + + Parameters + ---------- + A: array-like, shape (d, n) + Input array whose columns are normalized. + nx: module + Backend (NumPy-compatible) used for math operations. + eps: float, optional + Minimum norm value to avoid division by zero, default 1e-12. + + Returns + ---------- + A_norm: array-like, shape (d, n) + Column-normalized array. + """ + nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) + nrm = nx.maximum(nrm, eps) + return A / nrm + + +def _delta_matrix_1d_hs(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): + """Compute the normalized inner-product delta matrix for eigenspaces. + + Parameters + ---------- + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + eps: float, optional + Minimum norm value used in normalization, default 1e-12. + + Returns + ---------- + delta: array-like, shape (n_s, n_t) + Delta matrix with entries in [0, 1]. + """ + if nx is None: + nx = get_backend(Rs, Ls, Rt, Lt) + + Rs = nx.asarray(Rs) + Ls = nx.asarray(Ls) + Rt = nx.asarray(Rt) + Lt = nx.asarray(Lt) + + Rsn = _normalize_columns(Rs, nx=nx, eps=eps) + Lsn = _normalize_columns(Ls, nx=nx, eps=eps) + Rtn = _normalize_columns(Rt, nx=nx, eps=eps) + Ltn = _normalize_columns(Lt, nx=nx, eps=eps) + + Cr = nx.dot(nx.conj(Rsn).T, Rtn) + Cl = nx.dot(nx.conj(Lsn).T, Ltn) + + delta = nx.abs(Cr * Cl) + delta = nx.minimum(nx.maximum(delta, 0.0), 1.0) + return delta + + +def _atoms_from_operator(T, r=None, sort_mode="closest_to_1"): + """Extract dua; eigen-atoms from a square operator. + + Parameters + ---------- + T: array-like, shape (d, d) + Input linear operator. + r: int, optional + Number of modes to keep. If None, keep all modes. + sort_mode: str, optional + Eigenvalue sorting mode: "closest_to_1", "closest_to_0", or "largest_mag". + + Returns + ---------- + D: np.ndarray, shape (r,) + Selected eigenvalues. + R: np.ndarray, shape (d, r) + Corresponding right eigenvectors. + L: np.ndarray, shape (d, r) + Dual left eigenvectors. + """ + T = np.asarray(T) + if T.ndim != 2 or T.shape[0] != T.shape[1]: + raise ValueError(f"T must be a square 2D array; got shape {T.shape}") + + d = T.shape[0] + if r is None: + r = d + r = int(r) + if not (1 <= r <= d): + raise ValueError(f"r must be an integer in [1, {d}], got r={r}") + + evals, evecs = np.linalg.eig(T) + + if sort_mode == "closest_to_1": + order = np.argsort(np.abs(evals - 1.0)) + elif sort_mode == "closest_to_0": + order = np.argsort(np.abs(evals)) + elif sort_mode == "largest_mag": + order = np.argsort(-np.abs(evals)) + else: + raise ValueError( + "sort_mode must be one of 'closest_to_1', 'closest_to_0', or 'largest_mag'" + ) + + idx = order[:r] + D = evals[idx] + R = evecs[:, idx] + + evalsL, evecsL = np.linalg.eig(T.conj().T) + + L = np.zeros((d, r), dtype=complex) + used = set() + + for i, lam in enumerate(D): + targets = np.abs(evalsL - np.conj(lam)) + for j in np.argsort(targets): + if j not in used: + used.add(j) + L[:, i] = evecsL[:, j] + break + + G = L.conj().T @ R + if np.linalg.matrix_rank(G) < r: + raise ValueError("Dual normalization failed: L^* R is singular.") + + L = L @ np.linalg.inv(G).conj().T + + return D, R, L + + +##################################################################################################################################### +##################################################################################################################################### +### GRASSMANNIAN METRIC ### +##################################################################################################################################### +##################################################################################################################################### + + +def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300): + """Compute squared Grassmannian distances from delta similarities. + + Parameters + ---------- + delta: array-like + Similarity values in [0, 1]. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + eps: float, optional + Minimum value used for numerical stability, default 1e-300. + + Returns + ---------- + dist2: array-like + Squared Grassmannian distance(s). + """ + if nx is None: + nx = get_backend(delta) + + delta = nx.asarray(delta) + delta = nx.minimum(nx.maximum(delta, 0.0), 1.0) + + if grassman_metric == "geodesic": + return nx.arccos(delta) ** 2 + if grassman_metric == "chordal": + return 1.0 - delta**2 + if grassman_metric == "procrustes": + return 2.0 * (1.0 - delta) + if grassman_metric == "martin": + return -nx.log(nx.maximum(delta**2, eps)) + raise ValueError(f"Unknown grassman_metric: {grassman_metric}") + + +##################################################################################################################################### +##################################################################################################################################### +### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### +##################################################################################################################################### +##################################################################################################################################### +def cost( + D1, + R1, + L1, + D2, + R2, + L2, + eta=0.5, + p=2, + grassman_metric="chordal", + real_scale=1.0, + imag_scale=1.0, +): + """Compute the SGOT cost matrix between two spectral decompositions. + + Parameters + ---------- + D1: array-like, shape (n_1,) or (n_1, n_1) + Eigenvalues of operator T1 (or diagonal matrix). + R1: array-like, shape (L, n_1) + Right eigenvectors of operator T1. + L1: array-like, shape (L, n_1) + Left eigenvectors of operator T1. + D2: array-like, shape (n_2,) or (n_2, n_2) + Eigenvalues of operator T2 (or diagonal matrix). + R2: array-like, shape (L, n_2) + Right eigenvectors of operator T2. + L2: array-like, shape (L, n_2) + Left eigenvectors of operator T2. + eta: float, optional + Weighting between spectral and Grassmann terms, default 0.5. + p: int, optional + Power for the OT cost, default 2. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + real_scale: float, optional + Scale factor for real parts, default 1.0. + imag_scale: float, optional + Scale factor for imaginary parts, default 1.0. + + Returns + ---------- + C: array-like, shape (n_1, n_2) + SGOT cost matrix. + """ + nx = get_backend(D1, R1, L1, D2, R2, L2) + + D1 = nx.asarray(D1) + D2 = nx.asarray(D2) + if len(D1.shape) == 2: + lam1 = nx.diag(D1) + else: + lam1 = D1.reshape((-1,)) + if len(D2.shape) == 2: + lam2 = nx.diag(D2) + else: + lam2 = D2.reshape((-1,)) + + lam1 = lam1.astype(complex) + lam2 = lam2.astype(complex) + + lam1s = nx.real(lam1) * real_scale + 1j * nx.imag(lam1) * imag_scale + lam2s = nx.real(lam2) * real_scale + 1j * nx.imag(lam2) * imag_scale + C_lambda = nx.abs(lam1s[:, None] - lam2s[None, :]) ** 2 + + delta = _delta_matrix_1d_hs(R1, L1, R2, L2, nx=nx) + C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) + + C2 = eta * C_lambda + (1.0 - eta) * C_grass + C = C2 ** (p / 2.0) + + return C + + +def metric( + D1, + R1, + L1, + D2, + R2, + L2, + eta=0.5, + p=2, + q=1, + grassman_metric="chordal", + real_scale=1.0, + imag_scale=1.0, + Ws=None, + Wt=None, +): + """Compute the SGOT metric between two spectral decompositions. + + Parameters + ---------- + D1: array-like, shape (n_1,) or (n_1, n_1) + Eigenvalues of operator T1 (or diagonal matrix). + R1: array-like, shape (L, n_1) + Right eigenvectors of operator T1. + L1: array-like, shape (L, n_1) + Left eigenvectors of operator T1. + D2: array-like, shape (n_2,) or (n_2, n_2) + Eigenvalues of operator T2 (or diagonal matrix). + R2: array-like, shape (L, n_2) + Right eigenvectors of operator T2. + L2: array-like, shape (L, n_2) + Left eigenvectors of operator T2. + eta: float, optional + Weighting between spectral and Grassmann terms, default 0.5. + p: int, optional + Power for the OT cost, default 2. + q: int, optional + Outer root applied to the OT objective, default 1. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + real_scale: float, optional + Scale factor for real parts, default 1.0. + imag_scale: float, optional + Scale factor for imaginary parts, default 1.0. + Ws: array-like, shape (n_1,), optional + Source distribution. If None, uses a uniform distribution. + Wt: array-like, shape (n_2,), optional + Target distribution. If None, uses a uniform distribution. + + Returns + ---------- + dist: float + SGOT metric value. + """ + C = cost( + D1, + R1, + L1, + D2, + R2, + L2, + eta=eta, + p=p, + grassman_metric=grassman_metric, + real_scale=real_scale, + imag_scale=imag_scale, + ) + + nx = get_backend(C) + n, m = C.shape + + if Ws is None: + Ws = nx.ones((n,), dtype=C.dtype) / float(n) + else: + Ws = nx.asarray(Ws) + + if Wt is None: + Wt = nx.ones((m,), dtype=C.dtype) / float(m) + else: + Wt = nx.asarray(Wt) + + Ws = Ws / nx.sum(Ws) + Wt = Wt / nx.sum(Wt) + + C_np = ot.backend.to_numpy(C) + Ws_np = ot.backend.to_numpy(Ws) + Wt_np = ot.backend.to_numpy(Wt) + + P = ot_plan(C_np, Ws=Ws_np, Wt=Wt_np) + obj = ot_score(C_np, P, p=p) + + return float(obj) ** (1.0 / float(q)) + + +def metric_from_operator( + T1, + T2, + r=None, + eta=0.5, + p=2, + q=1, + grassman_metric="chordal", + real_scale=1.0, + imag_scale=1.0, + Ws=None, + Wt=None, +): + """Compute the SGOT metric directly from two operators. + + Parameters + ---------- + T1: array-like, shape (d, d) + First operator. + T2: array-like, shape (d, d) + Second operator. + r: int, optional + Number of modes to keep. If None, keep all modes. + eta: float, optional + Weighting between spectral and Grassmann terms, default 0.5. + p: int, optional + Power for the OT cost, default 2. + q: int, optional + Outer root applied to the OT objective, default 1. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + real_scale: float, optional + Scale factor for real parts, default 1.0. + imag_scale: float, optional + Scale factor for imaginary parts, default 1.0. + Ws: array-like, shape (n_1,), optional + Source distribution. If None, uses a uniform distribution. + Wt: array-like, shape (n_2,), optional + Target distribution. If None, uses a uniform distribution. + + Returns + ---------- + dist: float + SGOT metric value. + """ + D1, R1, L1 = _atoms_from_operator(T1, r=r, sort_mode="closest_to_1") + D2, R2, L2 = _atoms_from_operator(T2, r=r, sort_mode="closest_to_1") + + return metric( + D1, + R1, + L1, + D2, + R2, + L2, + eta=eta, + p=p, + q=q, + grassman_metric=grassman_metric, + real_scale=real_scale, + imag_scale=imag_scale, + Ws=Ws, + Wt=Wt, + ) + + +def operator_estimator( + X, + Y=None, + r=None, + ref=1e-8, + force_complex=False, +): + """Estimate a linear operator from data. + + Parameters + ---------- + X: array-like, shape (n_samples, d) or (d, n_samples) + Input snapshot matrix. + Y: array-like, shape like X, optional + Output snapshot matrix. If None, uses a one-step shift of X. + r: int, optional + Rank for optional truncated SVD of the estimated operator. + ref: float, optional + Tikhonov regularization strength, default 1e-8. + force_complex: bool, optional + If True, cast inputs to complex dtype. + + Returns + ---------- + T_hat: np.ndarray, shape (d, d) + Estimated linear operator. + """ + X = np.asarray(X) + + if Y is None: + if X.ndim != 2 or X.shape[0] < 2: + raise ValueError("If Y is None, X must be 2D with at least 2 samples/rows.") + X0 = X[:-1] + Y0 = X[1:] + else: + Y = np.asarray(Y) + if X.shape != Y.shape: + raise ValueError( + f"X and Y must have the same shape; got {X.shape} vs {Y.shape}" + ) + X0, Y0 = X, Y + + if X0.shape[0] >= 1 and X0.shape[0] != X0.shape[1]: + if X0.shape[0] >= X0.shape[1]: + Xc = X0.T + Yc = Y0.T + else: + Xc = X0 + Yc = Y0 + else: + Xc = X0 + Yc = Y0 + + if Xc.ndim != 2 or Yc.ndim != 2: + raise ValueError("X and Y must be 2D arrays after processing.") + + d, n = Xc.shape + if Yc.shape != (d, n): + raise ValueError( + f"After formatting, expected Y to have shape {(d, n)}, got {Yc.shape}" + ) + + if force_complex: + Xc = Xc.astype(complex) + Yc = Yc.astype(complex) + + XXH = Xc @ Xc.conj().T + YXH = Yc @ Xc.conj().T + A = XXH + ref * np.eye(d, dtype=XXH.dtype) + + T_hat = np.linalg.solve(A.T.conj(), YXH.T.conj()).T.conj() + + if r is not None: + if not (1 <= r <= d): + raise ValueError(f"r must be in [1, {d}], got r={r}") + U, S, Vh = np.linalg.svd(T_hat, full_matrices=False) + T_hat = (U[:, :r] * S[:r]) @ Vh[:r, :] + + return T_hat From 3f1011104c107d51ccdbbf900319514a02aea51d Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Fri, 20 Feb 2026 14:29:54 +0100 Subject: [PATCH 04/10] rewrite backend and refactor OT metric --- ot/sgot.py | 568 ++++++++++++++++++++++++----------------------------- 1 file changed, 255 insertions(+), 313 deletions(-) diff --git a/ot/sgot.py b/ot/sgot.py index 929078f7e..86d3a6036 100644 --- a/ot/sgot.py +++ b/ot/sgot.py @@ -17,167 +17,12 @@ import numpy as np import ot -from sklearn.utils.extmath import randomized_svd from ot.backend import get_backend ### # Settings : (Ds,Rs,Ls) if primal, (Ds,Xs,prs,pls) if dual ### -##################################################################################################################################### -##################################################################################################################################### -### PRINCIPAL ANGLE METRICS ### -##################################################################################################################################### -##################################################################################################################################### - - -def hs_metric(Ds, Rs, Ls, Dt, Rt, Lt, sampfreqs: int = 1, sampfreqt: int = 1): - """Compute the Hilbert-Schmidt (Frobenius) distance between two operators. - - Parameters - ---------- - Ds: array-like, shape (n_s,) - Source eigenvalues. - Rs: array-like, shape (L, n_s) - Source right eigenvectors. - Ls: array-like, shape (L, n_s) - Source left eigenvectors. - Dt: array-like, shape (n_t,) - Target eigenvalues. - Rt: array-like, shape (L, n_t) - Target right eigenvectors. - Lt: array-like, shape (L, n_t) - Target left eigenvectors. - sampfreqs: int, optional, sampling frequency for the source operator with default 1 - sampfreqt: int, optional, sampling frequency for the target operator with default 1 - - Returns - ---------- - dist: float, Frobenius norm - """ - Ts = Rs @ (np.exp(Ds / sampfreqs).reshape(-1, 1) * Ls.conj().T) - Tt = Rt @ (np.exp(Dt / sampfreqt).reshape(-1, 1) * Lt.conj().T) - C = Ts - Tt - return np.linalg.norm(C, "fro") - - -def operator_metric( - Ds, - Rs, - Ls, - Dt, - Rt, - Lt, - sampfreqs: int = 1, - sampfreqt: int = 1, - exact: bool = False, - n_iter: int = 5, - random_state: int = None, -): - """Compute the spectral norm distance between two reconstructed operators. - - Parameters - ---------- - Ds: array-like, shape (n_s,) - Source eigenvalues. - Rs: array-like, shape (L, n_s) - Source right eigenvectors. - Ls: array-like, shape (L, n_s) - Source left eigenvectors. - Dt: array-like, shape (n_t,) - Target eigenvalues. - Rt: array-like, shape (L, n_t) - Target right eigenvectors. - Lt: array-like, shape (L, n_t) - Target left eigenvectors. - sampfreqs: int, optional - Sampling frequency for the source operator with default 1 - sampfreqt: int, optional - Sampling frequency for the target operator with default 1 - exact: bool, optional - n_iter: int, optional - random_state: int or None, optional - - Returns - ---------- - dist: float - """ - Ts = Rs @ (np.exp(Ds / sampfreqs).reshape(-1, 1) * Ls.conj().T) - Tt = Rt @ (np.exp(Dt / sampfreqt).reshape(-1, 1) * Lt.conj().T) - C = Ts - Tt - if exact: - return np.linalg.norm(C, 2) - else: - _, S, _ = randomized_svd( - C.real, n_components=1, n_iter=n_iter, random_state=random_state - ) - return S[0] - - -def principal_angles_via_svd(A, B): - """Compute principal angles between two subspaces using SVD of QA^T QB. - - Parameters - A: array-like, shape (d, p) whose columns span the first subspace - B: array-like, shape (d, q) whose columns span the second subspace - - Returns - angle: sorted principal angles (in radians), shape (min(p, q),) - """ - QA, _ = np.linalg.qr(A, mode="reduced") - QB, _ = np.linalg.qr(B, mode="reduced") - C = QA.T @ QB - # SVD of small matrix C - _, S, _ = np.linalg.svd(C, full_matrices=False) - S = np.clip(S, -1.0, 1.0) - angles = np.arccos(S) - return np.sort(angles) - - -def principal_angles_distance( - Ds, - Rs, - Ls, - Dt, - Rt, - Lt, -): - """Compute a principal angles distance between two spectral decompositions. - - Parameters - ---------- - Ds: array-like, shape (n_s,) - Source eigenvalues. - Rs: array-like, shape (L, n_s) - Source right eigenvectors. - Ls: array-like, shape (L, n_s) - Source left eigenvectors. - Dt: array-like, shape (n_t,) - Target eigenvalues. - Rt: array-like, shape (L, n_t) - Target right eigenvectors. - Lt: array-like, shape (L, n_t) - Target left eigenvectors. - - Returns - ------- - dist: float - Principal-angles distance between the two decompositions. - """ - ns = Rs.shape[1] - nt = Rt.shape[1] - Ms = np.vstack( - [(ls[:, None] * rs.conj()[None, :]).flatten() for ls, rs in zip(Ls.T, Rs.T)] - ).T - Mt = np.vstack( - [(lt[:, None] * rt.conj()[None, :]).flatten() for lt, rt in zip(Lt.T, Rt.T)] - ).T - angles = principal_angles_via_svd(Ms, Mt) - if angles.shape[0] != max(ns, nt): - angles = np.hstack([angles, np.pi / 2 * np.ones(max(ns, nt) - angles.shape[0])]) - return np.sqrt(np.sum(angles**2)) - - ##################################################################################################################################### ##################################################################################################################################### ### OT METRIC ### @@ -185,7 +30,7 @@ def principal_angles_distance( ##################################################################################################################################### -def principal_grassman_matrix(Ps, Pt, eps: float = 1e-12): +def principal_grassman_matrix(Ps, Pt, eps: float = 1e-12, nx=None): """Compute the unitary Grassmann matrix for source and target domains. Parameters @@ -202,19 +47,25 @@ def principal_grassman_matrix(Ps, Pt, eps: float = 1e-12): C : np.ndarray, shape (n_ds, n_dt) Grassmann matrix between source and target subspaces. """ - ns = np.linalg.norm(Ps, axis=0, keepdims=True) - nt = np.linalg.norm(Pt, axis=0, keepdims=True) - ns = np.maximum(ns, eps) - nt = np.maximum(nt, eps) + if nx is None: + nx = get_backend(Ps, Pt) + + Ps = nx.asarray(Ps) + Pt = nx.asarray(Pt) + + ns = nx.sqrt(nx.sum(Ps * nx.conj(Ps), axis=0, keepdims=True)) + nt = nx.sqrt(nx.sum(Pt * nx.conj(Pt), axis=0, keepdims=True)) + + ns = nx.clip(ns, eps, 1e300) + nt = nx.clip(nt, eps, 1e300) Psn = Ps / ns Ptn = Pt / nt - C = Psn.conj().T @ Ptn - return C + return nx.dot(nx.conj(Psn).T, Ptn) -def eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt): +def eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt, nx=None): """Compute pairwise Grassmann matrices for source and target domains. Parameters @@ -233,13 +84,20 @@ def eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt): C: np.ndarray, shape (n_s, n_t) Eigenvector chordal cost matrix. """ - Cr = principal_grassman_matrix(Rs, Rt) - Cl = principal_grassman_matrix(Ls, Lt) - C = np.sqrt(1 - np.clip((Cr * Cl).real, a_min=0, a_max=1)) - return C + if nx is None: + nx = get_backend(Rs, Ls, Rt, Lt) + + Cr = principal_grassman_matrix(Rs, Rt, nx=nx) + Cl = principal_grassman_matrix(Ls, Lt, nx=nx) + prod = nx.real(Cr * Cl) + prod = nx.clip(prod, 0.0, 1.0) + return nx.sqrt(1.0 - prod) -def eigenvalue_cost_matrix(Ds, Dt, real_scale: float = 1.0, imag_scale: float = 1.0): + +def eigenvalue_cost_matrix( + Ds, Dt, real_scale: float = 1.0, imag_scale: float = 1.0, nx=None +): """Compute pairwise eigenvalue distances for source and target domains. Parameters @@ -258,19 +116,35 @@ def eigenvalue_cost_matrix(Ds, Dt, real_scale: float = 1.0, imag_scale: float = C: np.ndarray, shape (n_s, n_t) Eigenvalue cost matrix. """ - Dsn = Ds.real * real_scale + 1j * Ds.imag * imag_scale - Dtn = Dt.real * real_scale + 1j * Dt.imag * imag_scale - C = np.abs(Dsn[:, None] - Dtn[None, :]) - return C + if nx is None: + nx = get_backend(Ds, Dt) + + Ds = nx.asarray(Ds) + Dt = nx.asarray(Dt) + Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale + Dtn = nx.real(Dt) * real_scale + 1j * nx.imag(Dt) * imag_scale + return nx.abs(Dsn[:, None] - Dtn[None, :]) -def ChordalCostFunction( - real_scale: float = 1.0, imag_scale: float = 1.0, alpha: float = 0.5, p: int = 2 +def chordal_cost_matrix( + Ds, Rs, Ls, Dt, Rt, Lt, real_scale=1.0, imag_scale=1.0, alpha=0.5, p=2, nx=None ): - """Generate the chordal cost function. + """Compute the chordal cost matrix between source and target spectral decompositions. Parameters ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Dt: array-like, shape (n_t,) + Target eigenvalues. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. real_scale: float, optional Scale factor for real parts, default 1.0. imag_scale: float, optional @@ -282,43 +156,20 @@ def ChordalCostFunction( Returns ---------- - cost_function: Chordal cost function. + C: np.ndarray, shape (n_s, n_t) + Chordal cost matrix. """ - - def cost_function(Ds, Rs, Ls, Dt, Rt, Lt) -> np.ndarray: - """Compute the chordal cost matrix between source and target spectral decompositions. - - Parameters - ---------- - Ds: array-like, shape (n_s,) - Source eigenvalues. - Rs: array-like, shape (L, n_s) - Source right eigenvectors. - Ls: array-like, shape (L, n_s) - Source left eigenvectors. - Dt: array-like, shape (n_t,) - Target eigenvalues. - Rt: array-like, shape (L, n_t) - Target right eigenvectors. - Lt: array-like, shape (L, n_t) - Target left eigenvectors. - - Returns - ---------- - C: np.ndarray, shape (n_s, n_t) - Chordal cost matrix. - """ - CD = eigenvalue_cost_matrix( - Ds, Dt, real_scale=real_scale, imag_scale=imag_scale - ) - CC = eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt) - C = alpha * CD + (1 - alpha) * CC - return C**p - - return cost_function + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + CD = eigenvalue_cost_matrix( + Ds, Dt, real_scale=real_scale, imag_scale=imag_scale, nx=nx + ) + CC = eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt, nx=nx) + C = alpha * CD + (1.0 - alpha) * CC + return C**p -def ot_plan(C, Ws=None, Wt=None): +def ot_plan(C, Ws=None, Wt=None, nx=None): """Compute the optimal transport plan for a given cost matrix and marginals. Parameters @@ -335,14 +186,35 @@ def ot_plan(C, Ws=None, Wt=None): P: np.ndarray, shape (n, m) Optimal transport plan. """ + if nx is None: + nx = get_backend(C) + + C = nx.asarray(C) + n, m = C.shape + if Ws is None: - Ws = np.ones(C.shape[0]) / C.shape[0] + Ws = nx.ones((n,), dtype=C.dtype) / float(n) + else: + Ws = nx.asarray(Ws) + if Wt is None: - Wt = np.ones(C.shape[1]) / C.shape[1] - return ot.emd(Ws, Wt, C) + Wt = nx.ones((m,), dtype=C.dtype) / float(m) + else: + Wt = nx.asarray(Wt) + + Ws = Ws / nx.sum(Ws) + Wt = Wt / nx.sum(Wt) + + C_real = nx.real(C) + + C_np = ot.backend.to_numpy(C_real) + Ws_np = ot.backend.to_numpy(Ws) + Wt_np = ot.backend.to_numpy(Wt) + + return ot.emd(Ws_np, Wt_np, C_np) -def ot_score(C, P, p: int = 2) -> float: +def ot_score(C, P, p: int = 2, nx=None): """Compute the OT score (distance) given a cost matrix and a transport plan. Parameters @@ -359,7 +231,11 @@ def ot_score(C, P, p: int = 2) -> float: dist: float OT score (distance). """ - return np.sum(C * P) ** (1 / p) + if nx is None: + nx = get_backend(C) + C = nx.asarray(C) + P = nx.asarray(P) + return float(nx.sum(C * P) ** (1.0 / p)) def chordal_metric( @@ -373,6 +249,7 @@ def chordal_metric( imag_scale: float = 1.0, alpha: float = 0.5, p: int = 2, + nx=None, ): """Compute the chordal OT metric between two spectral decompositions. @@ -404,10 +281,24 @@ def chordal_metric( dist: float Chordal OT metric value. """ - cost_fn = ChordalCostFunction(real_scale, imag_scale, alpha, p) - C = cost_fn(Ds, Rs, Ls, Dt, Rt, Lt) - P = ot_plan(C) - return ot_score(C, P, p) + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + + C = chordal_cost_matrix( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + real_scale=real_scale, + imag_scale=imag_scale, + alpha=alpha, + p=p, + nx=nx, + ) + P = ot_plan(C, nx=nx) + return ot_score(C, P, p=p, nx=nx) ##################################################################################################################################### @@ -435,7 +326,7 @@ def _normalize_columns(A, nx, eps=1e-12): Column-normalized array. """ nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) - nrm = nx.maximum(nrm, eps) + nrm = nx.clip(nrm, eps, 1e300) return A / nrm @@ -479,7 +370,7 @@ def _delta_matrix_1d_hs(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): Cl = nx.dot(nx.conj(Lsn).T, Ltn) delta = nx.abs(Cr * Cl) - delta = nx.minimum(nx.maximum(delta, 0.0), 1.0) + delta = nx.clip(delta, 0.0, 1.0) return delta @@ -504,52 +395,72 @@ def _atoms_from_operator(T, r=None, sort_mode="closest_to_1"): L: np.ndarray, shape (d, r) Dual left eigenvectors. """ - T = np.asarray(T) + nx = get_backend(T) + T = nx.asarray(T) + if T.ndim != 2 or T.shape[0] != T.shape[1]: raise ValueError(f"T must be a square 2D array; got shape {T.shape}") - d = T.shape[0] + d = int(T.shape[0]) if r is None: r = d r = int(r) if not (1 <= r <= d): raise ValueError(f"r must be an integer in [1, {d}], got r={r}") - evals, evecs = np.linalg.eig(T) + T_np = ot.backend.to_numpy(T) + evals_np, evecs_np = np.linalg.eig(T_np) if sort_mode == "closest_to_1": - order = np.argsort(np.abs(evals - 1.0)) + order = np.argsort(np.abs(evals_np - 1.0)) elif sort_mode == "closest_to_0": - order = np.argsort(np.abs(evals)) + order = np.argsort(np.abs(evals_np)) elif sort_mode == "largest_mag": - order = np.argsort(-np.abs(evals)) + order = np.argsort(-np.abs(evals_np)) else: raise ValueError( "sort_mode must be one of 'closest_to_1', 'closest_to_0', or 'largest_mag'" ) idx = order[:r] - D = evals[idx] - R = evecs[:, idx] + D_np = evals_np[idx] + R_np = evecs_np[:, idx] - evalsL, evecsL = np.linalg.eig(T.conj().T) + evalsL_np, evecsL_np = np.linalg.eig(T_np.conj().T) - L = np.zeros((d, r), dtype=complex) + L_np = np.zeros((d, r), dtype=np.complex128) used = set() - for i, lam in enumerate(D): - targets = np.abs(evalsL - np.conj(lam)) + for i, lam in enumerate(D_np): + targets = np.abs(evalsL_np - np.conj(lam)) for j in np.argsort(targets): if j not in used: used.add(j) - L[:, i] = evecsL[:, j] + L_np[:, i] = evecsL_np[:, j] break - G = L.conj().T @ R - if np.linalg.matrix_rank(G) < r: + if hasattr(nx, "from_numpy"): + D = nx.from_numpy(D_np, type_as=T) + R = nx.from_numpy(R_np, type_as=T) + L = nx.from_numpy(L_np, type_as=T) + else: + D = nx.asarray(D_np) + R = nx.asarray(R_np) + L = nx.asarray(L_np) + + G = nx.dot(nx.conj(L).T, R) + + G_np = ot.backend.to_numpy(G) + if np.linalg.matrix_rank(G_np) < r: raise ValueError("Dual normalization failed: L^* R is singular.") - L = L @ np.linalg.inv(G).conj().T + invG_H_np = np.linalg.inv(G_np).conj().T + if hasattr(nx, "from_numpy"): + invG_H = nx.from_numpy(invG_H_np, type_as=T) + else: + invG_H = nx.asarray(invG_H_np) + + L = nx.dot(L, invG_H) return D, R, L @@ -584,7 +495,7 @@ def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1 nx = get_backend(delta) delta = nx.asarray(delta) - delta = nx.minimum(nx.maximum(delta, 0.0), 1.0) + delta = nx.clip(delta, 0.0, 1.0) if grassman_metric == "geodesic": return nx.arccos(delta) ** 2 @@ -593,7 +504,7 @@ def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1 if grassman_metric == "procrustes": return 2.0 * (1.0 - delta) if grassman_metric == "martin": - return -nx.log(nx.maximum(delta**2, eps)) + return -nx.log(nx.clip(delta**2, eps, 1e300)) raise ValueError(f"Unknown grassman_metric: {grassman_metric}") @@ -603,33 +514,34 @@ def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1 ##################################################################################################################################### ##################################################################################################################################### def cost( - D1, - R1, - L1, - D2, - R2, - L2, + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, eta=0.5, p=2, grassman_metric="chordal", real_scale=1.0, imag_scale=1.0, + nx=None, ): """Compute the SGOT cost matrix between two spectral decompositions. Parameters ---------- - D1: array-like, shape (n_1,) or (n_1, n_1) + Ds: array-like, shape (n_s,) or (n_s, n_s) Eigenvalues of operator T1 (or diagonal matrix). - R1: array-like, shape (L, n_1) + Rs: array-like, shape (L, n_s) Right eigenvectors of operator T1. - L1: array-like, shape (L, n_1) + Ls: array-like, shape (L, n_s) Left eigenvectors of operator T1. - D2: array-like, shape (n_2,) or (n_2, n_2) + Dt: array-like, shape (n_t,) or (n_t, n_t) Eigenvalues of operator T2 (or diagonal matrix). - R2: array-like, shape (L, n_2) + Rt: array-like, shape (L, n_t) Right eigenvectors of operator T2. - L2: array-like, shape (L, n_2) + Lt: array-like, shape (L, n_t) Left eigenvectors of operator T2. eta: float, optional Weighting between spectral and Grassmann terms, default 0.5. @@ -644,21 +556,22 @@ def cost( Returns ---------- - C: array-like, shape (n_1, n_2) + C: array-like, shape (n_s, n_t) SGOT cost matrix. """ - nx = get_backend(D1, R1, L1, D2, R2, L2) + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) - D1 = nx.asarray(D1) - D2 = nx.asarray(D2) - if len(D1.shape) == 2: - lam1 = nx.diag(D1) + Ds = nx.asarray(Ds) + Dt = nx.asarray(Dt) + if len(Ds.shape) == 2: + lam1 = nx.diag(Ds) else: - lam1 = D1.reshape((-1,)) - if len(D2.shape) == 2: - lam2 = nx.diag(D2) + lam1 = Ds.reshape((-1,)) + if len(Dt.shape) == 2: + lam2 = nx.diag(Dt) else: - lam2 = D2.reshape((-1,)) + lam2 = Dt.reshape((-1,)) lam1 = lam1.astype(complex) lam2 = lam2.astype(complex) @@ -667,7 +580,7 @@ def cost( lam2s = nx.real(lam2) * real_scale + 1j * nx.imag(lam2) * imag_scale C_lambda = nx.abs(lam1s[:, None] - lam2s[None, :]) ** 2 - delta = _delta_matrix_1d_hs(R1, L1, R2, L2, nx=nx) + delta = _delta_matrix_1d_hs(Rs, Ls, Rt, Lt, nx=nx) C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) C2 = eta * C_lambda + (1.0 - eta) * C_grass @@ -677,12 +590,12 @@ def cost( def metric( - D1, - R1, - L1, - D2, - R2, - L2, + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, eta=0.5, p=2, q=1, @@ -691,22 +604,23 @@ def metric( imag_scale=1.0, Ws=None, Wt=None, + nx=None, ): """Compute the SGOT metric between two spectral decompositions. Parameters ---------- - D1: array-like, shape (n_1,) or (n_1, n_1) + Ds: array-like, shape (n_s,) or (n_s, n_s) Eigenvalues of operator T1 (or diagonal matrix). - R1: array-like, shape (L, n_1) + Rs: array-like, shape (L, n_s) Right eigenvectors of operator T1. - L1: array-like, shape (L, n_1) + Ls: array-like, shape (L, n_s) Left eigenvectors of operator T1. - D2: array-like, shape (n_2,) or (n_2, n_2) + Dt: array-like, shape (n_t,) or (n_t, n_t) Eigenvalues of operator T2 (or diagonal matrix). - R2: array-like, shape (L, n_2) + Rt: array-like, shape (L, n_t) Right eigenvectors of operator T2. - L2: array-like, shape (L, n_2) + Lt: array-like, shape (L, n_t) Left eigenvectors of operator T2. eta: float, optional Weighting between spectral and Grassmann terms, default 0.5. @@ -720,9 +634,9 @@ def metric( Scale factor for real parts, default 1.0. imag_scale: float, optional Scale factor for imaginary parts, default 1.0. - Ws: array-like, shape (n_1,), optional + Ws: array-like, shape (n_s,), optional Source distribution. If None, uses a uniform distribution. - Wt: array-like, shape (n_2,), optional + Wt: array-like, shape (n_t,), optional Target distribution. If None, uses a uniform distribution. Returns @@ -730,21 +644,24 @@ def metric( dist: float SGOT metric value. """ + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + C = cost( - D1, - R1, - L1, - D2, - R2, - L2, + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, eta=eta, p=p, grassman_metric=grassman_metric, real_scale=real_scale, imag_scale=imag_scale, + nx=nx, ) - nx = get_backend(C) n, m = C.shape if Ws is None: @@ -760,13 +677,8 @@ def metric( Ws = Ws / nx.sum(Ws) Wt = Wt / nx.sum(Wt) - C_np = ot.backend.to_numpy(C) - Ws_np = ot.backend.to_numpy(Ws) - Wt_np = ot.backend.to_numpy(Wt) - - P = ot_plan(C_np, Ws=Ws_np, Wt=Wt_np) - obj = ot_score(C_np, P, p=p) - + P = ot_plan(C, Ws=Ws, Wt=Wt, nx=nx) + obj = ot_score(C, P, p=p, nx=nx) return float(obj) ** (1.0 / float(q)) @@ -805,9 +717,9 @@ def metric_from_operator( Scale factor for real parts, default 1.0. imag_scale: float, optional Scale factor for imaginary parts, default 1.0. - Ws: array-like, shape (n_1,), optional + Ws: array-like, shape (n_s,), optional Source distribution. If None, uses a uniform distribution. - Wt: array-like, shape (n_2,), optional + Wt: array-like, shape (n_t,), optional Target distribution. If None, uses a uniform distribution. Returns @@ -815,16 +727,16 @@ def metric_from_operator( dist: float SGOT metric value. """ - D1, R1, L1 = _atoms_from_operator(T1, r=r, sort_mode="closest_to_1") - D2, R2, L2 = _atoms_from_operator(T2, r=r, sort_mode="closest_to_1") + Ds, Rs, Ls = _atoms_from_operator(T1, r=r, sort_mode="closest_to_1") + Dt, Rt, Lt = _atoms_from_operator(T2, r=r, sort_mode="closest_to_1") return metric( - D1, - R1, - L1, - D2, - R2, - L2, + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, eta=eta, p=p, q=q, @@ -863,16 +775,18 @@ def operator_estimator( T_hat: np.ndarray, shape (d, d) Estimated linear operator. """ - X = np.asarray(X) + nx = get_backend(X, Y) if Y is not None else get_backend(X) + + X = nx.asarray(X) if Y is None: - if X.ndim != 2 or X.shape[0] < 2: + if X.ndim != 2 or int(X.shape[0]) < 2: raise ValueError("If Y is None, X must be 2D with at least 2 samples/rows.") X0 = X[:-1] Y0 = X[1:] else: - Y = np.asarray(Y) - if X.shape != Y.shape: + Y = nx.asarray(Y) + if tuple(X.shape) != tuple(Y.shape): raise ValueError( f"X and Y must have the same shape; got {X.shape} vs {Y.shape}" ) @@ -892,26 +806,54 @@ def operator_estimator( if Xc.ndim != 2 or Yc.ndim != 2: raise ValueError("X and Y must be 2D arrays after processing.") - d, n = Xc.shape - if Yc.shape != (d, n): + d, n = int(Xc.shape[0]), int(Xc.shape[1]) + if tuple(Yc.shape) != (d, n): raise ValueError( f"After formatting, expected Y to have shape {(d, n)}, got {Yc.shape}" ) if force_complex: - Xc = Xc.astype(complex) - Yc = Yc.astype(complex) + Xc_np = ot.backend.to_numpy(Xc) # explicit backend->NumPy copy + Yc_np = ot.backend.to_numpy(Yc) + Xc_np = Xc_np.astype(np.complex128, copy=False) + Yc_np = Yc_np.astype(np.complex128, copy=False) + if hasattr(nx, "from_numpy"): + Xc = nx.from_numpy(Xc_np, type_as=Xc) + Yc = nx.from_numpy(Yc_np, type_as=Yc) + else: + Xc = nx.asarray(Xc_np) + Yc = nx.asarray(Yc_np) + + XXH = nx.dot(Xc, nx.conj(Xc).T) + YXH = nx.dot(Yc, nx.conj(Xc).T) + A = XXH + ref * nx.eye(d, type_as=XXH) - XXH = Xc @ Xc.conj().T - YXH = Yc @ Xc.conj().T - A = XXH + ref * np.eye(d, dtype=XXH.dtype) + AH = nx.conj(A).T + BH = nx.conj(YXH).T - T_hat = np.linalg.solve(A.T.conj(), YXH.T.conj()).T.conj() + AH_np = ot.backend.to_numpy(AH) # explicit backend->NumPy copy + BH_np = ot.backend.to_numpy(BH) + Xsol_np = np.linalg.solve(AH_np, BH_np) + + if hasattr(nx, "from_numpy"): + Xsol = nx.from_numpy(Xsol_np, type_as=YXH) + else: + Xsol = nx.asarray(Xsol_np) + + T_hat = nx.conj(Xsol).T if r is not None: + r = int(r) if not (1 <= r <= d): raise ValueError(f"r must be in [1, {d}], got r={r}") - U, S, Vh = np.linalg.svd(T_hat, full_matrices=False) - T_hat = (U[:, :r] * S[:r]) @ Vh[:r, :] + + T_np = ot.backend.to_numpy(T_hat) # explicit backend->NumPy copy + U, S, Vh = np.linalg.svd(T_np, full_matrices=False) + T_np = (U[:, :r] * S[:r]) @ Vh[:r, :] + + if hasattr(nx, "from_numpy"): + T_hat = nx.from_numpy(T_np, type_as=T_hat) + else: + T_hat = nx.asarray(T_np) return T_hat From d5fef5b6ab5a366b5bb029e13fa8c2fc868ffdd4 Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Sun, 22 Feb 2026 22:37:14 +0100 Subject: [PATCH 05/10] refactor sgot, and implement tests for sgot and backend --- ot/backend.py | 134 +++++++++ ot/sgot.py | 626 +++++-------------------------------------- test/test_backend.py | 24 ++ test/test_sgot.py | 333 +++++++++++++++++++++++ 4 files changed, 564 insertions(+), 553 deletions(-) create mode 100644 test/test_sgot.py diff --git a/ot/backend.py b/ot/backend.py index 3f158d166..b2dfc4024 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -603,6 +603,55 @@ def clip(self, a, a_min, a_max): """ raise NotImplementedError() + def real(self, a): + """ + Return the real part of the tensor element-wise. + + This function follows the api from :any:`numpy.real` + + See: https://numpy.org/doc/stable/reference/generated/numpy.real.html + """ + raise NotImplementedError() + + def imag(self, a): + """ + Return the imaginary part of the tensor element-wise. + + This function follows the api from :any:`numpy.imag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.imag.html + """ + raise NotImplementedError() + + def conj(self, a): + """ + Return the complex conjugate, element-wise. + + This function follows the api from :any:`numpy.conj` + + See: https://numpy.org/doc/stable/reference/generated/numpy.conj.html + """ + raise NotImplementedError() + + def arccos(self, a): + """ + Trigonometric inverse cosine, element-wise. + + This function follows the api from :any:`numpy.arccos` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arccos.html + """ + raise NotImplementedError() + + def astype(self, a, dtype): + """ + Cast tensor to a given dtype. + + dtype can be a string (e.g. "complex128", "float64") or backend-specific + dtype. Backend converts to the corresponding type. + """ + raise NotImplementedError() + def repeat(self, a, repeats, axis=None): r""" Repeats elements of a tensor. @@ -1294,6 +1343,23 @@ def outer(self, a, b): def clip(self, a, a_min, a_max): return np.clip(a, a_min, a_max) + def real(self, a): + return np.real(a) + + def imag(self, a): + return np.imag(a) + + def conj(self, a): + return np.conj(a) + + def arccos(self, a): + return np.arccos(a) + + def astype(self, a, dtype): + if isinstance(dtype, str): + dtype = getattr(np, dtype, None) or np.dtype(dtype) + return np.asarray(a, dtype=dtype) + def repeat(self, a, repeats, axis=None): return np.repeat(a, repeats, axis) @@ -1711,6 +1777,23 @@ def outer(self, a, b): def clip(self, a, a_min, a_max): return jnp.clip(a, a_min, a_max) + def real(self, a): + return jnp.real(a) + + def imag(self, a): + return jnp.imag(a) + + def conj(self, a): + return jnp.conj(a) + + def arccos(self, a): + return jnp.arccos(a) + + def astype(self, a, dtype): + if isinstance(dtype, str): + dtype = getattr(jnp, dtype, None) or jnp.dtype(dtype) + return jnp.asarray(a, dtype=dtype) + def repeat(self, a, repeats, axis=None): return jnp.repeat(a, repeats, axis) @@ -2208,6 +2291,23 @@ def outer(self, a, b): def clip(self, a, a_min, a_max): return torch.clamp(a, a_min, a_max) + def real(self, a): + return torch.real(a) + + def imag(self, a): + return torch.imag(a) + + def conj(self, a): + return torch.conj(a) + + def arccos(self, a): + return torch.acos(a) + + def astype(self, a, dtype): + if isinstance(dtype, str): + dtype = getattr(torch, dtype, None) + return a.to(dtype=dtype) + def repeat(self, a, repeats, axis=None): return torch.repeat_interleave(a, repeats, dim=axis) @@ -2709,6 +2809,23 @@ def outer(self, a, b): def clip(self, a, a_min, a_max): return cp.clip(a, a_min, a_max) + def real(self, a): + return cp.real(a) + + def imag(self, a): + return cp.imag(a) + + def conj(self, a): + return cp.conj(a) + + def arccos(self, a): + return cp.arccos(a) + + def astype(self, a, dtype): + if isinstance(dtype, str): + dtype = getattr(cp, dtype, None) or cp.dtype(dtype) + return cp.asarray(a, dtype=dtype) + def repeat(self, a, repeats, axis=None): return cp.repeat(a, repeats, axis) @@ -3143,6 +3260,23 @@ def outer(self, a, b): def clip(self, a, a_min, a_max): return tnp.clip(a, a_min, a_max) + def real(self, a): + return tnp.real(a) + + def imag(self, a): + return tnp.imag(a) + + def conj(self, a): + return tnp.conj(a) + + def arccos(self, a): + return tnp.arccos(a) + + def astype(self, a, dtype): + if isinstance(dtype, str): + dtype = getattr(tnp, dtype, None) or tnp.dtype(dtype) + return tnp.array(a, dtype=dtype) + def repeat(self, a, repeats, axis=None): return tnp.repeat(a, repeats, axis) diff --git a/ot/sgot.py b/ot/sgot.py index 86d3a6036..08620b0d2 100644 --- a/ot/sgot.py +++ b/ot/sgot.py @@ -15,88 +15,18 @@ # Thibaut Germain # License: MIT License -import numpy as np import ot from ot.backend import get_backend -### -# Settings : (Ds,Rs,Ls) if primal, (Ds,Xs,prs,pls) if dual -### - ##################################################################################################################################### ##################################################################################################################################### -### OT METRIC ### +### NORMALISATION AND OPERATOR ATOMS ### ##################################################################################################################################### ##################################################################################################################################### -def principal_grassman_matrix(Ps, Pt, eps: float = 1e-12, nx=None): - """Compute the unitary Grassmann matrix for source and target domains. - - Parameters - ---------- - Ps : array-like, shape (l, n_ds) - Source domain data, with columns spanning the source subspace. - Pt : array-like, shape (l, n_dt) - Target domain data, with columns spanning the target subspace. - eps : float, optional - Minimum column norm used to avoid division by zero. Default is 1e-12. - - Returns - ------- - C : np.ndarray, shape (n_ds, n_dt) - Grassmann matrix between source and target subspaces. - """ - if nx is None: - nx = get_backend(Ps, Pt) - - Ps = nx.asarray(Ps) - Pt = nx.asarray(Pt) - - ns = nx.sqrt(nx.sum(Ps * nx.conj(Ps), axis=0, keepdims=True)) - nt = nx.sqrt(nx.sum(Pt * nx.conj(Pt), axis=0, keepdims=True)) - - ns = nx.clip(ns, eps, 1e300) - nt = nx.clip(nt, eps, 1e300) - - Psn = Ps / ns - Ptn = Pt / nt - - return nx.dot(nx.conj(Psn).T, Ptn) - - -def eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt, nx=None): - """Compute pairwise Grassmann matrices for source and target domains. - - Parameters - ---------- - Rs: array-like, shape (L, n_s) - Source right eigenvectors. - Ls: array-like, shape (L, n_s) - Source left eigenvectors. - Rt: array-like, shape (L, n_t) - Target right eigenvectors. - Lt: array-like, shape (L, n_t) - Target left eigenvectors. - - Returns - ---------- - C: np.ndarray, shape (n_s, n_t) - Eigenvector chordal cost matrix. - """ - if nx is None: - nx = get_backend(Rs, Ls, Rt, Lt) - - Cr = principal_grassman_matrix(Rs, Rt, nx=nx) - Cl = principal_grassman_matrix(Ls, Lt, nx=nx) - - prod = nx.real(Cr * Cl) - prod = nx.clip(prod, 0.0, 1.0) - return nx.sqrt(1.0 - prod) - - def eigenvalue_cost_matrix( - Ds, Dt, real_scale: float = 1.0, imag_scale: float = 1.0, nx=None + Ds, Dt, q=1, real_scale: float = 1.0, imag_scale: float = 1.0, nx=None ): """Compute pairwise eigenvalue distances for source and target domains. @@ -119,54 +49,11 @@ def eigenvalue_cost_matrix( if nx is None: nx = get_backend(Ds, Dt) - Ds = nx.asarray(Ds) - Dt = nx.asarray(Dt) Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale Dtn = nx.real(Dt) * real_scale + 1j * nx.imag(Dt) * imag_scale - return nx.abs(Dsn[:, None] - Dtn[None, :]) - - -def chordal_cost_matrix( - Ds, Rs, Ls, Dt, Rt, Lt, real_scale=1.0, imag_scale=1.0, alpha=0.5, p=2, nx=None -): - """Compute the chordal cost matrix between source and target spectral decompositions. - - Parameters - ---------- - Ds: array-like, shape (n_s,) - Source eigenvalues. - Rs: array-like, shape (L, n_s) - Source right eigenvectors. - Ls: array-like, shape (L, n_s) - Source left eigenvectors. - Dt: array-like, shape (n_t,) - Target eigenvalues. - Rt: array-like, shape (L, n_t) - Target right eigenvectors. - Lt: array-like, shape (L, n_t) - Target left eigenvectors. - real_scale: float, optional - Scale factor for real parts, default 1.0. - imag_scale: float, optional - Scale factor for imaginary parts, default 1.0. - alpha: float, optional - Weighting factor for the eigenvalue cost, default 0.5. - p: int, optional - Power for the chordal distance, default 2. - - Returns - ---------- - C: np.ndarray, shape (n_s, n_t) - Chordal cost matrix. - """ - if nx is None: - nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) - CD = eigenvalue_cost_matrix( - Ds, Dt, real_scale=real_scale, imag_scale=imag_scale, nx=nx - ) - CC = eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt, nx=nx) - C = alpha * CD + (1.0 - alpha) * CC - return C**p + prod = Dsn[:, None] - Dtn[None, :] + prod = nx.real(prod * nx.conj(prod)) + return prod ** (q / 2) def ot_plan(C, Ws=None, Wt=None, nx=None): @@ -189,123 +76,20 @@ def ot_plan(C, Ws=None, Wt=None, nx=None): if nx is None: nx = get_backend(C) - C = nx.asarray(C) n, m = C.shape if Ws is None: - Ws = nx.ones((n,), dtype=C.dtype) / float(n) - else: - Ws = nx.asarray(Ws) + Ws = nx.ones((n,), type_as=C) / float(n) if Wt is None: - Wt = nx.ones((m,), dtype=C.dtype) / float(m) - else: - Wt = nx.asarray(Wt) + Wt = nx.ones((m,), type_as=C) / float(m) Ws = Ws / nx.sum(Ws) Wt = Wt / nx.sum(Wt) C_real = nx.real(C) - C_np = ot.backend.to_numpy(C_real) - Ws_np = ot.backend.to_numpy(Ws) - Wt_np = ot.backend.to_numpy(Wt) - - return ot.emd(Ws_np, Wt_np, C_np) - - -def ot_score(C, P, p: int = 2, nx=None): - """Compute the OT score (distance) given a cost matrix and a transport plan. - - Parameters - ---------- - C: array-like, shape (n, m) - Cost matrix. - P: array-like, shape (n, m) - Transport plan. - p: int, optional - Power for the OT score, default 2. - - Returns - ---------- - dist: float - OT score (distance). - """ - if nx is None: - nx = get_backend(C) - C = nx.asarray(C) - P = nx.asarray(P) - return float(nx.sum(C * P) ** (1.0 / p)) - - -def chordal_metric( - Ds, - Rs, - Ls, - Dt, - Rt, - Lt, - real_scale: float = 1.0, - imag_scale: float = 1.0, - alpha: float = 0.5, - p: int = 2, - nx=None, -): - """Compute the chordal OT metric between two spectral decompositions. - - Parameters - ---------- - Ds: array-like, shape (n_s,) - Source eigenvalues. - Rs: array-like, shape (L, n_s) - Source right eigenvectors. - Ls: array-like, shape (L, n_s) - Source left eigenvectors. - Dt: array-like, shape (n_t,) - Target eigenvalues. - Rt: array-like, shape (L, n_t) - Target right eigenvectors. - Lt: array-like, shape (L, n_t) - Target left eigenvectors. - real_scale: float, optional - Scale factor for real parts, default 1.0. - imag_scale: float, optional - Scale factor for imaginary parts, default 1.0. - alpha: float, optional - Weighting factor for the eigenvalue cost, default 0.5. - p: int, optional - Power for the chordal distance, default 2. - - Returns - ---------- - dist: float - Chordal OT metric value. - """ - if nx is None: - nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) - - C = chordal_cost_matrix( - Ds, - Rs, - Ls, - Dt, - Rt, - Lt, - real_scale=real_scale, - imag_scale=imag_scale, - alpha=alpha, - p=p, - nx=nx, - ) - P = ot_plan(C, nx=nx) - return ot_score(C, P, p=p, nx=nx) - - -##################################################################################################################################### -##################################################################################################################################### -### NORMALISATION AND OPERATOR ATOMS ### -##################################################################################################################################### -##################################################################################################################################### + return ot.emd(Ws, Wt, C_real) def _normalize_columns(A, nx, eps=1e-12): @@ -326,11 +110,12 @@ def _normalize_columns(A, nx, eps=1e-12): Column-normalized array. """ nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) - nrm = nx.clip(nrm, eps, 1e300) + nrm = nx.real(nrm) # norm is real; avoid complex dtype for maximum (e.g. torch) + nrm = nx.maximum(nrm, eps) return A / nrm -def _delta_matrix_1d_hs(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): +def _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): """Compute the normalized inner-product delta matrix for eigenspaces. Parameters @@ -356,11 +141,6 @@ def _delta_matrix_1d_hs(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): if nx is None: nx = get_backend(Rs, Ls, Rt, Lt) - Rs = nx.asarray(Rs) - Ls = nx.asarray(Ls) - Rt = nx.asarray(Rt) - Lt = nx.asarray(Lt) - Rsn = _normalize_columns(Rs, nx=nx, eps=eps) Lsn = _normalize_columns(Ls, nx=nx, eps=eps) Rtn = _normalize_columns(Rt, nx=nx, eps=eps) @@ -374,104 +154,6 @@ def _delta_matrix_1d_hs(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): return delta -def _atoms_from_operator(T, r=None, sort_mode="closest_to_1"): - """Extract dua; eigen-atoms from a square operator. - - Parameters - ---------- - T: array-like, shape (d, d) - Input linear operator. - r: int, optional - Number of modes to keep. If None, keep all modes. - sort_mode: str, optional - Eigenvalue sorting mode: "closest_to_1", "closest_to_0", or "largest_mag". - - Returns - ---------- - D: np.ndarray, shape (r,) - Selected eigenvalues. - R: np.ndarray, shape (d, r) - Corresponding right eigenvectors. - L: np.ndarray, shape (d, r) - Dual left eigenvectors. - """ - nx = get_backend(T) - T = nx.asarray(T) - - if T.ndim != 2 or T.shape[0] != T.shape[1]: - raise ValueError(f"T must be a square 2D array; got shape {T.shape}") - - d = int(T.shape[0]) - if r is None: - r = d - r = int(r) - if not (1 <= r <= d): - raise ValueError(f"r must be an integer in [1, {d}], got r={r}") - - T_np = ot.backend.to_numpy(T) - evals_np, evecs_np = np.linalg.eig(T_np) - - if sort_mode == "closest_to_1": - order = np.argsort(np.abs(evals_np - 1.0)) - elif sort_mode == "closest_to_0": - order = np.argsort(np.abs(evals_np)) - elif sort_mode == "largest_mag": - order = np.argsort(-np.abs(evals_np)) - else: - raise ValueError( - "sort_mode must be one of 'closest_to_1', 'closest_to_0', or 'largest_mag'" - ) - - idx = order[:r] - D_np = evals_np[idx] - R_np = evecs_np[:, idx] - - evalsL_np, evecsL_np = np.linalg.eig(T_np.conj().T) - - L_np = np.zeros((d, r), dtype=np.complex128) - used = set() - - for i, lam in enumerate(D_np): - targets = np.abs(evalsL_np - np.conj(lam)) - for j in np.argsort(targets): - if j not in used: - used.add(j) - L_np[:, i] = evecsL_np[:, j] - break - - if hasattr(nx, "from_numpy"): - D = nx.from_numpy(D_np, type_as=T) - R = nx.from_numpy(R_np, type_as=T) - L = nx.from_numpy(L_np, type_as=T) - else: - D = nx.asarray(D_np) - R = nx.asarray(R_np) - L = nx.asarray(L_np) - - G = nx.dot(nx.conj(L).T, R) - - G_np = ot.backend.to_numpy(G) - if np.linalg.matrix_rank(G_np) < r: - raise ValueError("Dual normalization failed: L^* R is singular.") - - invG_H_np = np.linalg.inv(G_np).conj().T - if hasattr(nx, "from_numpy"): - invG_H = nx.from_numpy(invG_H_np, type_as=T) - else: - invG_H = nx.asarray(invG_H_np) - - L = nx.dot(L, invG_H) - - return D, R, L - - -##################################################################################################################################### -##################################################################################################################################### -### GRASSMANNIAN METRIC ### -##################################################################################################################################### -##################################################################################################################################### - - def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300): """Compute squared Grassmannian distances from delta similarities. @@ -494,7 +176,6 @@ def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1 if nx is None: nx = get_backend(delta) - delta = nx.asarray(delta) delta = nx.clip(delta, 0.0, 1.0) if grassman_metric == "geodesic": @@ -522,6 +203,7 @@ def cost( Lt, eta=0.5, p=2, + q=1, grassman_metric="chordal", real_scale=1.0, imag_scale=1.0, @@ -531,14 +213,14 @@ def cost( Parameters ---------- - Ds: array-like, shape (n_s,) or (n_s, n_s) - Eigenvalues of operator T1 (or diagonal matrix). + Ds: array-like, shape (n_s,) + Eigenvalues of operator T1. Rs: array-like, shape (L, n_s) Right eigenvectors of operator T1. Ls: array-like, shape (L, n_s) Left eigenvectors of operator T1. - Dt: array-like, shape (n_t,) or (n_t, n_t) - Eigenvalues of operator T2 (or diagonal matrix). + Dt: array-like, shape (n_t,) + Eigenvalues of operator T2. Rt: array-like, shape (L, n_t) Right eigenvectors of operator T2. Lt: array-like, shape (L, n_t) @@ -562,25 +244,44 @@ def cost( if nx is None: nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) - Ds = nx.asarray(Ds) - Dt = nx.asarray(Dt) - if len(Ds.shape) == 2: - lam1 = nx.diag(Ds) - else: - lam1 = Ds.reshape((-1,)) - if len(Dt.shape) == 2: - lam2 = nx.diag(Dt) - else: - lam2 = Dt.reshape((-1,)) - - lam1 = lam1.astype(complex) - lam2 = lam2.astype(complex) - - lam1s = nx.real(lam1) * real_scale + 1j * nx.imag(lam1) * imag_scale - lam2s = nx.real(lam2) * real_scale + 1j * nx.imag(lam2) * imag_scale - C_lambda = nx.abs(lam1s[:, None] - lam2s[None, :]) ** 2 - - delta = _delta_matrix_1d_hs(Rs, Ls, Rt, Lt, nx=nx) + if Ds.ndim != 1: + raise ValueError(f"cost() expects Ds to be 1D (n,), got shape {Ds.shape}") + lam1 = Ds + + if Dt.ndim != 1: + raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}") + lam2 = Dt + + lam1 = nx.astype(lam1, "complex128") + lam2 = nx.astype(lam2, "complex128") + + if Rs.shape != Ls.shape: + raise ValueError( + f"Rs and Ls must have the same shape, got {Rs.shape} and {Ls.shape}" + ) + + if Rt.shape != Lt.shape: + raise ValueError( + f"Rt and Lt must have the same shape, got {Rt.shape} and {Lt.shape}" + ) + + if Rs.shape[1] != lam1.shape[0]: + raise ValueError( + f"Number of source eigenvectors ({Rs.shape[1]}) must match " + f"number of source eigenvalues ({lam1.shape[0]})" + ) + + if Rt.shape[1] != lam2.shape[0]: + raise ValueError( + f"Number of target eigenvectors ({Rt.shape[1]}) must match " + f"number of target eigenvalues ({lam2.shape[0]})" + ) + + C_lambda = eigenvalue_cost_matrix( + lam1, lam2, q=q, real_scale=real_scale, imag_scale=imag_scale, nx=nx + ) + + delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx) C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) C2 = eta * C_lambda + (1.0 - eta) * C_grass @@ -599,6 +300,7 @@ def metric( eta=0.5, p=2, q=1, + r=2, grassman_metric="chordal", real_scale=1.0, imag_scale=1.0, @@ -610,14 +312,14 @@ def metric( Parameters ---------- - Ds: array-like, shape (n_s,) or (n_s, n_s) - Eigenvalues of operator T1 (or diagonal matrix). + Ds: array-like, shape (n_s,) + Eigenvalues of operator T1. Rs: array-like, shape (L, n_s) Right eigenvectors of operator T1. Ls: array-like, shape (L, n_s) Left eigenvectors of operator T1. - Dt: array-like, shape (n_t,) or (n_t, n_t) - Eigenvalues of operator T2 (or diagonal matrix). + Dt: array-like, shape (n_t,) + Eigenvalues of operator T2. Rt: array-like, shape (L, n_t) Right eigenvectors of operator T2. Lt: array-like, shape (L, n_t) @@ -625,9 +327,13 @@ def metric( eta: float, optional Weighting between spectral and Grassmann terms, default 0.5. p: int, optional - Power for the OT cost, default 2. + Exponent defining the OT ground cost and Wasserstein order. The cost matrix is raised to the power p/2 and the OT objective + is raised to the power 1/p. Default is 2. q: int, optional - Outer root applied to the OT objective, default 1. + Exponent applied to the eigenvalue distance in the spectral term. Controls the geometry of the eigenvalue cost matrix. + Default is 1. + r: int, optional + Outer root applied to the Wasserstein objective. Default is 2. grassman_metric: str, optional Metric type: "geodesic", "chordal", "procrustes", or "martin". real_scale: float, optional @@ -647,6 +353,11 @@ def metric( if nx is None: nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + if Ds.ndim != 1: + raise ValueError(f"metric() expects Ds to be 1D (n,), got shape {Ds.shape}") + if Dt.ndim != 1: + raise ValueError(f"metric() expects Dt to be 1D (n,), got shape {Dt.shape}") + C = cost( Ds, Rs, @@ -656,204 +367,13 @@ def metric( Lt, eta=eta, p=p, + q=q, grassman_metric=grassman_metric, real_scale=real_scale, imag_scale=imag_scale, nx=nx, ) - n, m = C.shape - - if Ws is None: - Ws = nx.ones((n,), dtype=C.dtype) / float(n) - else: - Ws = nx.asarray(Ws) - - if Wt is None: - Wt = nx.ones((m,), dtype=C.dtype) / float(m) - else: - Wt = nx.asarray(Wt) - - Ws = Ws / nx.sum(Ws) - Wt = Wt / nx.sum(Wt) - P = ot_plan(C, Ws=Ws, Wt=Wt, nx=nx) - obj = ot_score(C, P, p=p, nx=nx) - return float(obj) ** (1.0 / float(q)) - - -def metric_from_operator( - T1, - T2, - r=None, - eta=0.5, - p=2, - q=1, - grassman_metric="chordal", - real_scale=1.0, - imag_scale=1.0, - Ws=None, - Wt=None, -): - """Compute the SGOT metric directly from two operators. - - Parameters - ---------- - T1: array-like, shape (d, d) - First operator. - T2: array-like, shape (d, d) - Second operator. - r: int, optional - Number of modes to keep. If None, keep all modes. - eta: float, optional - Weighting between spectral and Grassmann terms, default 0.5. - p: int, optional - Power for the OT cost, default 2. - q: int, optional - Outer root applied to the OT objective, default 1. - grassman_metric: str, optional - Metric type: "geodesic", "chordal", "procrustes", or "martin". - real_scale: float, optional - Scale factor for real parts, default 1.0. - imag_scale: float, optional - Scale factor for imaginary parts, default 1.0. - Ws: array-like, shape (n_s,), optional - Source distribution. If None, uses a uniform distribution. - Wt: array-like, shape (n_t,), optional - Target distribution. If None, uses a uniform distribution. - - Returns - ---------- - dist: float - SGOT metric value. - """ - Ds, Rs, Ls = _atoms_from_operator(T1, r=r, sort_mode="closest_to_1") - Dt, Rt, Lt = _atoms_from_operator(T2, r=r, sort_mode="closest_to_1") - - return metric( - Ds, - Rs, - Ls, - Dt, - Rt, - Lt, - eta=eta, - p=p, - q=q, - grassman_metric=grassman_metric, - real_scale=real_scale, - imag_scale=imag_scale, - Ws=Ws, - Wt=Wt, - ) - - -def operator_estimator( - X, - Y=None, - r=None, - ref=1e-8, - force_complex=False, -): - """Estimate a linear operator from data. - - Parameters - ---------- - X: array-like, shape (n_samples, d) or (d, n_samples) - Input snapshot matrix. - Y: array-like, shape like X, optional - Output snapshot matrix. If None, uses a one-step shift of X. - r: int, optional - Rank for optional truncated SVD of the estimated operator. - ref: float, optional - Tikhonov regularization strength, default 1e-8. - force_complex: bool, optional - If True, cast inputs to complex dtype. - - Returns - ---------- - T_hat: np.ndarray, shape (d, d) - Estimated linear operator. - """ - nx = get_backend(X, Y) if Y is not None else get_backend(X) - - X = nx.asarray(X) - - if Y is None: - if X.ndim != 2 or int(X.shape[0]) < 2: - raise ValueError("If Y is None, X must be 2D with at least 2 samples/rows.") - X0 = X[:-1] - Y0 = X[1:] - else: - Y = nx.asarray(Y) - if tuple(X.shape) != tuple(Y.shape): - raise ValueError( - f"X and Y must have the same shape; got {X.shape} vs {Y.shape}" - ) - X0, Y0 = X, Y - - if X0.shape[0] >= 1 and X0.shape[0] != X0.shape[1]: - if X0.shape[0] >= X0.shape[1]: - Xc = X0.T - Yc = Y0.T - else: - Xc = X0 - Yc = Y0 - else: - Xc = X0 - Yc = Y0 - - if Xc.ndim != 2 or Yc.ndim != 2: - raise ValueError("X and Y must be 2D arrays after processing.") - - d, n = int(Xc.shape[0]), int(Xc.shape[1]) - if tuple(Yc.shape) != (d, n): - raise ValueError( - f"After formatting, expected Y to have shape {(d, n)}, got {Yc.shape}" - ) - - if force_complex: - Xc_np = ot.backend.to_numpy(Xc) # explicit backend->NumPy copy - Yc_np = ot.backend.to_numpy(Yc) - Xc_np = Xc_np.astype(np.complex128, copy=False) - Yc_np = Yc_np.astype(np.complex128, copy=False) - if hasattr(nx, "from_numpy"): - Xc = nx.from_numpy(Xc_np, type_as=Xc) - Yc = nx.from_numpy(Yc_np, type_as=Yc) - else: - Xc = nx.asarray(Xc_np) - Yc = nx.asarray(Yc_np) - - XXH = nx.dot(Xc, nx.conj(Xc).T) - YXH = nx.dot(Yc, nx.conj(Xc).T) - A = XXH + ref * nx.eye(d, type_as=XXH) - - AH = nx.conj(A).T - BH = nx.conj(YXH).T - - AH_np = ot.backend.to_numpy(AH) # explicit backend->NumPy copy - BH_np = ot.backend.to_numpy(BH) - Xsol_np = np.linalg.solve(AH_np, BH_np) - - if hasattr(nx, "from_numpy"): - Xsol = nx.from_numpy(Xsol_np, type_as=YXH) - else: - Xsol = nx.asarray(Xsol_np) - - T_hat = nx.conj(Xsol).T - - if r is not None: - r = int(r) - if not (1 <= r <= d): - raise ValueError(f"r must be in [1, {d}], got r={r}") - - T_np = ot.backend.to_numpy(T_hat) # explicit backend->NumPy copy - U, S, Vh = np.linalg.svd(T_np, full_matrices=False) - T_np = (U[:, :r] * S[:r]) @ Vh[:r, :] - - if hasattr(nx, "from_numpy"): - T_hat = nx.from_numpy(T_np, type_as=T_hat) - else: - T_hat = nx.asarray(T_np) - - return T_hat + obj = float(nx.sum(C * P) ** (1.0 / p)) + return float(obj) ** (1.0 / float(r)) diff --git a/test/test_backend.py b/test/test_backend.py index efd696ef0..e322e5d35 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -338,6 +338,9 @@ def test_func_backends(nx): sp_col = np.array([0, 3, 1, 2, 2]) sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64) + M_complex = M + 1j * rnd.randn(10, 3) + v_acos = np.clip(v, -0.99, 0.99) + lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: @@ -722,6 +725,27 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("atan2") + M_complex_b = nx.from_numpy(M_complex) + A = nx.real(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("real") + A = nx.imag(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("imag") + A = nx.conj(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("conj") + v_acos_b = nx.from_numpy(v_acos) + A = nx.arccos(v_acos_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("arccos") + A = nx.astype(Mb, "float64") + lst_b.append(nx.to_numpy(A)) + lst_name.append("astype float64") + A = nx.astype(vb, "complex128") + lst_b.append(nx.to_numpy(A)) + lst_name.append("astype complex128") + A = nx.transpose(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append("transpose") diff --git a/test/test_sgot.py b/test/test_sgot.py new file mode 100644 index 000000000..c1330696b --- /dev/null +++ b/test/test_sgot.py @@ -0,0 +1,333 @@ +"""Tests for ot.sgot module""" + +# Author: Sienna O'Shea +# Thibaut Germain +# License: MIT License + +import numpy as np +import pytest + +from ot.backend import get_backend + +try: + import torch +except ImportError: + torch = None + +try: + import jax + import jax.numpy as jnp +except ImportError: + jax = None + +from ot.sgot import ( + eigenvalue_cost_matrix, + _delta_matrix_1d, + _grassmann_distance_squared, + cost, + metric, +) + +rng = np.random.RandomState(0) + + +def rand_complex(shape): + real = rng.randn(*shape) + imag = rng.randn(*shape) + return real + 1j * imag + + +def random_atoms(d=8, r=4): + Ds = rand_complex((r,)) + Rs = rand_complex((d, r)) + Ls = rand_complex((d, r)) + Dt = rand_complex((r,)) + Rt = rand_complex((d, r)) + Lt = rand_complex((d, r)) + return Ds, Rs, Ls, Dt, Rt, Lt + + +# --------------------------------------------------------------------- +# DATA / SAMPLING TESTS +# --------------------------------------------------------------------- + + +def test_atoms_are_complex(): + """Confirm sampled atoms are complex (Gaussian real + 1j*imag).""" + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + for name, arr in [ + ("Ds", Ds), + ("Rs", Rs), + ("Ls", Ls), + ("Dt", Dt), + ("Rt", Rt), + ("Lt", Lt), + ]: + assert np.iscomplexobj(arr), f"{name} should be complex" + assert np.any(np.imag(arr) != 0), f"{name} should have non-zero imaginary part" + + +def test_random_d_r(): + """Sample d and r uniformly and run cost (and metric when available) with those shapes.""" + d_min, d_max = 4, 12 + r_min, r_max = 2, 6 + for _ in range(5): + d = int(rng.randint(d_min, d_max + 1)) + r = int(rng.randint(r_min, r_max + 1)) + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms(d=d, r=r) + C = cost(Ds, Rs, Ls, Dt, Rt, Lt) + np.testing.assert_allclose(C.shape, (r, r)) + assert np.all(np.isfinite(C)) and np.all(C >= 0) + try: + dist = metric(Ds, Rs, Ls, Dt, Rt, Lt) + assert np.isfinite(dist) and dist >= 0 + except TypeError: + pytest.skip("metric() unavailable (emd_c signature mismatch)") + + +# --------------------------------------------------------------------- +# BACKEND CONSISTENCY TESTS +# --------------------------------------------------------------------- + + +def test_backend_return(): + """Confirm get_backend returns the correct backend for numpy/torch/jax arrays.""" + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + assert nx is not None + assert nx.__name__ == "numpy" + + if torch is not None: + Ds_t = torch.from_numpy(Ds) + nx_t = get_backend(Ds_t) + assert nx_t is not None + assert nx_t.__name__ == "torch" + + if jax is not None: + Ds_j = jnp.array(Ds) + nx_j = get_backend(Ds_j) + assert nx_j is not None + assert nx_j.__name__ == "jax" + + +@pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"]) +def test_cost_backend_consistency(backend_name): + if backend_name == "torch" and torch is None: + pytest.skip("Torch not available") + if backend_name == "jax" and jax is None: + pytest.skip("JAX not available") + + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + + C_np = cost(Ds, Rs, Ls, Dt, Rt, Lt) + + if backend_name == "numpy": + C_backend = C_np + + elif backend_name == "torch": + Ds_b = torch.from_numpy(Ds) + Rs_b = torch.from_numpy(Rs) + Ls_b = torch.from_numpy(Ls) + Dt_b = torch.from_numpy(Dt) + Rt_b = torch.from_numpy(Rt) + Lt_b = torch.from_numpy(Lt) + C_backend = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + C_backend = C_backend.detach().cpu().numpy() + + elif backend_name == "jax": + Ds_b = jnp.array(Ds) + Rs_b = jnp.array(Rs) + Ls_b = jnp.array(Ls) + Dt_b = jnp.array(Dt) + Rt_b = jnp.array(Rt) + Lt_b = jnp.array(Lt) + C_backend = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + C_backend = np.array(C_backend) + + np.testing.assert_allclose(C_backend, C_np, atol=1e-6) + + +# --------------------------------------------------------------------- +# DELTA MATRIX TESTS +# --------------------------------------------------------------------- + + +def test_delta_identity(): + r = 4 + I = np.eye(r, dtype=complex) + delta = _delta_matrix_1d(I, I, I, I) + np.testing.assert_allclose(delta, np.eye(r), atol=1e-12) + + +def test_delta_swap_invariance(): + d, r = 6, 3 + R = rand_complex((d, r)) + L = R.copy() + delta1 = _delta_matrix_1d(R, L, R, L) + delta2 = _delta_matrix_1d(L, R, L, R) + np.testing.assert_allclose(delta1, delta2, atol=1e-12) + + +# --------------------------------------------------------------------- +# GRASSMANN DISTANCE TESTS +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize("metric_name", ["geodesic", "chordal", "procrustes", "martin"]) +def test_grassmann_zero_distance(metric_name): + delta = np.ones((3, 3)) + dist2 = _grassmann_distance_squared(delta, grassman_metric=metric_name) + np.testing.assert_allclose(dist2, 0.0, atol=1e-12) + + +def test_grassmann_invalid_name(): + delta = np.ones((2, 2)) + with pytest.raises(ValueError): + _grassmann_distance_squared(delta, grassman_metric="cordal") + + +# --------------------------------------------------------------------- +# COST TESTS +# --------------------------------------------------------------------- + + +def test_cost_self_zero(nx): + """(D_S R_S L_S D_S): diagonal of cost matrix (same atom to same atom) should be near zero.""" + Ds, Rs, Ls, _, _, _ = random_atoms() + Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) + C = cost(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2) + C_np = nx.to_numpy(C) + np.testing.assert_allclose(np.diag(C_np), np.zeros(C_np.shape[0]), atol=1e-10) + np.testing.assert_allclose(C_np, C_np.T, atol=1e-10) + + +def test_cost_reference(nx): + """Cost with same inputs and HPs should be deterministic (np.testing.assert_allclose).""" + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + eta, p, q = 0.5, 2, 1 + C1 = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + C2 = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + np.testing.assert_allclose(nx.to_numpy(C1), nx.to_numpy(C2), atol=1e-12) + + +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_cost_basic(grassman_metric, nx): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + C = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, grassman_metric=grassman_metric) + C_np = nx.to_numpy(C) + assert C_np.shape == (Ds.shape[0], Dt.shape[0]) + assert np.all(np.isfinite(C_np)) + assert np.all(C_np >= 0) + + +def test_cost_validation(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + + with pytest.raises(ValueError): + cost(Ds.reshape(-1, 1), Rs, Ls, Dt, Rt, Lt) + + with pytest.raises(ValueError): + cost(Ds, Rs[:, :-1], Ls, Dt, Rt, Lt) + + +# --------------------------------------------------------------------- +# METRIC TESTS +# --------------------------------------------------------------------- + + +def test_metric_self_zero(): + Ds, Rs, Ls, _, _, _ = random_atoms() + dist = metric(Ds, Rs, Ls, Ds, Rs, Ls) + assert np.isfinite(dist) + assert abs(dist) < 2e-4 + + +def test_metric_symmetry(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + d1 = metric(Ds, Rs, Ls, Dt, Rt, Lt) + d2 = metric(Dt, Rt, Lt, Ds, Rs, Ls) + np.testing.assert_allclose(d1, d2, atol=1e-8) + + +def test_metric_with_weights(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + r = Ds.shape[0] + + logits_s = rng.randn(r) + logits_t = rng.randn(r) + + Ws = np.exp(logits_s) + Ws = Ws / np.sum(Ws) + + Wt = np.exp(logits_t) + Wt = Wt / np.sum(Wt) + + dist = metric(Ds, Rs, Ls, Dt, Rt, Lt, Ws=Ws, Wt=Wt) + assert np.isfinite(dist) + + +# --------------------------------------------------------------------- +# HYPERPARAMETER SWEEP TEST +# --------------------------------------------------------------------- + + +def test_hyperparameter_sweep_cost(nx): + """Create test_cost for each trial: sweep over HPs and run cost().""" + grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] + n_trials = 10 + for _ in range(n_trials): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + eta = rng.uniform(0.0, 1.0) + p = rng.choice([1, 2]) + q = rng.choice([1, 2]) + gm = rng.choice(grassmann_types) + C = cost( + Ds_b, + Rs_b, + Ls_b, + Dt_b, + Rt_b, + Lt_b, + eta=eta, + p=p, + q=q, + grassman_metric=gm, + ) + C_np = nx.to_numpy(C) + assert C_np.shape == (Ds.shape[0], Dt.shape[0]) + assert np.all(np.isfinite(C_np)) + assert np.all(C_np >= 0) + + +def test_hyperparameter_sweep(): + grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] + + for _ in range(10): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + eta = rng.uniform(0.0, 1.0) + p = rng.choice([1, 2]) + q = rng.choice([1, 2]) + r = rng.choice([1, 2]) + gm = rng.choice(grassmann_types) + + dist = metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=eta, + p=p, + q=q, + r=r, + grassman_metric=gm, + ) + + assert np.isfinite(dist) + assert dist >= 0 From 630e359c59b8aedebeeac95bd13e327e18245ce3 Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Mon, 23 Feb 2026 12:41:58 +0100 Subject: [PATCH 06/10] fix astype in backend --- ot/backend.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index b2dfc4024..1def519a9 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2305,7 +2305,25 @@ def arccos(self, a): def astype(self, a, dtype): if isinstance(dtype, str): - dtype = getattr(torch, dtype, None) + # Map common numpy-style string dtypes to torch dtypes explicitly. + # This makes backend.astype robust across torch versions and aliases. + mapping = { + "float32": torch.float32, + "float64": torch.float64, + "float": torch.float32, + "double": torch.float64, + "complex64": getattr(torch, "complex64", None), + "complex128": getattr(torch, "complex128", None), + } + torch_dtype = mapping.get(dtype) + if torch_dtype is None: + # Fallback: try direct attribute lookup (e.g. torch.float16) + torch_dtype = getattr(torch, dtype, None) + if torch_dtype is None: + raise ValueError( + f"Unsupported dtype for TorchBackend.astype: {dtype!r}" + ) + dtype = torch_dtype return a.to(dtype=dtype) def repeat(self, a, repeats, axis=None): From 57c7a0ca6d0f6841d9b9eeb2d0234712ca7579ab Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Wed, 25 Feb 2026 17:09:32 +0100 Subject: [PATCH 07/10] edits as per PR #792 --- RELEASES.md | 12 +- ot/backend.py | 62 +---------- ot/sgot.py | 250 +++++++++++++++++++++++++----------------- test/test_backend.py | 9 +- test/test_sgot.py | 255 +++++++++++++++---------------------------- 5 files changed, 247 insertions(+), 341 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index e2e16f145..c734202c0 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,13 +1,6 @@ # Releases -## Upcomming 0.9.7.post1 - -#### New features -The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920). - - - ## 0.9.7.dev0 This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. @@ -20,8 +13,13 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) - Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) +<<<<<<< HEAD - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +======= +- Add cost functions between linear operators following + [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920) (PR #792) +>>>>>>> 8d13c55 (edits as per PR #792) #### Closed issues diff --git a/ot/backend.py b/ot/backend.py index 59fbb3082..0568f2e2f 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -662,15 +662,6 @@ def arccos(self, a): """ raise NotImplementedError() - def astype(self, a, dtype): - """ - Cast tensor to a given dtype. - - dtype can be a string (e.g. "complex128", "float64") or backend-specific - dtype. Backend converts to the corresponding type. - """ - raise NotImplementedError() - def repeat(self, a, repeats, axis=None): r""" Repeats elements of a tensor. @@ -1242,7 +1233,7 @@ def _from_numpy(self, a, type_as=None): elif isinstance(a, float): return a else: - return a.astype(type_as.dtype) + return np.asarray(a, dtype=type_as.dtype) def set_gradients(self, val, inputs, grads): # No gradients for numpy @@ -1374,11 +1365,6 @@ def conj(self, a): def arccos(self, a): return np.arccos(a) - def astype(self, a, dtype): - if isinstance(dtype, str): - dtype = getattr(np, dtype, None) or np.dtype(dtype) - return np.asarray(a, dtype=dtype) - def repeat(self, a, repeats, axis=None): return np.repeat(a, repeats, axis) @@ -1670,7 +1656,7 @@ def _from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) + return self._change_device(jnp.asarray(a, dtype=type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1808,11 +1794,6 @@ def conj(self, a): def arccos(self, a): return jnp.arccos(a) - def astype(self, a, dtype): - if isinstance(dtype, str): - dtype = getattr(jnp, dtype, None) or jnp.dtype(dtype) - return jnp.asarray(a, dtype=dtype) - def repeat(self, a, repeats, axis=None): return jnp.repeat(a, repeats, axis) @@ -1886,7 +1867,9 @@ def randperm(self, size, type_as=None): if not isinstance(size, int): raise ValueError("size must be an integer") if type_as is not None: - return jax.random.permutation(subkey, size).astype(type_as.dtype) + return jnp.asarray( + jax.random.permutation(subkey, size), dtype=type_as.dtype + ) else: return jax.random.permutation(subkey, size) @@ -2322,29 +2305,6 @@ def conj(self, a): def arccos(self, a): return torch.acos(a) - def astype(self, a, dtype): - if isinstance(dtype, str): - # Map common numpy-style string dtypes to torch dtypes explicitly. - # This makes backend.astype robust across torch versions and aliases. - mapping = { - "float32": torch.float32, - "float64": torch.float64, - "float": torch.float32, - "double": torch.float64, - "complex64": getattr(torch, "complex64", None), - "complex128": getattr(torch, "complex128", None), - } - torch_dtype = mapping.get(dtype) - if torch_dtype is None: - # Fallback: try direct attribute lookup (e.g. torch.float16) - torch_dtype = getattr(torch, dtype, None) - if torch_dtype is None: - raise ValueError( - f"Unsupported dtype for TorchBackend.astype: {dtype!r}" - ) - dtype = torch_dtype - return a.to(dtype=dtype) - def repeat(self, a, repeats, axis=None): return torch.repeat_interleave(a, repeats, dim=axis) @@ -2858,11 +2818,6 @@ def conj(self, a): def arccos(self, a): return cp.arccos(a) - def astype(self, a, dtype): - if isinstance(dtype, str): - dtype = getattr(cp, dtype, None) or cp.dtype(dtype) - return cp.asarray(a, dtype=dtype) - def repeat(self, a, repeats, axis=None): return cp.repeat(a, repeats, axis) @@ -2954,7 +2909,7 @@ def randperm(self, size, type_as=None): return self.rng_.permutation(size) else: with cp.cuda.Device(type_as.device): - return self.rng_.permutation(size).astype(type_as.dtype) + return cp.asarray(self.rng_.permutation(size), dtype=type_as.dtype) def coo_matrix(self, data, rows, cols, shape=None, type_as=None): data = self.from_numpy(data) @@ -3309,11 +3264,6 @@ def conj(self, a): def arccos(self, a): return tnp.arccos(a) - def astype(self, a, dtype): - if isinstance(dtype, str): - dtype = getattr(tnp, dtype, None) or tnp.dtype(dtype) - return tnp.array(a, dtype=dtype) - def repeat(self, a, repeats, axis=None): return tnp.repeat(a, repeats, axis) diff --git a/ot/sgot.py b/ot/sgot.py index 08620b0d2..d443c3490 100644 --- a/ot/sgot.py +++ b/ot/sgot.py @@ -25,9 +25,7 @@ ##################################################################################################################################### -def eigenvalue_cost_matrix( - Ds, Dt, q=1, real_scale: float = 1.0, imag_scale: float = 1.0, nx=None -): +def eigenvalue_cost_matrix(Ds, Dt, q=1, eigen_scaling=None, nx=None): """Compute pairwise eigenvalue distances for source and target domains. Parameters @@ -36,10 +34,10 @@ def eigenvalue_cost_matrix( Source eigenvalues. Dt: array-like, shape (n_t,) Target eigenvalues. - real_scale: float, optional - Scale factor for real parts, default 1.0. - imag_scale: float, optional - Scale factor for imaginary parts, default 1.0. + eigen_scaling: None or array-like of length 2, optional + Scaling (real_scale, imag_scale) applied to eigenvalues before computing + distances. If None, defaults to (1.0, 1.0). Accepts tuple/list or + array/tensor with two entries. Returns ---------- @@ -49,6 +47,14 @@ def eigenvalue_cost_matrix( if nx is None: nx = get_backend(Ds, Dt) + if eigen_scaling is None: + real_scale, imag_scale = 1.0, 1.0 + else: + if isinstance(eigen_scaling, (tuple, list)): + real_scale, imag_scale = eigen_scaling + else: + real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1] + Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale Dtn = nx.real(Dt) * real_scale + 1j * nx.imag(Dt) * imag_scale prod = Dsn[:, None] - Dtn[None, :] @@ -56,42 +62,6 @@ def eigenvalue_cost_matrix( return prod ** (q / 2) -def ot_plan(C, Ws=None, Wt=None, nx=None): - """Compute the optimal transport plan for a given cost matrix and marginals. - - Parameters - ---------- - C: array-like, shape (n, m) - Cost matrix. - Ws: array-like, shape (n,), optional - Source distribution. If None, uses a uniform distribution. - Wt: array-like, shape (m,), optional - Target distribution. If None, uses a uniform distribution. - - Returns - ---------- - P: np.ndarray, shape (n, m) - Optimal transport plan. - """ - if nx is None: - nx = get_backend(C) - - n, m = C.shape - - if Ws is None: - Ws = nx.ones((n,), type_as=C) / float(n) - - if Wt is None: - Wt = nx.ones((m,), type_as=C) / float(m) - - Ws = Ws / nx.sum(Ws) - Wt = Wt / nx.sum(Wt) - - C_real = nx.real(C) - - return ot.emd(Ws, Wt, C_real) - - def _normalize_columns(A, nx, eps=1e-12): """Normalize the columns of an array with a backend-aware norm. @@ -185,7 +155,11 @@ def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1 if grassman_metric == "procrustes": return 2.0 * (1.0 - delta) if grassman_metric == "martin": - return -nx.log(nx.clip(delta**2, eps, 1e300)) + # Martin-type Grassmann metric: -log(delta^2) with lower clamp at eps. + # We deliberately avoid any upper threshold to stay close to the + # information-geometric interpretation in Germain et al. (2025). + delta2 = nx.maximum(delta**2, eps) + return -nx.log(delta2) raise ValueError(f"Unknown grassman_metric: {grassman_metric}") @@ -194,7 +168,7 @@ def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1 ### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### ##################################################################################################################################### ##################################################################################################################################### -def cost( +def sgot_cost_matrix( Ds, Rs, Ls, @@ -205,11 +179,50 @@ def cost( p=2, q=1, grassman_metric="chordal", - real_scale=1.0, - imag_scale=1.0, + eigen_scaling=None, nx=None, ): - """Compute the SGOT cost matrix between two spectral decompositions. + r"""Compute the SGOT cost matrix between two spectral decompositions. + + This returns the discrete ground cost matrix used in the SGOT optimal transport + objective. Each spectral atom is :math:`z_i=(\lambda_i, V_i)` where + :math:`\lambda_i \in \mathbb{C}` is an eigenvalue and :math:`V_i` is the + associated (bi-orthogonal) eigenspace point. + + .. math:: + C_2(i,j) \;=\; \eta\,C_\lambda(i,j) \;+\; (1-\eta)\,C_G(i,j), + + with spectral term + + .. math:: + C_\lambda(i,j) \;=\; \big|\lambda_i - \lambda'_j\big|^{q}, + + and Grassmann term computed from a similarity score :math:`\delta_{ij}\in[0,1]` + built from left/right eigenvectors + + .. math:: + \delta_{ij} \;=\; \left|\langle r_i, r'_j\rangle\,\langle \ell_i, \ell'_j\rangle\right|. + + Depending on ``grassman_metric``, the Grassmann contribution is: + + - ``"chordal"``: + .. math:: + C_G(i,j) \;=\; 1 - \delta_{ij}^2 + - ``"geodesic"``: + .. math:: + C_G(i,j) \;=\; \arccos(\delta_{ij})^2 + - ``"procrustes"``: + .. math:: + C_G(i,j) \;=\; 2(1-\delta_{ij}) + - ``"martin"``: + .. math:: + C_G(i,j) \;=\; -\log\!\left(\max(\delta_{ij}^2,\varepsilon)\right) + + Finally, we return a matrix suited for a :math:`p`-Wasserstein objective by + treating :math:`C_2 \approx d^2` and outputting + + .. math:: + C(i,j) \;=\; \big(\operatorname{Re}(C_2(i,j))\big)^{p/2}. Parameters ---------- @@ -228,69 +241,75 @@ def cost( eta: float, optional Weighting between spectral and Grassmann terms, default 0.5. p: int, optional - Power for the OT cost, default 2. + Exponent defining the OT ground cost. The returned cost is :math:`d^p` with + :math:`d^2 \approx C_2`. Default is 2. + q: int, optional + Exponent applied to the eigenvalue distance in the spectral term. + Default is 1. grassman_metric: str, optional Metric type: "geodesic", "chordal", "procrustes", or "martin". - real_scale: float, optional - Scale factor for real parts, default 1.0. - imag_scale: float, optional - Scale factor for imaginary parts, default 1.0. + eigen_scaling: None or array-like of length 2, optional + Scaling ``(real_scale, imag_scale)`` applied to eigenvalues before computing + :math:`C_\lambda`. If provided, eigenvalues are transformed as + :math:`\lambda \mapsto \alpha\operatorname{Re}(\lambda) + i\,\beta\operatorname{Im}(\lambda)`. + If None, defaults to ``(1.0, 1.0)``. Accepts tuple/list or array/tensor with + two entries. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. Returns ---------- C: array-like, shape (n_s, n_t) - SGOT cost matrix. + SGOT cost matrix :math:`C = d^p`. + + References + ---------- + Germain et al., *Spectral-Grassmann Optimal Transport* (SGOT). """ if nx is None: nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) - if Ds.ndim != 1: - raise ValueError(f"cost() expects Ds to be 1D (n,), got shape {Ds.shape}") - lam1 = Ds - - if Dt.ndim != 1: - raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}") - lam2 = Dt - - lam1 = nx.astype(lam1, "complex128") - lam2 = nx.astype(lam2, "complex128") - - if Rs.shape != Ls.shape: + if Ds.ndim != 1 or Dt.ndim != 1: raise ValueError( - f"Rs and Ls must have the same shape, got {Rs.shape} and {Ls.shape}" + f"sgot_cost_matrix() expects Ds, Dt 1D; got Ds {getattr(Ds,'shape',None)}, Dt {getattr(Dt,'shape',None)}" ) - if Rt.shape != Lt.shape: + if Rs.shape != Ls.shape or Rt.shape != Lt.shape: raise ValueError( - f"Rt and Lt must have the same shape, got {Rt.shape} and {Lt.shape}" + f"Right/left eigenvector shapes must match; got (Rs,Ls)=({Rs.shape},{Ls.shape}), (Rt,Lt)=({Rt.shape},{Lt.shape})" ) - if Rs.shape[1] != lam1.shape[0]: + if Rs.shape[1] != Ds.shape[0] or Rt.shape[1] != Dt.shape[0]: raise ValueError( - f"Number of source eigenvectors ({Rs.shape[1]}) must match " - f"number of source eigenvalues ({lam1.shape[0]})" + f"Eigenvectors columns must match eigenvalues: Rs {Rs.shape[1]} vs Ds {Ds.shape[0]}, " + f"Rt {Rt.shape[1]} vs Dt {Dt.shape[0]}" ) - if Rt.shape[1] != lam2.shape[0]: - raise ValueError( - f"Number of target eigenvectors ({Rt.shape[1]}) must match " - f"number of target eigenvalues ({lam2.shape[0]})" - ) - - C_lambda = eigenvalue_cost_matrix( - lam1, lam2, q=q, real_scale=real_scale, imag_scale=imag_scale, nx=nx - ) + C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx) delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx) C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) C2 = eta * C_lambda + (1.0 - eta) * C_grass - C = C2 ** (p / 2.0) + C = nx.real(C2) ** (p / 2.0) return C -def metric( +def _validate_sgot_metric_inputs(Ds, Dt): + """Validate that eigenvalue inputs for SGOT metric are 1D.""" + Ds_shape = getattr(Ds, "shape", None) + Dt_shape = getattr(Dt, "shape", None) + Ds_ndim = getattr(Ds, "ndim", None) + Dt_ndim = getattr(Dt, "ndim", None) + if Ds_ndim != 1 or Dt_ndim != 1: + raise ValueError( + "sgot_metric() expects Ds and Dt to be 1D (n,), " + f"got Ds shape {Ds_shape} and Dt shape {Dt_shape}" + ) + + +def sgot_metric( Ds, Rs, Ls, @@ -302,13 +321,30 @@ def metric( q=1, r=2, grassman_metric="chordal", - real_scale=1.0, - imag_scale=1.0, + eigen_scaling=None, Ws=None, Wt=None, nx=None, ): - """Compute the SGOT metric between two spectral decompositions. + r"""Compute the SGOT metric between two spectral decompositions. + + This function computes a discrete optimal transport problem between two measures + over spectral atoms :math:`z_i=(\lambda_i, V_i)` and :math:`z'_j=(\lambda'_j, V'_j)`. + Using the ground cost matrix :math:`C=d^p` returned by :func:`sgot_cost_matrix`, + we solve: + + .. math:: + P^\star \in \arg\min_{P\in\Pi(W_s, W_t)} \langle C, P\rangle, + + and compute the associated :math:`p`-Wasserstein objective: + + .. math:: + \mathrm{obj} \;=\; \left(\sum_{i,j} C(i,j)\,P^\star_{ij}\right)^{1/p}. + + This implementation returns an additional outer root: + + .. math:: + \mathrm{SGOT} \;=\; \mathrm{obj}^{1/r}. Parameters ---------- @@ -327,38 +363,41 @@ def metric( eta: float, optional Weighting between spectral and Grassmann terms, default 0.5. p: int, optional - Exponent defining the OT ground cost and Wasserstein order. The cost matrix is raised to the power p/2 and the OT objective - is raised to the power 1/p. Default is 2. + Exponent defining the OT ground cost and Wasserstein order. The cost matrix + is :math:`d^p` and the OT objective is raised to the power :math:`1/p`. + Default is 2. q: int, optional - Exponent applied to the eigenvalue distance in the spectral term. Controls the geometry of the eigenvalue cost matrix. + Exponent applied to the eigenvalue distance in the spectral term. Default is 1. r: int, optional Outer root applied to the Wasserstein objective. Default is 2. grassman_metric: str, optional Metric type: "geodesic", "chordal", "procrustes", or "martin". - real_scale: float, optional - Scale factor for real parts, default 1.0. - imag_scale: float, optional - Scale factor for imaginary parts, default 1.0. + eigen_scaling: None or array-like of length 2, optional + Scaling ``(real_scale, imag_scale)`` applied to eigenvalues before computing + the spectral part of the cost. If None, defaults to ``(1.0, 1.0)``. Ws: array-like, shape (n_s,), optional Source distribution. If None, uses a uniform distribution. Wt: array-like, shape (n_t,), optional Target distribution. If None, uses a uniform distribution. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. Returns ---------- dist: float SGOT metric value. + + References + ---------- + Germain et al., *Spectral-Grassmann Optimal Transport* (SGOT). """ if nx is None: nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) - if Ds.ndim != 1: - raise ValueError(f"metric() expects Ds to be 1D (n,), got shape {Ds.shape}") - if Dt.ndim != 1: - raise ValueError(f"metric() expects Dt to be 1D (n,), got shape {Dt.shape}") + _validate_sgot_metric_inputs(Ds, Dt) - C = cost( + C = sgot_cost_matrix( Ds, Rs, Ls, @@ -369,11 +408,18 @@ def metric( p=p, q=q, grassman_metric=grassman_metric, - real_scale=real_scale, - imag_scale=imag_scale, + eigen_scaling=eigen_scaling, nx=nx, ) - P = ot_plan(C, Ws=Ws, Wt=Wt, nx=nx) + if Ws is None: + Ws = nx.ones((C.shape[0],), type_as=C) / float(C.shape[0]) + if Wt is None: + Wt = nx.ones((C.shape[1],), type_as=C) / float(C.shape[1]) + + Ws = Ws / nx.sum(Ws) + Wt = Wt / nx.sum(Wt) + + P = ot.emd2(Ws, Wt, nx.real(C)) obj = float(nx.sum(C * P) ** (1.0 / p)) - return float(obj) ** (1.0 / float(r)) + return obj ** (1.0 / float(r)) diff --git a/test/test_backend.py b/test/test_backend.py index e25297bfa..fe6af9c67 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -730,22 +730,19 @@ def test_func_backends(nx): A = nx.real(M_complex_b) lst_b.append(nx.to_numpy(A)) lst_name.append("real") + A = nx.imag(M_complex_b) lst_b.append(nx.to_numpy(A)) lst_name.append("imag") + A = nx.conj(M_complex_b) lst_b.append(nx.to_numpy(A)) lst_name.append("conj") + v_acos_b = nx.from_numpy(v_acos) A = nx.arccos(v_acos_b) lst_b.append(nx.to_numpy(A)) lst_name.append("arccos") - A = nx.astype(Mb, "float64") - lst_b.append(nx.to_numpy(A)) - lst_name.append("astype float64") - A = nx.astype(vb, "complex128") - lst_b.append(nx.to_numpy(A)) - lst_name.append("astype complex128") A = nx.transpose(Mb) lst_b.append(nx.to_numpy(A)) diff --git a/test/test_sgot.py b/test/test_sgot.py index c1330696b..4ecfd2e40 100644 --- a/test/test_sgot.py +++ b/test/test_sgot.py @@ -7,19 +7,6 @@ import numpy as np import pytest -from ot.backend import get_backend - -try: - import torch -except ImportError: - torch = None - -try: - import jax - import jax.numpy as jnp -except ImportError: - jax = None - from ot.sgot import ( eigenvalue_cost_matrix, _delta_matrix_1d, @@ -28,22 +15,23 @@ metric, ) -rng = np.random.RandomState(0) +def random_atoms(d=8, r=4, seed=42): + """Deterministic complex atoms for given d, r.""" -def rand_complex(shape): - real = rng.randn(*shape) - imag = rng.randn(*shape) - return real + 1j * imag + def _rand_complex(shape, seed_): + rng = np.random.RandomState(seed_) + real = rng.randn(*shape) + imag = rng.randn(*shape) + return real + 1j * imag + Ds = _rand_complex((r,), seed + 0) + Rs = _rand_complex((d, r), seed + 1) + Ls = _rand_complex((d, r), seed + 2) + Dt = _rand_complex((r,), seed + 3) + Rt = _rand_complex((d, r), seed + 4) + Lt = _rand_complex((d, r), seed + 5) -def random_atoms(d=8, r=4): - Ds = rand_complex((r,)) - Rs = rand_complex((d, r)) - Ls = rand_complex((d, r)) - Dt = rand_complex((r,)) - Rt = rand_complex((d, r)) - Lt = rand_complex((d, r)) return Ds, Rs, Ls, Dt, Rt, Lt @@ -52,101 +40,28 @@ def random_atoms(d=8, r=4): # --------------------------------------------------------------------- -def test_atoms_are_complex(): - """Confirm sampled atoms are complex (Gaussian real + 1j*imag).""" - Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() - for name, arr in [ - ("Ds", Ds), - ("Rs", Rs), - ("Ls", Ls), - ("Dt", Dt), - ("Rt", Rt), - ("Lt", Lt), - ]: - assert np.iscomplexobj(arr), f"{name} should be complex" - assert np.any(np.imag(arr) != 0), f"{name} should have non-zero imaginary part" - - -def test_random_d_r(): +def test_random_d_r(nx): """Sample d and r uniformly and run cost (and metric when available) with those shapes.""" + rng = np.random.RandomState(0) d_min, d_max = 4, 12 r_min, r_max = 2, 6 for _ in range(5): d = int(rng.randint(d_min, d_max + 1)) r = int(rng.randint(r_min, r_max + 1)) Ds, Rs, Ls, Dt, Rt, Lt = random_atoms(d=d, r=r) - C = cost(Ds, Rs, Ls, Dt, Rt, Lt) - np.testing.assert_allclose(C.shape, (r, r)) - assert np.all(np.isfinite(C)) and np.all(C >= 0) + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + C = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + C_np = nx.to_numpy(C) + np.testing.assert_allclose(C_np.shape, (r, r)) + assert np.all(np.isfinite(C_np)) and np.all(C_np >= 0) try: - dist = metric(Ds, Rs, Ls, Dt, Rt, Lt) - assert np.isfinite(dist) and dist >= 0 + dist = metric(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + dist_np = nx.to_numpy(dist) + assert np.isfinite(dist_np) and dist_np >= 0 except TypeError: pytest.skip("metric() unavailable (emd_c signature mismatch)") -# --------------------------------------------------------------------- -# BACKEND CONSISTENCY TESTS -# --------------------------------------------------------------------- - - -def test_backend_return(): - """Confirm get_backend returns the correct backend for numpy/torch/jax arrays.""" - Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() - nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) - assert nx is not None - assert nx.__name__ == "numpy" - - if torch is not None: - Ds_t = torch.from_numpy(Ds) - nx_t = get_backend(Ds_t) - assert nx_t is not None - assert nx_t.__name__ == "torch" - - if jax is not None: - Ds_j = jnp.array(Ds) - nx_j = get_backend(Ds_j) - assert nx_j is not None - assert nx_j.__name__ == "jax" - - -@pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"]) -def test_cost_backend_consistency(backend_name): - if backend_name == "torch" and torch is None: - pytest.skip("Torch not available") - if backend_name == "jax" and jax is None: - pytest.skip("JAX not available") - - Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() - - C_np = cost(Ds, Rs, Ls, Dt, Rt, Lt) - - if backend_name == "numpy": - C_backend = C_np - - elif backend_name == "torch": - Ds_b = torch.from_numpy(Ds) - Rs_b = torch.from_numpy(Rs) - Ls_b = torch.from_numpy(Ls) - Dt_b = torch.from_numpy(Dt) - Rt_b = torch.from_numpy(Rt) - Lt_b = torch.from_numpy(Lt) - C_backend = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) - C_backend = C_backend.detach().cpu().numpy() - - elif backend_name == "jax": - Ds_b = jnp.array(Ds) - Rs_b = jnp.array(Rs) - Ls_b = jnp.array(Ls) - Dt_b = jnp.array(Dt) - Rt_b = jnp.array(Rt) - Lt_b = jnp.array(Lt) - C_backend = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) - C_backend = np.array(C_backend) - - np.testing.assert_allclose(C_backend, C_np, atol=1e-6) - - # --------------------------------------------------------------------- # DELTA MATRIX TESTS # --------------------------------------------------------------------- @@ -161,7 +76,7 @@ def test_delta_identity(): def test_delta_swap_invariance(): d, r = 6, 3 - R = rand_complex((d, r)) + _, R, _, _, _, _ = random_atoms(d=d, r=r) L = R.copy() delta1 = _delta_matrix_1d(R, L, R, L) delta2 = _delta_matrix_1d(L, R, L, R) @@ -173,11 +88,14 @@ def test_delta_swap_invariance(): # --------------------------------------------------------------------- -@pytest.mark.parametrize("metric_name", ["geodesic", "chordal", "procrustes", "martin"]) -def test_grassmann_zero_distance(metric_name): - delta = np.ones((3, 3)) - dist2 = _grassmann_distance_squared(delta, grassman_metric=metric_name) - np.testing.assert_allclose(dist2, 0.0, atol=1e-12) +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_grassmann_zero_distance(grassman_metric, nx): + delta = nx.from_numpy(np.ones((3, 3))) + dist2 = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) + dist2_np = nx.to_numpy(dist2) + np.testing.assert_allclose(dist2_np, 0.0, atol=1e-12) def test_grassmann_invalid_name(): @@ -257,13 +175,11 @@ def test_metric_with_weights(): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() r = Ds.shape[0] - logits_s = rng.randn(r) - logits_t = rng.randn(r) - - Ws = np.exp(logits_s) + rng = np.random.RandomState(1) + Ws = rng.rand(r) Ws = Ws / np.sum(Ws) - Wt = np.exp(logits_t) + Wt = rng.rand(r) Wt = Wt / np.sum(Wt) dist = metric(Ds, Rs, Ls, Dt, Rt, Lt, Ws=Ws, Wt=Wt) @@ -276,58 +192,57 @@ def test_metric_with_weights(): def test_hyperparameter_sweep_cost(nx): - """Create test_cost for each trial: sweep over HPs and run cost().""" + """Sweep over a random set of HPs and run cost().""" grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] - n_trials = 10 - for _ in range(n_trials): - Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() - Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) - eta = rng.uniform(0.0, 1.0) - p = rng.choice([1, 2]) - q = rng.choice([1, 2]) - gm = rng.choice(grassmann_types) - C = cost( - Ds_b, - Rs_b, - Ls_b, - Dt_b, - Rt_b, - Lt_b, - eta=eta, - p=p, - q=q, - grassman_metric=gm, - ) - C_np = nx.to_numpy(C) - assert C_np.shape == (Ds.shape[0], Dt.shape[0]) - assert np.all(np.isfinite(C_np)) - assert np.all(C_np >= 0) + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + rng = np.random.RandomState(2) + eta = rng.uniform(0.0, 1.0) + p = rng.choice([1, 2]) + q = rng.choice([1, 2]) + gm = rng.choice(grassmann_types) + C = cost( + Ds_b, + Rs_b, + Ls_b, + Dt_b, + Rt_b, + Lt_b, + eta=eta, + p=p, + q=q, + grassman_metric=gm, + ) + C_np = nx.to_numpy(C) + assert C_np.shape == (Ds.shape[0], Dt.shape[0]) + assert np.all(np.isfinite(C_np)) + assert np.all(C_np >= 0) -def test_hyperparameter_sweep(): - grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_hyperparameter_sweep(grassman_metric): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + rng = np.random.RandomState(3) + eta = rng.uniform(0.0, 1.0) + p = rng.choice([1, 2]) + q = rng.choice([1, 2]) + r = rng.choice([1, 2]) + + dist = metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=eta, + p=p, + q=q, + r=r, + grassman_metric=grassman_metric, + ) - for _ in range(10): - Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() - eta = rng.uniform(0.0, 1.0) - p = rng.choice([1, 2]) - q = rng.choice([1, 2]) - r = rng.choice([1, 2]) - gm = rng.choice(grassmann_types) - - dist = metric( - Ds, - Rs, - Ls, - Dt, - Rt, - Lt, - eta=eta, - p=p, - q=q, - r=r, - grassman_metric=gm, - ) - - assert np.isfinite(dist) - assert dist >= 0 + assert np.isfinite(dist) + assert dist >= 0 From a0e74beca4ae9979334742063534c3dfee1dd14d Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Fri, 27 Feb 2026 14:03:31 +0100 Subject: [PATCH 08/10] cost & metric fixed in test_sgot --- test/test_sgot.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/test/test_sgot.py b/test/test_sgot.py index 4ecfd2e40..10c9cb70e 100644 --- a/test/test_sgot.py +++ b/test/test_sgot.py @@ -11,8 +11,8 @@ eigenvalue_cost_matrix, _delta_matrix_1d, _grassmann_distance_squared, - cost, - metric, + sgot_cost_matrix, + sgot_metric, ) @@ -50,16 +50,16 @@ def test_random_d_r(nx): r = int(rng.randint(r_min, r_max + 1)) Ds, Rs, Ls, Dt, Rt, Lt = random_atoms(d=d, r=r) Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) - C = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) C_np = nx.to_numpy(C) np.testing.assert_allclose(C_np.shape, (r, r)) assert np.all(np.isfinite(C_np)) and np.all(C_np >= 0) try: - dist = metric(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + dist = sgot_metric(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) dist_np = nx.to_numpy(dist) assert np.isfinite(dist_np) and dist_np >= 0 except TypeError: - pytest.skip("metric() unavailable (emd_c signature mismatch)") + pytest.skip("sgot_metric() unavailable (emd_c signature mismatch)") # --------------------------------------------------------------------- @@ -113,7 +113,7 @@ def test_cost_self_zero(nx): """(D_S R_S L_S D_S): diagonal of cost matrix (same atom to same atom) should be near zero.""" Ds, Rs, Ls, _, _, _ = random_atoms() Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) - C = cost(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2) + C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2) C_np = nx.to_numpy(C) np.testing.assert_allclose(np.diag(C_np), np.zeros(C_np.shape[0]), atol=1e-10) np.testing.assert_allclose(C_np, C_np.T, atol=1e-10) @@ -124,8 +124,8 @@ def test_cost_reference(nx): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) eta, p, q = 0.5, 2, 1 - C1 = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) - C2 = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + C1 = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + C2 = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) np.testing.assert_allclose(nx.to_numpy(C1), nx.to_numpy(C2), atol=1e-12) @@ -135,7 +135,9 @@ def test_cost_reference(nx): def test_cost_basic(grassman_metric, nx): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) - C = cost(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, grassman_metric=grassman_metric) + C = sgot_cost_matrix( + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, grassman_metric=grassman_metric + ) C_np = nx.to_numpy(C) assert C_np.shape == (Ds.shape[0], Dt.shape[0]) assert np.all(np.isfinite(C_np)) @@ -146,10 +148,10 @@ def test_cost_validation(): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() with pytest.raises(ValueError): - cost(Ds.reshape(-1, 1), Rs, Ls, Dt, Rt, Lt) + sgot_cost_matrix(Ds.reshape(-1, 1), Rs, Ls, Dt, Rt, Lt) with pytest.raises(ValueError): - cost(Ds, Rs[:, :-1], Ls, Dt, Rt, Lt) + sgot_cost_matrix(Ds, Rs[:, :-1], Ls, Dt, Rt, Lt) # --------------------------------------------------------------------- @@ -159,15 +161,15 @@ def test_cost_validation(): def test_metric_self_zero(): Ds, Rs, Ls, _, _, _ = random_atoms() - dist = metric(Ds, Rs, Ls, Ds, Rs, Ls) + dist = sgot_metric(Ds, Rs, Ls, Ds, Rs, Ls) assert np.isfinite(dist) - assert abs(dist) < 2e-4 + assert abs(dist) < 5e-4 def test_metric_symmetry(): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() - d1 = metric(Ds, Rs, Ls, Dt, Rt, Lt) - d2 = metric(Dt, Rt, Lt, Ds, Rs, Ls) + d1 = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt) + d2 = sgot_metric(Dt, Rt, Lt, Ds, Rs, Ls) np.testing.assert_allclose(d1, d2, atol=1e-8) @@ -182,7 +184,7 @@ def test_metric_with_weights(): Wt = rng.rand(r) Wt = Wt / np.sum(Wt) - dist = metric(Ds, Rs, Ls, Dt, Rt, Lt, Ws=Ws, Wt=Wt) + dist = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt, Ws=Ws, Wt=Wt) assert np.isfinite(dist) @@ -201,7 +203,7 @@ def test_hyperparameter_sweep_cost(nx): p = rng.choice([1, 2]) q = rng.choice([1, 2]) gm = rng.choice(grassmann_types) - C = cost( + C = sgot_cost_matrix( Ds_b, Rs_b, Ls_b, @@ -230,7 +232,7 @@ def test_hyperparameter_sweep(grassman_metric): q = rng.choice([1, 2]) r = rng.choice([1, 2]) - dist = metric( + dist = sgot_metric( Ds, Rs, Ls, From e994303f7f0a08e4b6b04cd72d5b51d453512056 Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Fri, 27 Feb 2026 14:21:09 +0100 Subject: [PATCH 09/10] correct issues on test_sgot --- test/test_sgot.py | 52 +++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/test/test_sgot.py b/test/test_sgot.py index 10c9cb70e..267d43a1c 100644 --- a/test/test_sgot.py +++ b/test/test_sgot.py @@ -41,7 +41,7 @@ def _rand_complex(shape, seed_): def test_random_d_r(nx): - """Sample d and r uniformly and run cost (and metric when available) with those shapes.""" + """Sample d and r uniformly and run sgot_cost_matrix (and sgot_metric when available) with those shapes.""" rng = np.random.RandomState(0) d_min, d_max = 4, 12 r_min, r_max = 2, 6 @@ -67,14 +67,22 @@ def test_random_d_r(nx): # --------------------------------------------------------------------- -def test_delta_identity(): +def test_eigenvalue_cost_matrix_simple(): + Ds = np.array([0.0, 1.0]) + Dt = np.array([0.0, 2.0]) + C = eigenvalue_cost_matrix(Ds, Dt, q=2) + expected = np.array([[0.0, 4.0], [1.0, 1.0]]) + np.testing.assert_allclose(C, expected) + + +def test_delta_matrix_1d_identity(): r = 4 I = np.eye(r, dtype=complex) delta = _delta_matrix_1d(I, I, I, I) np.testing.assert_allclose(delta, np.eye(r), atol=1e-12) -def test_delta_swap_invariance(): +def test_delta_matrix_1d_swap_invariance(): d, r = 6, 3 _, R, _, _, _, _ = random_atoms(d=d, r=r) L = R.copy() @@ -98,7 +106,7 @@ def test_grassmann_zero_distance(grassman_metric, nx): np.testing.assert_allclose(dist2_np, 0.0, atol=1e-12) -def test_grassmann_invalid_name(): +def test_grassmann_distance_invalid_name(): delta = np.ones((2, 2)) with pytest.raises(ValueError): _grassmann_distance_squared(delta, grassman_metric="cordal") @@ -110,7 +118,7 @@ def test_grassmann_invalid_name(): def test_cost_self_zero(nx): - """(D_S R_S L_S D_S): diagonal of cost matrix (same atom to same atom) should be near zero.""" + """(D_S R_S L_S D_S): diagonal of sgot_cost_matrix matrix (same atom to same atom) should be near zero.""" Ds, Rs, Ls, _, _, _ = random_atoms() Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2) @@ -119,7 +127,7 @@ def test_cost_self_zero(nx): np.testing.assert_allclose(C_np, C_np.T, atol=1e-10) -def test_cost_reference(nx): +def test_grassmann_cost_reference(nx): """Cost with same inputs and HPs should be deterministic (np.testing.assert_allclose).""" Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) @@ -132,7 +140,7 @@ def test_cost_reference(nx): @pytest.mark.parametrize( "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] ) -def test_cost_basic(grassman_metric, nx): +def test_grassmann_cost_basic_properties(grassman_metric, nx): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) C = sgot_cost_matrix( @@ -144,7 +152,7 @@ def test_cost_basic(grassman_metric, nx): assert np.all(C_np >= 0) -def test_cost_validation(): +def test_sgot_cost_input_validation(): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() with pytest.raises(ValueError): @@ -159,21 +167,21 @@ def test_cost_validation(): # --------------------------------------------------------------------- -def test_metric_self_zero(): +def test_sgot_metric_self_zero(): Ds, Rs, Ls, _, _, _ = random_atoms() dist = sgot_metric(Ds, Rs, Ls, Ds, Rs, Ls) assert np.isfinite(dist) assert abs(dist) < 5e-4 -def test_metric_symmetry(): +def test_sgot_metric_symmetry(): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() d1 = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt) d2 = sgot_metric(Dt, Rt, Lt, Ds, Rs, Ls) np.testing.assert_allclose(d1, d2, atol=1e-8) -def test_metric_with_weights(): +def test_sgot_metric_with_weights(): Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() r = Ds.shape[0] @@ -193,16 +201,20 @@ def test_metric_with_weights(): # --------------------------------------------------------------------- -def test_hyperparameter_sweep_cost(nx): - """Sweep over a random set of HPs and run cost().""" - grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] +@pytest.mark.parametrize( + "eta, p, q, grassman_metric", + [ + (0.5, 1, 1, "geodesic"), + (0.5, 2, 1, "chordal"), + (0.3, 2, 2, "procrustes"), + (0.7, 1, 2, "martin"), + ], +) +def test_hyperparameter_sweep_cost(nx, eta, p, q, grassman_metric): + """Sweep over a set of fixed HPs and run cost().""" Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) - rng = np.random.RandomState(2) - eta = rng.uniform(0.0, 1.0) - p = rng.choice([1, 2]) - q = rng.choice([1, 2]) - gm = rng.choice(grassmann_types) + C = sgot_cost_matrix( Ds_b, Rs_b, @@ -213,7 +225,7 @@ def test_hyperparameter_sweep_cost(nx): eta=eta, p=p, q=q, - grassman_metric=gm, + grassman_metric=grassman_metric, ) C_np = nx.to_numpy(C) assert C_np.shape == (Ds.shape[0], Dt.shape[0]) From e9f8be7abcf06fb3ff1f367bb944387fe3012c7a Mon Sep 17 00:00:00 2001 From: Sienna O'Shea Date: Fri, 27 Feb 2026 14:50:29 +0100 Subject: [PATCH 10/10] fixing test failures --- ot/sgot.py | 13 ++++++++----- test/test_sgot.py | 10 ++++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/ot/sgot.py b/ot/sgot.py index d443c3490..c8ee5b91c 100644 --- a/ot/sgot.py +++ b/ot/sgot.py @@ -271,17 +271,20 @@ def sgot_cost_matrix( if Ds.ndim != 1 or Dt.ndim != 1: raise ValueError( - f"sgot_cost_matrix() expects Ds, Dt 1D; got Ds {getattr(Ds,'shape',None)}, Dt {getattr(Dt,'shape',None)}" + f"sgot_cost_matrix() expects Ds, Dt 1D; " + f"got Ds {getattr(Ds, 'shape', None)}, Dt {getattr(Dt, 'shape', None)}" ) if Rs.shape != Ls.shape or Rt.shape != Lt.shape: raise ValueError( - f"Right/left eigenvector shapes must match; got (Rs,Ls)=({Rs.shape},{Ls.shape}), (Rt,Lt)=({Rt.shape},{Lt.shape})" + "Right/left eigenvector shapes must match; got " + f"(Rs,Ls)=({Rs.shape},{Ls.shape}), (Rt,Lt)=({Rt.shape},{Lt.shape})" ) if Rs.shape[1] != Ds.shape[0] or Rt.shape[1] != Dt.shape[0]: raise ValueError( - f"Eigenvectors columns must match eigenvalues: Rs {Rs.shape[1]} vs Ds {Ds.shape[0]}, " + "Eigenvector columns must match eigenvalues: " + f"Rs {Rs.shape[1]} vs Ds {Ds.shape[0]}, " f"Rt {Rt.shape[1]} vs Dt {Dt.shape[0]}" ) @@ -420,6 +423,6 @@ def sgot_metric( Ws = Ws / nx.sum(Ws) Wt = Wt / nx.sum(Wt) - P = ot.emd2(Ws, Wt, nx.real(C)) - obj = float(nx.sum(C * P) ** (1.0 / p)) + obj = ot.emd2(Ws, Wt, nx.real(C)) + obj = obj ** (1.0 / p) return obj ** (1.0 / float(r)) diff --git a/test/test_sgot.py b/test/test_sgot.py index 267d43a1c..af64e20fb 100644 --- a/test/test_sgot.py +++ b/test/test_sgot.py @@ -167,11 +167,13 @@ def test_sgot_cost_input_validation(): # --------------------------------------------------------------------- -def test_sgot_metric_self_zero(): +def test_sgot_metric_self_zero(nx): Ds, Rs, Ls, _, _, _ = random_atoms() - dist = sgot_metric(Ds, Rs, Ls, Ds, Rs, Ls) - assert np.isfinite(dist) - assert abs(dist) < 5e-4 + Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) + dist = sgot_metric(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2, nx=nx) + dist_np = nx.to_numpy(dist) + assert np.isfinite(dist_np) + assert abs(float(dist_np)) < 5e-4 def test_sgot_metric_symmetry():