diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..f320c7e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,76 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [3.0.0] - UNRELEASED + +### Breaking + +- **`sc-sam` removed as a dependency.** The SAM algorithm is now vendored + under `samap.sam`. All internal imports route through `samap.sam` — no + external SAM package is installed or required. If you were importing + `samalg` directly, switch to `samap.sam`. +- `_smart_expand` default switched from matrix-power to BFS. Produces + slightly different marginal neighbours (~1% edge difference on the + golden-suite data) — the matpow path wasted one budget slot per cell on + a self-loop artefact. Pass `legacy=True` for bit-exact 2.x reproduction. + +### Added + +- **GPU backend** via `SAMAP(backend="auto"|"cpu"|"cuda")`. Dispatches + numpy/scipy ↔ cupy/cupyx, hnswlib ↔ FAISS for kNN, and scanpy ↔ + rapids-singlecell for Leiden/UMAP. Install with `pip install sc-samap[gpu]` + (see `docs/performance.md` for conda details). `"auto"` picks CUDA if + available, else CPU. +- **N² → N-linear memory rewrites** (see `docs/performance.md` for the full + model): + - *Precomposed feature translation* — projection precomposes + `G · diag(W/σ) · PCs` so the cells × genes `Xtr` intermediate is never + materialised. Iteration-invariant state (`XᵀX`, means, own-species + projection) is computed once. ~2× wall and ~2× memory on the benchmark + suite; gains grow with N. + - *Streaming mutual-NN* — coarsening streams per-species-pair blocks + directly into a CSR builder instead of materialising dense N × N products. + - *Batched correlation refinement* — streams gene-pair batches + (default `batch_size=512`); computes only the columns of the smoothed + expression matrix referenced per batch. Peak memory drops from + O(N × G_active) to O(N × 1024). ~4× less memory; ~3-5× slower on small + data where the full matrix fits — pass `batch_size=None` to opt out. + - *BFS neighbourhood expansion* — numba BFS kernel replaces matrix-power + `_smart_expand`. ~5× faster at 3k cells, memory-bounded. +- **Randomized SVD with implicit centering** for sparse PCA — available + via `svd_solver="randomized"` on `samap.sam.pca._pca_with_sparse`. Faster + on GPU and at high PC counts; slightly different numerics. Default remains + ARPACK. +- **Phase-level benchmark suite** — `benchmarks/bench_samap.py` compares + legacy vs optimized paths for each rewritten phase. +- `docs/performance.md` — memory model, backend selection, tuning, scaling + estimates. + +### Fixed + +- Dead random-walk computation in `_mapper` (result written then immediately + discarded; preserved only the binarization side effect). +- `thr` → `align_thr` kwarg misroute in `analysis.enrichment` (was falling + through to an unrelated p-value threshold). +- Deprecated `.A` matrix attribute → `np.asarray()` in several hot paths. +- Stale root `setup.py` removed (pyproject.toml is authoritative). +- Broken `SAMGUI` import and dead `gui()` method removed. +- Duplicated `_q` helper consolidated into `samap.utils.q`. +- Dead `mdata['xsim']` store removed. +- `__version__` is now dynamic via `importlib.metadata`. + +### Changed + +- `src/samap/core/mapping.py` split into focused modules: `homology.py`, + `correlation.py`, `projection.py`, `coarsening.py`, `expand.py`. The + `SAMAP` class remains in `mapping.py`; all existing imports work unchanged. +- `_refine_corr` / `_refine_corr_parallel` default `batch_size` changed + from `None` (materialized) to `512` (streaming). +- `_smart_expand` default `legacy` changed from `True` (matpow) to + `False` (BFS). +- Golden regression fixture regenerated to reflect the BFS and streaming + defaults. diff --git a/benchmarks/bench_samap.py b/benchmarks/bench_samap.py new file mode 100644 index 0000000..a001950 --- /dev/null +++ b/benchmarks/bench_samap.py @@ -0,0 +1,644 @@ +#!/usr/bin/env python +"""SAMap optimization benchmark — Phase 3 wins, measured. + +Compares legacy vs optimized code paths for the three SAMap iteration +phases that were rewritten in Phase 3: + +* **expand** — matrix-power `_smart_expand` vs BFS `_smart_expand` +* **projection** — per-iter `_mapping_window` (rebuilds precompute) vs + precompute-once + `_mapping_window_fast` +* **correlation** — materialized `_compute_pair_corrs(batch_size=None)` + vs streaming `batch_size=int` + +Each benchmark is phase-level (direct function call, no full SAM mock). +The toggles aren't plumbed through `_mapper` yet, so phase-level is the +cleanest way to attribute speedup to each optimization. + +Usage +----- +:: + + python benchmarks/bench_samap.py --max-cells 3000 + python benchmarks/bench_samap.py --max-cells 30000 --phases expand,projection + python benchmarks/plot_bench.py benchmarks/results/bench_.csv + +Memory measurement +------------------ +Uses ``tracemalloc`` peak. This tracks numpy/scipy sparse allocations +(the bulk of what the legacy paths materialize) but misses allocations +inside numba-nopython kernels. For the configs benchmarked here, the +dominant memory cost is sparse-matrix intermediates on the legacy paths +— visible to tracemalloc. +""" + +from __future__ import annotations + +import argparse +import csv +import logging +import sys +import time +import tracemalloc +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import scipy.sparse as spp +from sklearn.neighbors import kneighbors_graph + +# --------------------------------------------------------------------------- +# SAMap imports — the functions under test. +# --------------------------------------------------------------------------- +from samap.core.correlation import _compute_pair_corrs +from samap.core.expand import _smart_expand +from samap.core.projection import ( + _mapping_window, + _mapping_window_fast, + _projection_precompute, +) + +# Silence the INFO logging from projection.py — distracts from timing output. +logging.getLogger("samap").setLevel(logging.WARNING) + +# --------------------------------------------------------------------------- +# Measurement infrastructure +# --------------------------------------------------------------------------- + + +@dataclass(slots=True) +class Result: + """One (phase, config, scale) measurement.""" + + n_cells: int + phase: str + config: str + wall_time_s: float + peak_mem_mb: float + n_iters: int # how many iterations of the measured step + + +@contextmanager +def measure() -> Iterator[dict[str, float]]: + """Time + peak-memory context. + + Starts a fresh tracemalloc window. On exit, the yielded dict holds + ``wall_time_s`` and ``peak_mem_mb`` (tracemalloc peak in MiB). + """ + # Garbage-collect first so we measure the benchmark, not leftovers. + import gc + + gc.collect() + tracemalloc.start() + t0 = time.perf_counter() + box: dict[str, float] = {} + try: + yield box + finally: + wall = time.perf_counter() - t0 + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + box["wall_time_s"] = wall + box["peak_mem_mb"] = peak / (1024 * 1024) + + +def run_warmup(fn: Callable[[], Any]) -> None: + """Run ``fn`` once and discard. Forces numba JIT so it's excluded from timings.""" + fn() + + +# --------------------------------------------------------------------------- +# Synthetic data — per phase +# --------------------------------------------------------------------------- + + +def synth_knn_graph( + n_cells: int, + k: int, + n_clusters: int, + rng: np.random.Generator, +) -> tuple[spp.csr_matrix, np.ndarray]: + """Weighted kNN graph over blob-structured points. + + Mimics a scanpy ``connectivities`` matrix: symmetric, weights in + (0, 1], diagonal absent, cluster-local neighbourhood structure. + Returns the graph and an integer cluster-label array. + """ + centres = ( + np.stack( + [ + np.cos(2 * np.pi * np.arange(n_clusters) / n_clusters), + np.sin(2 * np.pi * np.arange(n_clusters) / n_clusters), + ], + axis=1, + ) + * 10.0 + ) + labels = rng.integers(0, n_clusters, size=n_cells) + pts = centres[labels] + rng.normal(scale=1.0, size=(n_cells, 2)) + + A = kneighbors_graph(pts, n_neighbors=k, mode="distance", include_self=False) + A.data = np.exp(-A.data / A.data.mean()) + A = A.maximum(A.T).tocsr() + A.setdiag(0.0) + A.eliminate_zeros() + return A, labels + + +class _MockAdata: + """Minimal adata stand-in for projection benchmarks. + + Provides only the fields ``_projection_precompute`` reads: ``X``, + ``var_names``, ``var['weights']``, ``varm['PCs_SAMap']``, plus gene-name + column slicing via ``__getitem__``. + """ + + __slots__ = ("X", "_name_to_ix", "var", "var_names", "varm") + + def __init__( + self, + X: spp.csr_matrix, + var_names: np.ndarray, + weights: np.ndarray, + PCs: np.ndarray, + ) -> None: + self.X = X + self.var_names = var_names + self.var = pd.DataFrame({"weights": weights}, index=var_names) + self.varm = {"PCs_SAMap": PCs} + self._name_to_ix = {n: i for i, n in enumerate(var_names)} + + def __getitem__(self, key: tuple) -> _MockAdata: + _, cols = key + ix = np.array([self._name_to_ix[c] for c in cols]) + return _MockAdata( + X=self.X[:, ix], + var_names=self.var_names[ix], + weights=self.var["weights"].values[ix], + PCs=self.varm["PCs_SAMap"][ix], + ) + + +class _MockSAM: + __slots__ = ("adata",) + + def __init__(self, adata: _MockAdata) -> None: + self.adata = adata + + +def synth_species_pair( + n_cells_per_species: int, + n_genes_per_species: int, + n_homology_edges_per_gene: int, + density: float, + npcs: int, + rng: np.random.Generator, +) -> dict[str, Any]: + """Build two mock species + a cross-species homology graph. + + Returns the pieces needed by projection and correlation benchmarks: + per-species SAM mocks, concatenated gene names, and the sparse + block-off-diagonal homology matrix. + """ + sids = ["aa", "bb"] + sams: dict[str, _MockSAM] = {} + gns_list: list[np.ndarray] = [] + + for sid in sids: + var_names = np.array([f"{sid}_gene{i:05d}" for i in range(n_genes_per_species)]) + X = spp.random( + n_cells_per_species, + n_genes_per_species, + density=density, + format="csr", + random_state=rng.integers(1 << 31), + dtype=np.float32, + ) + X.data *= 10 # roughly count-scale + weights = rng.uniform(0.1, 1.0, n_genes_per_species) + PCs = rng.standard_normal((n_genes_per_species, npcs), dtype=np.float64) + sams[sid] = _MockSAM(_MockAdata(X, var_names, weights, PCs)) + gns_list.append(var_names) + + gns = np.concatenate(gns_list) + G = gns.size + + # Block-off-diagonal homology: species-a genes → species-b genes only. + # ~n_homology_edges_per_gene outgoing edges per gene, random targets, + # strictly-positive weights. Symmetrise by max. + g_a = n_genes_per_species + n_edges = g_a * n_homology_edges_per_gene + src = rng.integers(0, g_a, size=n_edges) + dst = rng.integers(g_a, G, size=n_edges) + w = rng.uniform(0.01, 1.0, size=n_edges) + gnnm = spp.csr_matrix((w, (src, dst)), shape=(G, G)) + gnnm = gnnm.maximum(gnnm.T) + + return { + "sams": sams, + "gns": gns, + "gns_list": gns_list, + "gnnm": gnnm, + "sids": sids, + } + + +def synth_correlation_inputs( + n_cells_per_species: int, + n_genes_per_species: int, + n_pairs: int, + knn_k: int, + density: float, + rng: np.random.Generator, +) -> dict[str, Any]: + """Build inputs for ``_compute_pair_corrs``. + + Two-species block-diagonal expression, row-normalised averaging + operator, and a set of cross-species gene pairs. Layout matches what + ``_refine_corr_parallel`` feeds into the kernel. + """ + n_a, n_b = n_cells_per_species, n_cells_per_species + g_a, g_b = n_genes_per_species, n_genes_per_species + N = n_a + n_b + G = g_a + g_b + + # Averaging operator: random sparse kNN, self-loops guaranteed, + # row-normalised. Keeping it simple (no cluster structure — doesn't + # affect timing, only numerical output). + rows = np.repeat(np.arange(N), knn_k) + cols = rng.integers(0, N, size=N * knn_k) + knn = spp.csr_matrix((np.ones(N * knn_k), (rows, cols)), shape=(N, N)) + knn.setdiag(1.0) + knn.sum_duplicates() + rs = np.asarray(knn.sum(1)).flatten() + nnms = knn.multiply(1.0 / rs[:, None]).tocsr() + + # Block-diagonal expression. Use float32 to match production. + Xa = spp.random( + n_a, + g_a, + density=density, + format="csr", + random_state=rng.integers(1 << 31), + dtype=np.float32, + ) + Xb = spp.random( + n_b, + g_b, + density=density, + format="csr", + random_state=rng.integers(1 << 31), + dtype=np.float32, + ) + Xs = spp.block_diag([Xa, Xb]).tocsc() + + # Cross-species gene pairs. + p1 = rng.integers(0, g_a, size=n_pairs) + p2 = rng.integers(g_a, G, size=n_pairs) + p = np.column_stack((p1, p2)).astype(np.int64) + ps_int = np.column_stack((np.zeros(n_pairs, dtype=np.int64), np.ones(n_pairs, dtype=np.int64))) + + return { + "nnms": nnms, + "Xs": Xs, + "p": p, + "ps_int": ps_int, + "sp_starts": np.array([0, n_a], dtype=np.int64), + "sp_lens": np.array([n_a, n_b], dtype=np.int64), + "n": N, + } + + +# --------------------------------------------------------------------------- +# Phase benchmarks +# --------------------------------------------------------------------------- + + +def bench_expand( + n_cells: int, + *, + rng: np.random.Generator, + n_iters: int = 3, +) -> list[Result]: + """Legacy matrix-power vs BFS neighbourhood expansion.""" + print(f" [expand] building kNN graph: n={n_cells}", file=sys.stderr) + nnm, labels = synth_knn_graph(n_cells, k=20, n_clusters=max(8, n_cells // 200), rng=rng) + _, ix, counts = np.unique(labels, return_inverse=True, return_counts=True) + K = counts[ix].astype(np.int64) + + # JIT warmup — both paths have numba kernels / scipy compile paths. + run_warmup(lambda: _smart_expand(nnm, K.copy(), NH=3, legacy=True)) + run_warmup(lambda: _smart_expand(nnm, K.copy(), NH=3, legacy=False)) + + results = [] + + # --- Legacy: matrix powers -------------------------------------------- + with measure() as m: + for _ in range(n_iters): + _smart_expand(nnm, K.copy(), NH=3, legacy=True) + results.append( + Result( + n_cells=n_cells, + phase="expand", + config="legacy", + wall_time_s=m["wall_time_s"], + peak_mem_mb=m["peak_mem_mb"], + n_iters=n_iters, + ) + ) + + # --- Optimized: BFS --------------------------------------------------- + with measure() as m: + for _ in range(n_iters): + _smart_expand(nnm, K.copy(), NH=3, legacy=False) + results.append( + Result( + n_cells=n_cells, + phase="expand", + config="optimized", + wall_time_s=m["wall_time_s"], + peak_mem_mb=m["peak_mem_mb"], + n_iters=n_iters, + ) + ) + + return results + + +def bench_projection( + n_cells: int, + *, + rng: np.random.Generator, + n_iters: int = 3, +) -> list[Result]: + """Per-iter precompute rebuild vs precompute-once + fast path. + + The legacy path calls ``_mapping_window`` which internally rebuilds the + projection precompute every call. The optimized path builds the + precompute once (outside the timed loop — iteration-invariant) and + calls ``_mapping_window_fast`` per iteration. + + This benchmark measures *per-iteration* cost — the one-time precompute + cost for the optimized path is not charged to any iteration since + SAMap runs 3+ iterations and amortizes it. + """ + print( + f" [projection] building species pair: n={n_cells}/species, genes=5000/species", + file=sys.stderr, + ) + synth = synth_species_pair( + n_cells_per_species=n_cells, + n_genes_per_species=5000, + n_homology_edges_per_gene=5, + density=0.08, + npcs=50, + rng=rng, + ) + sams, gns, gnnm = synth["sams"], synth["gns"], synth["gnnm"] + + # JIT / cache warmup. + run_warmup(lambda: _mapping_window(sams, gnnm, gns, K=20)) + + results = [] + + # --- Legacy: _mapping_window rebuilds precompute each call ------------ + with measure() as m: + for _ in range(n_iters): + _mapping_window(sams, gnnm, gns, K=20) + results.append( + Result( + n_cells=n_cells, + phase="projection", + config="legacy", + wall_time_s=m["wall_time_s"], + peak_mem_mb=m["peak_mem_mb"], + n_iters=n_iters, + ) + ) + + # --- Optimized: precompute once, fast path per iter ------------------- + # Precompute outside the measurement window — it's iteration-invariant. + pre = _projection_precompute(sams, gns) + with measure() as m: + for _ in range(n_iters): + _mapping_window_fast(gnnm, pre, K=20) + results.append( + Result( + n_cells=n_cells, + phase="projection", + config="optimized", + wall_time_s=m["wall_time_s"], + peak_mem_mb=m["peak_mem_mb"], + n_iters=n_iters, + ) + ) + + return results + + +def bench_correlation( + n_cells: int, + *, + rng: np.random.Generator, + n_iters: int = 1, +) -> list[Result]: + """Materialised Xavg vs streaming batched correlation. + + Legacy: ``batch_size=None`` materialises the full N × G_active smoothed + matrix. Optimized: ``batch_size`` streams pair-batches and computes only + the columns needed per batch. + """ + print( + f" [correlation] building inputs: n={n_cells}/species, 2000 genes/species, 8000 pairs", + file=sys.stderr, + ) + inp = synth_correlation_inputs( + n_cells_per_species=n_cells, + n_genes_per_species=2000, + n_pairs=8000, + knn_k=15, + density=0.08, + rng=rng, + ) + + args = ( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["n"], + ) + + # Warmup (both paths share the numba _corr_kernel). + run_warmup(lambda: _compute_pair_corrs(*args, corr_mode="pearson", batch_size=None)) + + results = [] + + # --- Legacy: materialised Xavg ---------------------------------------- + with measure() as m: + for _ in range(n_iters): + _compute_pair_corrs(*args, corr_mode="pearson", batch_size=None) + results.append( + Result( + n_cells=n_cells, + phase="correlation", + config="legacy", + wall_time_s=m["wall_time_s"], + peak_mem_mb=m["peak_mem_mb"], + n_iters=n_iters, + ) + ) + + # --- Optimized: streaming batches ------------------------------------- + with measure() as m: + for _ in range(n_iters): + _compute_pair_corrs(*args, corr_mode="pearson", batch_size=512) + results.append( + Result( + n_cells=n_cells, + phase="correlation", + config="optimized", + wall_time_s=m["wall_time_s"], + peak_mem_mb=m["peak_mem_mb"], + n_iters=n_iters, + ) + ) + + return results + + +BENCHMARKS: dict[str, Callable[..., list[Result]]] = { + "expand": bench_expand, + "projection": bench_projection, + "correlation": bench_correlation, +} + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + + +def write_csv(path: Path, results: list[Result]) -> None: + """Dump results to CSV — one row per (scale, phase, config).""" + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(["n_cells", "phase", "config", "wall_time_s", "peak_mem_mb", "n_iters"]) + for r in results: + w.writerow( + [ + r.n_cells, + r.phase, + r.config, + f"{r.wall_time_s:.6f}", + f"{r.peak_mem_mb:.3f}", + r.n_iters, + ] + ) + + +def print_summary(results: list[Result]) -> None: + """Pretty-print speedup / memory-reduction table to stderr.""" + # Group by (phase, n_cells) → compare legacy vs optimized. + grouped: dict[tuple[str, int], dict[str, Result]] = {} + for r in results: + key = (r.phase, r.n_cells) + grouped.setdefault(key, {})[r.config] = r + + print("\n=== Summary ===", file=sys.stderr) + print( + f"{'phase':<12} {'n_cells':>8} {'legacy_s':>10} {'opt_s':>10} " + f"{'speedup':>8} {'legacy_mb':>10} {'opt_mb':>10} {'mem_ratio':>9}", + file=sys.stderr, + ) + for (phase, n_cells), by_cfg in sorted(grouped.items()): + if "legacy" not in by_cfg or "optimized" not in by_cfg: + continue + leg, opt = by_cfg["legacy"], by_cfg["optimized"] + speedup = leg.wall_time_s / opt.wall_time_s if opt.wall_time_s > 0 else float("inf") + mem_ratio = leg.peak_mem_mb / opt.peak_mem_mb if opt.peak_mem_mb > 0 else float("inf") + print( + f"{phase:<12} {n_cells:>8d} {leg.wall_time_s:>10.3f} " + f"{opt.wall_time_s:>10.3f} {speedup:>7.1f}x " + f"{leg.peak_mem_mb:>10.1f} {opt.peak_mem_mb:>10.1f} {mem_ratio:>8.1f}x", + file=sys.stderr, + ) + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + p.add_argument( + "--max-cells", + type=int, + default=10000, + help="stop at the first scale whose n_cells exceeds this (default: 10000)", + ) + p.add_argument( + "--phases", + default="expand,projection,correlation", + help="comma-separated list of phases to run (expand,projection,correlation; default: all)", + ) + p.add_argument( + "--out", + type=Path, + default=None, + help="output CSV path (default: benchmarks/results/bench_.csv)", + ) + p.add_argument( + "--seed", + type=int, + default=42, + help="RNG seed for synthetic data (default: 42)", + ) + args = p.parse_args(argv) + + phases = [p.strip() for p in args.phases.split(",") if p.strip()] + for ph in phases: + if ph not in BENCHMARKS: + print( + f"error: unknown phase {ph!r}; choose from {sorted(BENCHMARKS)}", + file=sys.stderr, + ) + return 2 + + scales = [s for s in [1000, 3000, 10000, 30000] if s <= args.max_cells] + if not scales: + scales = [args.max_cells] + + ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + out_path = args.out or (Path(__file__).parent / "results" / f"bench_{ts}.csv") + + rng = np.random.default_rng(args.seed) + all_results: list[Result] = [] + + for n in scales: + for ph in phases: + print(f"[scale n_cells={n}] {ph}", file=sys.stderr) + try: + rs = BENCHMARKS[ph](n, rng=rng) + except MemoryError: + print( + f" OOM at n_cells={n}, phase={ph} — stopping this phase", + file=sys.stderr, + ) + continue + all_results.extend(rs) + + write_csv(out_path, all_results) + print(f"\nResults → {out_path}", file=sys.stderr) + print_summary(all_results) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/plot_bench.py b/benchmarks/plot_bench.py new file mode 100644 index 0000000..3efa844 --- /dev/null +++ b/benchmarks/plot_bench.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +"""Plot SAMap benchmark results — log-log scaling curves. + +Reads a CSV produced by ``bench_samap.py`` and renders one figure with +two rows (time, memory) and one column per benchmark phase. Each subplot +overlays the legacy and optimized curves on log-log axes so the scaling +difference is visible. + +Usage +----- +:: + + python benchmarks/plot_bench.py benchmarks/results/bench_.csv + python benchmarks/plot_bench.py latest # most recent CSV in results/ + python benchmarks/plot_bench.py bench.csv -o scaling.png +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +_CONFIG_STYLE = { + "legacy": {"marker": "o", "linestyle": "-", "color": "#d62728"}, + "optimized": {"marker": "s", "linestyle": "-", "color": "#2ca02c"}, +} + + +def _resolve_csv(arg: str) -> Path: + """Resolve 'latest' to the most recent bench CSV in results/.""" + if arg == "latest": + results_dir = Path(__file__).parent / "results" + candidates = sorted(results_dir.glob("bench_*.csv")) + if not candidates: + raise FileNotFoundError(f"no bench_*.csv files in {results_dir}") + return candidates[-1] + return Path(arg) + + +def plot(csv_path: Path, out_path: Path | None = None) -> Path: + """Render scaling curves from a benchmark CSV. + + Returns the written figure path (derived from ``csv_path`` if ``out_path`` + is None). + """ + df = pd.read_csv(csv_path) + + # Normalise timings to per-iteration so scales are comparable across phases + # that use different n_iters. + df["time_per_iter_s"] = df["wall_time_s"] / df["n_iters"] + + phases = sorted(df["phase"].unique()) + n_ph = len(phases) + + fig, axes = plt.subplots( + 2, + n_ph, + figsize=(4.5 * n_ph, 8), + squeeze=False, + sharex="col", + ) + + for j, phase in enumerate(phases): + sub = df[df["phase"] == phase].sort_values("n_cells") + + # --- Row 0: wall time --- + ax_t = axes[0, j] + for cfg, style in _CONFIG_STYLE.items(): + s = sub[sub["config"] == cfg] + if s.empty: + continue + ax_t.loglog( + s["n_cells"], + s["time_per_iter_s"], + label=cfg, + **style, + ) + ax_t.set_title(phase, fontsize=12, fontweight="bold") + ax_t.set_ylabel("wall time / iter (s)") + ax_t.grid(True, which="both", ls="--", alpha=0.3) + if j == 0: + ax_t.legend(loc="upper left", fontsize=10) + + # --- Row 1: peak memory --- + ax_m = axes[1, j] + for cfg, style in _CONFIG_STYLE.items(): + s = sub[sub["config"] == cfg] + if s.empty: + continue + ax_m.loglog( + s["n_cells"], + s["peak_mem_mb"], + label=cfg, + **style, + ) + ax_m.set_xlabel("n_cells") + ax_m.set_ylabel("peak memory (MiB)") + ax_m.grid(True, which="both", ls="--", alpha=0.3) + + fig.suptitle( + f"SAMap Phase-3 optimizations — {csv_path.name}", + fontsize=13, + ) + fig.tight_layout(rect=(0, 0, 1, 0.96)) + + if out_path is None: + out_path = csv_path.with_suffix(".png") + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + return out_path + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "csv", + help="path to bench CSV, or 'latest' for most recent in results/", + ) + p.add_argument( + "-o", + "--out", + type=Path, + default=None, + help="output figure path (default: same as CSV with .png suffix)", + ) + args = p.parse_args(argv) + + csv_path = _resolve_csv(args.csv) + out = plot(csv_path, args.out) + print(f"Figure → {out}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/results/.gitignore b/benchmarks/results/.gitignore new file mode 100644 index 0000000..8f4da06 --- /dev/null +++ b/benchmarks/results/.gitignore @@ -0,0 +1,3 @@ +# Timestamped bench outputs — keep the directory, not the contents. +*.csv +*.png diff --git a/docs/performance.md b/docs/performance.md new file mode 100644 index 0000000..f42911c --- /dev/null +++ b/docs/performance.md @@ -0,0 +1,237 @@ +# SAMap Performance Guide + +SAMap 3.0 is a substantial rewrite of the core algorithm's memory model. +The algorithm is mathematically unchanged, but every step that previously +materialized an **N × G** or **N × N** intermediate has been rewritten to +stream or precompose. This document explains the new memory model, how to +select and tune the compute backend, and what to expect when scaling. + +--- + +## Summary: what changed in 3.0 + +| Phase | Legacy (≤2.x) | 3.0 | Memory complexity | +|---|---|---|---| +| **Feature translation** | Materialize `Xtr = X @ G` (cells × genes, ~30% dense) per iteration | Precompose projection, one SpMM per species pair | O(N × G) → O(N × npcs) | +| **Neighbourhood expansion** | Matrix powers + LIL zeroing | BFS with per-row budget, numba kernel | O(N²) peak → O(N · k · NH) | +| **Mutual-NN stitching** | Full N × N products, chunked | Streaming per-species-pair with direct CSR output | O(N²) dense peak → O(N · k) sparse | +| **Correlation refinement** | Materialize full smoothed expression `Xavg = nnms @ Xs` | Stream pair batches; compute only columns referenced per batch | O(N × G_active) → O(N × 2·batch_size) | +| **Sparse PCA** | ARPACK | ARPACK or randomized SVD with implicit centering | Same; randomized is faster on GPU | + +In all cases the **numeric output is equivalent** (correlation streaming is +bit-identical to materialized; projection precomposition agrees to ~1e-15) +with one exception: the BFS neighbourhood expansion avoids a self-loop +artefact of the matrix-power path and selects slightly different marginal +neighbours (~1% edge difference on the golden-suite data). We consider BFS +strictly more correct; pass `_smart_expand(..., legacy=True)` for exact 2.x +reproduction. + +--- + +## The memory model + +SAMap iterates a two-phase loop: **project** cells into a joint latent +space via the current homology graph, then **refine** the homology graph +from the resulting cross-species neighbourhood. Each phase had an N × G +or N × N dense chokepoint. + +### Projection: precomposed feature translation + +The legacy code computed, per iteration and per species pair, + +``` +Xtr = X_i @ G_ij # N_i × G_j, ~30% dense — the bottleneck +Xscaled = Xtr / σ +wpca = (Xscaled * W_j) @ PCs_j +``` + +and assembled a block matrix of these before projecting through the PC +loadings. For realistic data this intermediate dominates both memory and +wall time. + +The 3.0 path observes that the entire chain is a linear operator and can +be precomposed: + +``` +P_ij = G_ij · diag(W_j / σ) · PCs_j # G_i × npcs — a few MB, regardless of N +wpca = X_i @ P_ij # ONE SpMM +``` + +The per-column standard deviation `σ` (which depends on `Xtr`) is recovered +without materializing `Xtr` via a quadratic form in the columns of `G_ij`, +using iteration-invariant precomputes of `X_iᵀX_i` and `X_i.mean(0)`. The +own-species contribution `X_i @ PCs_i` does not depend on the homology +graph at all and is computed once at the start of the run. + +**When it kicks in:** Always on. The iteration-invariant state is built in +`_Samap_Iter.__init__`; each iteration runs `_mapping_window_fast` which +consumes the cached state. + +### Coarsening: streaming mutual-NN + +The legacy mutual-NN step built intermediate N × N products (kNN graph × +expanded neighbourhood) chunked by rows but still O(N²)-dense at peak. The +3.0 path streams per-species-pair blocks and emits the final sparse kNN +directly via a COO builder — no dense intermediate. + +**When it kicks in:** Always on. Tunable via the internal `chunksize` +parameter on `_mapper` (default 20 000 rows), but this is not currently +exposed in the public `SAMAP.run()` API. + +### Correlation refinement: batched smoothed expression + +The legacy path materialized `Xavg = nnms @ Xs` — an N × G_active dense +matrix — so the per-pair correlation kernel could pull columns by index. +At million-cell scale this is multiple GB. + +The 3.0 streaming path (`batch_size=512`) processes 512 gene pairs at a +time: compute only the ≤1024 columns of `Xavg` actually referenced, +correlate, discard. Peak memory drops to O(N × 1024) regardless of how +many genes are active. Columns that appear in multiple batches are +recomputed — this is a cheap single-column SpMV and empirically <5% +overhead at scale. + +**Trade-off:** At small scale (<10k cells), where the full `Xavg` fits +comfortably in memory, the streaming path is ~3-5× *slower* than the +materialized path (benchmark: 3.6× at 3k cells, fixed per-batch dispatch +overhead). + +**Auto-selection (the default):** `batch_size="auto"` estimates the +materialized `Xavg` size from cell count, gene count, expression density, +and kNN degree (output density ≈ `1 - (1-p)^k` where `p` is input +density and `k` is average neighbour degree). If the estimate is under +`correlation_mem_threshold` (default 2 GB), materialise — fast path, +one big SpMM. Otherwise stream at `batch_size=512`. The decision is +logged at INFO level. + +**Tuning:** +- `correlation_mem_threshold` is exposed on `SAMAP.refine_homology_graph`. + Raise it on large-memory nodes (say 8 GB on a 64 GB box) to keep the + faster materialised path for larger datasets. Lower it on + memory-constrained environments. +- `batch_size=None` forces the materialised path unconditionally. +- `batch_size=` forces streaming at that size. + +### Neighbourhood expansion: BFS + +The legacy `_smart_expand` used repeated matrix powers with LIL zeroing +to collect an NH-hop neighbourhood per cell. This wastes one budget slot +per cell on a self-loop artefact (a cell's 2-hop neighbourhood always +includes itself) and has O(N²) peak memory for the power products. + +The 3.0 default is a numba BFS kernel that walks neighbours directly, +tracks a per-row visited set, and respects the budget exactly. + +**When it kicks in:** Default-on (`legacy=False`). Pass `legacy=True` to +`_smart_expand` for bit-exact 2.x reproduction. + +--- + +## Backend selection (CPU / GPU) + +```python +from samap import SAMAP + +sm = SAMAP(sams={...}, backend="auto") # pick CUDA if available, else CPU +sm = SAMAP(sams={...}, backend="cpu") # force numpy/scipy +sm = SAMAP(sams={...}, backend="cuda") # force cupy/cupyx — raises if unavailable +``` + +`"auto"` resolves to `"cuda"` if `cupy` is importable and a GPU is +detected, otherwise `"cpu"`. The resolved device is logged at construction. + +### GPU installation + +```bash +pip install "sc-samap[gpu]" +``` + +This pulls: + +- `cupy-cuda12x` — numpy/scipy dispatch on CUDA 12.x. For CUDA 11.x, + install `cupy-cuda11x` directly. +- `faiss-gpu` — GPU approximate kNN. **Note:** wheels are not on PyPI; + install via conda (`pytorch` or `conda-forge` channel). The pip extra + is advisory. +- `rapids-singlecell` — GPU Leiden/UMAP. Best installed from the + `rapidsai` conda channel. + +The kNN dispatch (`approximate_knn` in `samap.core.knn`) picks FAISS on +GPU and hnswlib on CPU automatically. + +--- + +## Tuning parameters + +Most of these live on internal functions; they are not (yet) plumbed +through the public `SAMAP.run()` API. + +| Parameter | Location | Default | When to change | +|---|---|---|---| +| `batch_size` | `_refine_corr`, `_refine_corr_parallel` | `512` | Lower (256, 128) on severe memory pressure. `None` for speed on small datasets. | +| `chunksize` | `_mapper` (coarsening) | `20000` | Lower if the streaming mutual-NN step OOMs on the row-chunk. Rarely needed. | +| `legacy` | `_smart_expand` | `False` | `True` only for bit-exact 2.x reproduction. | +| `svd_solver` | `samap.sam.pca._pca_with_sparse` | `"arpack"` | `"randomized"` is faster on GPU and at high `npcs`. Slightly different numerics (randomized is an approximation). Not plumbed to public API. | +| `backend` | `SAMAP.__init__` | `"auto"` | Force `"cpu"` for reproducibility; `"cuda"` to fail loudly if GPU is missing. | + +--- + +## Expected scaling + +These are **estimates** from synthetic benchmarks and informal testing. +Actual limits depend heavily on species count, gene-set overlap, data +density, and kNN parameters. + +| Setup | Approx. cell-count ceiling | Notes | +|---|---|---| +| 64 GB CPU, ≤2.x code | ~500k | `Xtr` and `Xavg` materialization are the walls | +| 64 GB CPU, 3.0 | ~2-3M | Limited by `X_iᵀX_i` Gram matrix and streaming overhead | +| 256 GB CPU + A100, 3.0 | ~5-10M | Randomized SVD helps; kNN moves to FAISS on GPU | + +Dominant memory costs in 3.0, in rough order: + +1. The input `X` matrices themselves (CSR, unavoidable) +2. `X_iᵀX_i` Gram matrices (G × G sparse, per species — precomputed once) +3. Per-iteration `P_ij` (G_i × npcs dense, but tiny) +4. Streaming correlation working set (N × 2·batch_size dense) + +If you OOM at step 2, your gene set is too large — consider pre-filtering +to highly variable genes before running SAMap. + +--- + +## Benchmark results + +From `benchmarks/bench_samap.py`, synthetic data, 2-species, timed over +3 SAMap iterations (1 for correlation), measured on CPU (tracemalloc peak): + +| Phase | n_cells | Legacy wall (s) | Optimized wall (s) | Speedup | Legacy mem (MB) | Optimized mem (MB) | Mem ratio | +|---|---:|---:|---:|---:|---:|---:|---:| +| expand | 1 000 | 1.18 | 0.34 | **3.5×** | 11.2 | 8.4 | 1.3× | +| expand | 3 000 | 5.61 | 1.16 | **4.8×** | 62.8 | 40.3 | 1.6× | +| projection | 1 000 | 6.61 | 3.32 | **2.0×** | 1157 | 571 | 2.0× | +| projection | 3 000 | 11.79 | 6.33 | **1.9×** | 1180 | 573 | 2.1× | +| correlation | 1 000 | 0.25 | 1.18 | 0.21× | 103 | 24 | **4.3×** | +| correlation | 3 000 | 0.37 | 1.33 | 0.27× | 307 | 71 | **4.3×** | + +**Interpretation:** + +- **expand**: Pure win. BFS is faster and smaller at every scale tested. +- **projection**: Pure win. ~2× on both axes; gains grow with N as the + `Xtr` materialization would grow linearly in N but the precomposed + `P_ij` does not. +- **correlation**: Memory win at the cost of speed — by design. The + speedup axis will flip positive at the scale where materialized `Xavg` + spills to swap or OOMs outright. On the toy benchmark sizes, streaming + is slower. + +Re-run the benchmark locally: + +```bash +python benchmarks/bench_samap.py --max-cells 10000 +python benchmarks/plot_bench.py benchmarks/results/bench_.csv +``` + +The `tracemalloc` peak catches numpy/scipy sparse allocations (the bulk +of legacy materialization) but misses allocations inside numba-nopython +kernels — so actual peak RSS may differ. diff --git a/pyproject.toml b/pyproject.toml index 4f700e7..40e016b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sc-samap" -version = "2.0.2" +version = "3.0.0" description = "The SAMap algorithm" readme = "README.md" license = "MIT" @@ -33,7 +33,6 @@ classifiers = [ "Topic :: Scientific/Engineering :: Bio-Informatics", ] dependencies = [ - "sc-sam>=2.0.0", "numpy>=1.26.0", "scipy>=1.11.0", "scanpy>=1.10.0", @@ -59,6 +58,19 @@ viz = [ "seaborn>=0.13", "plotly>=5.18", ] +gpu = [ + # cupy for GPU array ops + cupyx.scipy.sparse for SpMM. + # Pick the cuda12x variant as the sensible default; users on cuda11x + # should install cupy-cuda11x directly and skip this extra. + "cupy-cuda12x>=12.0", + # FAISS-GPU for brute-force kNN on the CUDA backend. NOTE: faiss-gpu + # wheels are not on PyPI — install via conda (faiss-gpu from the + # pytorch/conda-forge channels). This pin is advisory only. + "faiss-gpu>=1.7", + # rapids-singlecell for GPU UMAP + Leiden. Same caveat: best installed + # from the rapidsai conda channel for CUDA version matching. + "rapids-singlecell>=0.10", +] all = ["samap[dev,viz]"] [project.urls] @@ -104,6 +116,9 @@ ignore = [ "B008", # do not perform function calls in argument defaults "B905", # zip without explicit strict "NPY002", # legacy numpy random (not critical) + "RUF002", # ambiguous unicode in docstrings (math notation: ×, σ, ⊙ are intentional) + "RUF003", # ambiguous unicode in comments (same) + "SIM108", # ternary operator (multi-branch math is clearer as if/else) ] [tool.ruff.lint.isort] diff --git a/setup.py b/setup.py deleted file mode 100755 index 9ae71d7..0000000 --- a/setup.py +++ /dev/null @@ -1,30 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name="samap", - version="1.0.15", - description="The SAMap algorithm", - long_description="The Self-Assembling Manifold Mapping algorithm for mapping single-cell datasets across species.", - long_description_content_type="text/markdown", - author="Alexander J. Tarashansky", - url="https://github.com/atarashansky/SAMap", - author_email="tarashan@stanford.edu", - keywords="scrnaseq analysis manifold reconstruction cross-species mapping", - python_requires=">=3.7", - install_requires=[ - "sam-algorithm==1.0.2", - "scipy<1.13.0", - "numpy==1.23.5", - "scanpy==1.9.3", - "hnswlib==0.7.0", - "dill", - "numba==0.56.3", - "h5py==3.8.0", - "leidenalg", - "fast-histogram", - "holoviews-samap" - ], - packages=find_packages(), - include_package_data=True, - zip_safe=False, -) diff --git a/src/samap/__init__.py b/src/samap/__init__.py index cad2e7c..10f0083 100644 --- a/src/samap/__init__.py +++ b/src/samap/__init__.py @@ -18,7 +18,13 @@ from __future__ import annotations -__version__ = "2.0.0" +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("sc-samap") +except PackageNotFoundError: + # Package not installed (e.g. running from source checkout) + __version__ = "2.0.2" # Core imports # Analysis imports diff --git a/src/samap/_rsc_compat.py b/src/samap/_rsc_compat.py new file mode 100644 index 0000000..e9661ce --- /dev/null +++ b/src/samap/_rsc_compat.py @@ -0,0 +1,107 @@ +"""Optional rapids-singlecell dispatch for UMAP and Leiden. + +rapids-singlecell (rsc) provides GPU-accelerated implementations of the +scanpy tools suite. When both a CUDA backend is active *and* rsc is +installed, we dispatch to it — otherwise fall back to CPU scanpy. + +This module imports cleanly on machines without rsc: ``HAS_RSC`` is False +and all wrappers take the CPU path. No GPU dependency is imposed. + +Known upstream issues handled here: + +* rsc's Leiden occasionally returns a degenerate clustering (one cluster + per cell, or a single cluster) at certain resolution values — a known + cugraph edge case. When detected we fall back to CPU scanpy, which is + slower but always well-behaved. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from samap._logging import get_logger + +if TYPE_CHECKING: + from anndata import AnnData + + from samap.core._backend import Backend + +logger = get_logger("samap.rsc") + +# --- Optional import -------------------------------------------------------- + +try: + import rapids_singlecell as rsc + + HAS_RSC: bool = True +except ImportError: + rsc = None # type: ignore[assignment] + HAS_RSC = False + + +# --- Dispatch wrappers ------------------------------------------------------ + + +def umap(adata: AnnData, bk: Backend, **kwargs: Any) -> None: + """Compute UMAP embedding — rsc on GPU, scanpy on CPU. + + Both implementations write to ``adata.obsm['X_umap']`` in place. + + Parameters + ---------- + adata + Annotated data with neighbors already computed. + bk + Active backend. rsc is used only when ``bk.gpu and HAS_RSC``. + **kwargs + Forwarded to ``rsc.tl.umap`` or ``scanpy.tl.umap`` (same signature). + """ + if bk.gpu and HAS_RSC: + rsc.tl.umap(adata, **kwargs) + else: + import scanpy as sc + + sc.tl.umap(adata, **kwargs) + + +def leiden(adata: AnnData, bk: Backend, key_added: str = "leiden", **kwargs: Any) -> None: + """Compute Leiden clustering — rsc on GPU with fallback, scanpy on CPU. + + rapids-singlecell's cugraph-backed Leiden has a known failure mode + where it collapses to one cluster or explodes to one-cluster-per-cell + at certain resolution values. When we detect that, we warn and re-run + on CPU. A successful clustering is defined heuristically as producing + between 2 and n_cells/2 clusters. + + Parameters + ---------- + adata + Annotated data with neighbors already computed. + bk + Active backend. rsc is attempted only when ``bk.gpu and HAS_RSC``. + key_added + Key under which the cluster assignments are stored in ``adata.obs``. + Defaults to ``"leiden"`` (matching scanpy's default). + **kwargs + Forwarded to ``rsc.tl.leiden`` or ``scanpy.tl.leiden``. + """ + if bk.gpu and HAS_RSC: + rsc.tl.leiden(adata, key_added=key_added, **kwargs) + n_clusters = adata.obs[key_added].nunique() + n_cells = adata.shape[0] + # Degenerate: everything in one bucket, or everything in its own bucket. + # A reasonable clustering of single-cell data sits well inside this band. + if n_clusters < 2 or n_clusters > n_cells // 2: + logger.warning( + "rsc leiden returned %d clusters for %d cells (degenerate); " + "falling back to CPU scanpy.", + n_clusters, + n_cells, + ) + import scanpy as sc + + sc.tl.leiden(adata, key_added=key_added, **kwargs) + else: + import scanpy as sc + + sc.tl.leiden(adata, key_added=key_added, **kwargs) diff --git a/src/samap/analysis/enrichment.py b/src/samap/analysis/enrichment.py index aaad8d3..40394d6 100644 --- a/src/samap/analysis/enrichment.py +++ b/src/samap/analysis/enrichment.py @@ -215,7 +215,7 @@ def __init__( logger.info("Finding enriched gene pairs...") gpf = GenePairFinder(sm, keys=keys) - gene_pairs = gpf.find_all(thr=align_thr, n_top=n_top) + gene_pairs = gpf.find_all(align_thr=align_thr, n_top=n_top) self.DICT: dict[str, NDArray[Any]] = {} for c in gene_pairs.columns: diff --git a/src/samap/analysis/gene_pairs.py b/src/samap/analysis/gene_pairs.py index d90dc6b..23d94ce 100644 --- a/src/samap/analysis/gene_pairs.py +++ b/src/samap/analysis/gene_pairs.py @@ -11,20 +11,16 @@ import sklearn.utils.sparsefuncs as sf from samap._logging import logger +from samap.utils import q as _q from samap.utils import substr, to_vn if TYPE_CHECKING: from typing import Any from numpy.typing import NDArray - from samalg import SAM from samap.core.mapping import SAMAP - - -def _q(x: Any) -> NDArray[Any]: - """Convert input to numpy array.""" - return np.array(list(x)) + from samap.sam import SAM class GenePairFinder: diff --git a/src/samap/analysis/plotting.py b/src/samap/analysis/plotting.py index b8475cc..770b581 100644 --- a/src/samap/analysis/plotting.py +++ b/src/samap/analysis/plotting.py @@ -7,16 +7,11 @@ import numpy as np import pandas as pd +from samap.utils import q as _q + if TYPE_CHECKING: from typing import Any - from numpy.typing import NDArray - - -def _q(x: Any) -> NDArray[Any]: - """Convert input to numpy array.""" - return np.array(list(x)) - def sankey_plot( M: pd.DataFrame, diff --git a/src/samap/analysis/scores.py b/src/samap/analysis/scores.py index 2bdc7f6..da5cf03 100644 --- a/src/samap/analysis/scores.py +++ b/src/samap/analysis/scores.py @@ -9,9 +9,10 @@ import numpy as np import pandas as pd import scipy as sp -from samalg import SAM -from samap.utils import df_to_dict, substr, to_vn, to_vo +from samap.sam import SAM +from samap.utils import coo_to_csr_overwrite, df_to_dict, substr, to_vn, to_vo +from samap.utils import q as _q if TYPE_CHECKING: from typing import Any @@ -21,11 +22,6 @@ from samap.core.mapping import SAMAP -def _q(x: Any) -> NDArray[Any]: - """Convert input to numpy array.""" - return np.array(list(x)) - - def _compute_csim( samap: SAM, key: str, @@ -370,9 +366,9 @@ def convert_eggnog_to_homologs( X = np.array(X) Y = np.array(Y) - B = sp.sparse.lil_matrix((og.size, D.size)) - B[Y, X] = 1 - B = B.tocsr() + # Binary membership matrix; only the nonzero pattern matters downstream + # (result is binarised after dot), so COO sum-duplicates is safe here. + B = sp.sparse.coo_matrix((np.ones(Y.size), (Y, X)), shape=(og.size, D.size)).tocsr() B = B.dot(B.T) B.data[:] = 1 pairs = gn[np.vstack(B.nonzero()).T] @@ -415,10 +411,14 @@ def CellTypeTriangles( x, y = substr(all_pairs, ";") ctu = np.unique(np.concatenate((x, y))) Z = pd.DataFrame(data=np.arange(ctu.size)[None, :], columns=ctu) - nnm = sp.sparse.lil_matrix((ctu.size,) * 2) - nnm[Z[x].values.flatten(), Z[y].values.flatten()] = alignment - nnm[Z[y].values.flatten(), Z[x].values.flatten()] = alignment - nnm = nnm.tocsr() + zx = Z[x].values.flatten() + zy = Z[y].values.flatten() + nnm = coo_to_csr_overwrite( + np.concatenate([zx, zy]), + np.concatenate([zy, zx]), + np.concatenate([alignment, alignment]), + (ctu.size, ctu.size), + ) G = nx.Graph() gps = ctu[np.vstack(nnm.nonzero()).T] @@ -624,9 +624,7 @@ def GeneTriangles( all_genes = np.unique(pairs.flatten()) Z = pd.DataFrame(data=np.arange(all_genes.size)[None, :], columns=all_genes) x, y = Z[pairs[:, 0]].values.flatten(), Z[pairs[:, 1]].values.flatten() - GNNM = sp.sparse.lil_matrix((all_genes.size,) * 2) - GNNM[x, y] = data - GNNM = GNNM.tocsr() + GNNM = coo_to_csr_overwrite(x, y, data, (all_genes.size, all_genes.size)) GNNM.data[GNNM.data < corr_thr] = 0 GNNM.eliminate_zeros() diff --git a/src/samap/core/_backend.py b/src/samap/core/_backend.py new file mode 100644 index 0000000..299fcd3 --- /dev/null +++ b/src/samap/core/_backend.py @@ -0,0 +1,423 @@ +"""Backend dispatch layer for CPU (numpy/scipy) ↔ GPU (cupy/cupyx). + +This module provides a thin abstraction so SAMap's hot path can run on either +CPU or GPU with a single code path. The ``Backend`` class exposes the active +array and sparse namespaces (``xp``, ``sp``) and provides compatibility shims +for operations where the scipy and cupy APIs diverge. + +cupy is an **optional** dependency. If it is not installed, importing this +module still succeeds; ``Backend("cpu")`` and ``Backend("auto")`` work, and +``Backend("cuda")`` raises a clear error. + +Known scipy ↔ cupy divergences handled here: + +* No ``.nonzero()`` on cupy sparse matrices — use ``nonzero()`` shim. +* No LIL format on cupy — use :class:`COOBuilder` instead of ``lil_matrix``. +* ``svds`` on cupy accepts no ``solver=`` / ``v0=`` / ``random_state=`` kwargs + (cupy uses a thick-restart Lanczos on the normal equations; scipy uses + ARPACK). The shim filters unsupported kwargs on GPU. +* cuSPARSE SpMM internally forces a Fortran-order copy of the dense RHS; use + :meth:`Backend.asfortran_if_gpu` to pre-convert and avoid an implicit copy + at matmul time. +* ``sum_duplicates()`` on cupy sorts indices in a different order than scipy. + This usually does not affect numerical results but can break exact-bytes + golden-output comparisons — compare with a tolerance instead. +""" + +from __future__ import annotations + +import warnings +from types import ModuleType +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import scipy.sparse as scipy_sparse +import scipy.sparse.linalg as scipy_spla + +if TYPE_CHECKING: + from numpy.typing import ArrayLike + +# --- Optional cupy import --------------------------------------------------- +# cupy may not be installed (e.g. on macOS dev machines with no CUDA support). +# We try to import it at module load and set flags accordingly. All cupy-using +# code paths are guarded on these flags so the module imports cleanly either +# way. + +_cupy: ModuleType | None +_cupyx_sparse: ModuleType | None +_cupyx_spla: ModuleType | None + +try: + import cupy as _cupy + import cupyx.scipy.sparse as _cupyx_sparse + import cupyx.scipy.sparse.linalg as _cupyx_spla + + HAS_CUPY: bool = True +except ImportError: + _cupy = None + _cupyx_sparse = None + _cupyx_spla = None + HAS_CUPY = False + + +def _cuda_available() -> bool: + """Return True iff cupy is importable *and* a CUDA device is present.""" + if not HAS_CUPY: + return False + try: + return bool(_cupy.is_available()) # type: ignore[union-attr] + except Exception: + # cupy can raise from is_available() if the CUDA driver is present but + # incompatible. Treat any failure as "no GPU". + return False + + +# Kwargs that scipy.sparse.linalg.svds accepts but cupyx does not. +# cupy's svds signature: (a, k=6, *, ncv, tol, which, maxiter, +# return_singular_vectors) +_CUPY_SVDS_UNSUPPORTED = frozenset({"solver", "v0", "random_state", "rng", "options"}) + + +__all__ = ["HAS_CUPY", "Backend", "COOBuilder"] + + +class Backend: + """Dispatch between numpy/scipy and cupy/cupyx. + + Parameters + ---------- + device + ``"cpu"`` forces the numpy/scipy backend. ``"cuda"`` forces the + cupy/cupyx backend (raises :class:`RuntimeError` if cupy is not + installed or no CUDA device is visible). ``"auto"`` (default) picks + cuda when available, else cpu. + + Attributes + ---------- + xp + Array namespace — :mod:`numpy` or :mod:`cupy`. + sp + Sparse namespace — :mod:`scipy.sparse` or :mod:`cupyx.scipy.sparse`. + spla + Sparse linear-algebra namespace — :mod:`scipy.sparse.linalg` or + :mod:`cupyx.scipy.sparse.linalg`. + gpu + ``True`` if the cuda backend is active. + device + The resolved device string, ``"cpu"`` or ``"cuda"``. + """ + + __slots__ = ("_faiss_res", "device", "gpu", "sp", "spla", "xp") + + def __init__(self, device: Literal["cpu", "cuda", "auto"] = "auto") -> None: + self._faiss_res: Any = None # lazy faiss.StandardGpuResources cache + if device == "auto": + device = "cuda" if _cuda_available() else "cpu" + + if device == "cuda": + if not HAS_CUPY: + raise RuntimeError( + "Backend('cuda') requested but cupy is not installed. " + "Install cupy (e.g. 'pip install cupy-cuda12x') or use " + "device='cpu'." + ) + if not _cuda_available(): + raise RuntimeError( + "Backend('cuda') requested but no CUDA device is " + "available. Check your GPU drivers or use device='cpu'." + ) + self.xp = _cupy + self.sp = _cupyx_sparse + self.spla = _cupyx_spla + self.gpu = True + elif device == "cpu": + self.xp = np + self.sp = scipy_sparse + self.spla = scipy_spla + self.gpu = False + else: + raise ValueError(f"device must be 'cpu', 'cuda', or 'auto'; got {device!r}") + + self.device: str = device + + def __repr__(self) -> str: + return f"Backend(device={self.device!r}, gpu={self.gpu})" + + # ----------------------------------------------------------------------- + # Compat shims for scipy/cupy API differences + # ----------------------------------------------------------------------- + + def nonzero(self, A: Any) -> tuple[ArrayLike, ArrayLike]: + """Return (row, col) indices of stored entries. + + cupy sparse matrices lack a ``.nonzero()`` method. On GPU this goes + through COO format; on CPU it calls the native method. In both cases + the result includes *explicit zeros* (i.e. this is structural + nonzero, matching scipy's behaviour on sparse matrices). + """ + if self.gpu and self.sp.issparse(A): + coo = A.tocoo() + return coo.row, coo.col + return A.nonzero() + + def sparse_from_coo( + self, + data: ArrayLike, + row: ArrayLike, + col: ArrayLike, + shape: tuple[int, int], + fmt: str = "csr", + ) -> Any: + """Build a sparse matrix from COO triplets on the active backend. + + Duplicate (row, col) entries are summed. Note that cupy's + ``sum_duplicates`` sorts indices differently from scipy's, which + should not affect numerical results but may break byte-exact + comparisons. + """ + data = self.xp.asarray(data) + row = self.xp.asarray(row) + col = self.xp.asarray(col) + coo = self.sp.coo_matrix((data, (row, col)), shape=shape) + return coo.asformat(fmt) + + def setdiag(self, A: Any, val: Any) -> Any: + """Set the main diagonal of ``A`` to ``val``, in CSR, no LIL round-trip. + + scipy's CSR ``setdiag`` works but emits a ``SparseEfficiencyWarning`` + when it has to change the sparsity structure. We suppress that + warning and call ``eliminate_zeros()`` when zeroing the diagonal so + the structural nnz shrinks. + + cupy CSR supports ``setdiag`` directly as well; the same path + handles both backends. + + Returns + ------- + The input matrix ``A`` (modified in place), converted to CSR if it + was not already. + """ + if A.format != "csr": + A = A.tocsr() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", scipy_sparse.SparseEfficiencyWarning) + A.setdiag(val) + if np.isscalar(val) and val == 0: + A.eliminate_zeros() + return A + + def svds(self, A: Any, k: int, **kwargs: Any) -> Any: + """Compute ``k`` largest singular values/vectors. + + On GPU, strips kwargs that cupy does not support (``solver``, ``v0``, + ``random_state``, ``rng``, ``options``). cupy's implementation uses a + thick-restart Lanczos on the normal equations (AᴴA or AAᴴ), not the + Golub-Kahan bidiagonalisation that scipy's ARPACK path uses. This + squares the condition number — fine for top-k singular values, less + accurate for the smallest. + """ + if self.gpu: + kwargs = {k_: v for k_, v in kwargs.items() if k_ not in _CUPY_SVDS_UNSUPPORTED} + return self.spla.svds(A, k=k, **kwargs) + + def LinearOperator( + self, + shape: tuple[int, int], + matvec: Any, + rmatvec: Any = None, + matmat: Any = None, + rmatmat: Any = None, + dtype: Any = None, + ) -> Any: + """Dispatch to the backend's ``LinearOperator`` constructor. + + Both scipy and cupy share the same constructor signature, so this is + a straight pass-through to the active ``spla`` namespace. + """ + return self.spla.LinearOperator( + shape=shape, + matvec=matvec, + rmatvec=rmatvec, + matmat=matmat, + rmatmat=rmatmat, + dtype=dtype, + ) + + # ----------------------------------------------------------------------- + # Data movement + # ----------------------------------------------------------------------- + + def to_device(self, A: Any) -> Any: + """Move array or sparse matrix to the active backend. + + * On a CPU backend this is a no-op (returns ``A`` unchanged). + * On a GPU backend, uploads numpy/scipy data to cupy/cupyx. Objects + already on-device are returned unchanged. + + Handles dense ndarrays and scipy CSR/CSC/COO sparse matrices. + """ + if not self.gpu: + return A + + # Already on device? + if isinstance(A, _cupy.ndarray) or _cupyx_sparse.issparse(A): + return A + + if scipy_sparse.issparse(A): + # cupy sparse constructors accept scipy sparse matrices directly + # and copy to device. + fmt = A.format + if fmt == "csr": + return _cupyx_sparse.csr_matrix(A) + if fmt == "csc": + return _cupyx_sparse.csc_matrix(A) + if fmt == "coo": + return _cupyx_sparse.coo_matrix(A) + # Unsupported format on GPU (lil, dok, bsr, dia handled via csr) + return _cupyx_sparse.csr_matrix(A.tocsr()) + + # Dense array → cupy + return _cupy.asarray(A) + + def to_host(self, A: Any) -> Any: + """Move array or sparse matrix back to numpy/scipy (host memory). + + If ``A`` is already host-resident (numpy/scipy), returns it unchanged. + If ``A`` is a cupy array or cupyx sparse matrix, calls ``.get()`` to + transfer to host. + """ + if HAS_CUPY and isinstance(A, _cupy.ndarray): + return A.get() + if HAS_CUPY and _cupyx_sparse.issparse(A): + return A.get() + return A + + def asfortran_if_gpu(self, A: Any) -> Any: + """Pre-convert a dense array to Fortran (column-major) order on GPU. + + cuSPARSE's SpMM (sparse-times-dense) path requires the dense RHS to be + Fortran-ordered and will silently make a copy if it is not. When the + same dense block is re-used across several SpMM calls (e.g. inside a + Lanczos loop), pre-converting once avoids repeated implicit copies. + + On CPU this is a no-op (scipy's sparse dot handles C-order fine). + """ + if not self.gpu: + return A + if isinstance(A, _cupy.ndarray) and not A.flags.f_contiguous: + return _cupy.asfortranarray(A) + return A + + def free_pool(self) -> None: + """Release unused blocks from cupy's memory pools. + + cupy caches GPU allocations in a memory pool for reuse. After a large + transient allocation, calling this reclaims device memory. No-op on + CPU. + """ + if not self.gpu: + return + _cupy.get_default_memory_pool().free_all_blocks() + _cupy.get_default_pinned_memory_pool().free_all_blocks() + + def faiss_gpu_resources(self) -> Any: + """Lazily create and cache a ``faiss.StandardGpuResources`` instance. + + FAISS's ``StandardGpuResources`` holds cuBLAS handles, streams, and a + temporary-memory pool. Creating one is expensive (~100ms) and the + object is safe to reuse across index builds, so we cache a single + instance per Backend. + + Returns ``None`` on a CPU backend or if faiss has no GPU support. + """ + if not self.gpu: + return None + if self._faiss_res is not None: + return self._faiss_res + try: + import faiss + except ImportError: + return None + if not hasattr(faiss, "StandardGpuResources"): + # faiss-cpu build — no GPU index support + return None + self._faiss_res = faiss.StandardGpuResources() + return self._faiss_res + + +class COOBuilder: + """Accumulate (row, col, val) triplets on the host, finalise to CSR/CSC. + + This replaces the ``lil_matrix`` + fancy-index-assignment pattern, which + does not work on GPU because cupy has no LIL format. Triplets are buffered + in Python lists (host side, O(1) append), then concatenated and converted + to the target format in one shot at :meth:`finalize` time. + + Parameters + ---------- + bk + Backend that determines the target namespace for :meth:`finalize`. + shape + Shape of the output matrix. + dtype + Data dtype for values. Defaults to ``float64``. + + Examples + -------- + >>> bk = Backend("cpu") + >>> b = COOBuilder(bk, shape=(3, 3)) + >>> b.add(0, 1, 5.0) + >>> b.add_batch([1, 2], [2, 0], [3.0, 7.0]) + >>> A = b.finalize(fmt="csr") + >>> A.toarray() + array([[0., 5., 0.], + [0., 0., 3.], + [7., 0., 0.]]) + """ + + __slots__ = ("_bk", "_cols", "_dtype", "_rows", "_shape", "_vals") + + def __init__(self, bk: Backend, shape: tuple[int, int], dtype: Any = None) -> None: + self._bk = bk + self._shape = shape + self._dtype = np.float64 if dtype is None else np.dtype(dtype) + # Buffer as lists of numpy arrays — cheap to append, single concat at end. + self._rows: list[np.ndarray] = [] + self._cols: list[np.ndarray] = [] + self._vals: list[np.ndarray] = [] + + def add(self, i: int, j: int, v: Any) -> None: + """Add a single (row, col, value) triplet.""" + self._rows.append(np.asarray([i], dtype=np.int64)) + self._cols.append(np.asarray([j], dtype=np.int64)) + self._vals.append(np.asarray([v], dtype=self._dtype)) + + def add_batch(self, ii: ArrayLike, jj: ArrayLike, vv: ArrayLike) -> None: + """Add arrays of (row, col, value) triplets at once. + + Inputs may be on host or device; they are brought to host for + buffering. This keeps accumulation cheap and defers the single + host→device transfer to :meth:`finalize`. + """ + ii_h = self._bk.to_host(ii) + jj_h = self._bk.to_host(jj) + vv_h = self._bk.to_host(vv) + self._rows.append(np.ascontiguousarray(ii_h, dtype=np.int64)) + self._cols.append(np.ascontiguousarray(jj_h, dtype=np.int64)) + self._vals.append(np.ascontiguousarray(vv_h, dtype=self._dtype)) + + def finalize(self, fmt: str = "csr") -> Any: + """Concatenate buffered triplets and build a sparse matrix. + + Duplicate (row, col) entries are summed (standard COO semantics). + On a GPU backend the result lives on device. + """ + if self._rows: + row = np.concatenate(self._rows) + col = np.concatenate(self._cols) + val = np.concatenate(self._vals) + else: + row = np.empty(0, dtype=np.int64) + col = np.empty(0, dtype=np.int64) + val = np.empty(0, dtype=self._dtype) + return self._bk.sparse_from_coo(val, row, col, self._shape, fmt=fmt) diff --git a/src/samap/core/coarsening.py b/src/samap/core/coarsening.py new file mode 100644 index 0000000..c0247af --- /dev/null +++ b/src/samap/core/coarsening.py @@ -0,0 +1,471 @@ +"""Cross-species kNN graph stitching and manifold assembly. + +The `_mapper` function here is the core graph-coarsening step: it takes +per-species neighbourhoods and the cross-species projection kNN, stitches them +together via in-degree coarsening, and produces the combined SAM manifold. + +Implementation notes +-------------------- +The mutual-NN construction exploits block structure to avoid materialising the +full N×N intermediate ``D = B @ nnm_internal.T``: + +* ``B`` (cross-species kNN, from projection) is **block-off-diagonal** — + within-species blocks are zero by construction. +* ``nnm_internal`` (expanded within-species kNN) is **block-diagonal**. +* Therefore ``D`` is also block-off-diagonal: ``D[a,b] = B[a,b] @ nnm_b.T``. +* The mutualisation ``M = sqrt(D ⊙ D.T)`` factors per species pair: + ``M[a,b] = sqrt(D[a,b] ⊙ D[b,a].T)``, and the two factors can be computed + chunk-by-chunk for the source species ``a`` without ever holding the full D. + +This brings peak memory from O(N²) down to O(N_a × N_b) per pair (and further +down to O(chunk × N_b) when chunking within a large species). +""" + +from __future__ import annotations + +import gc +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import scanpy as sc +import scipy as sp + +from samap._constants import ( + UMAP_MAXITER_LARGE, + UMAP_MAXITER_SMALL, + UMAP_MIN_DIST, + UMAP_SIZE_THRESHOLD, +) +from samap._logging import logger +from samap.sam import SAM +from samap.utils import q as _q +from samap.utils import sparse_knn + +from ._backend import Backend, COOBuilder +from .correlation import _replace +from .expand import _smart_expand +from .homology import _tanh_scale +from .projection import _mapping_window, _mapping_window_fast + +if TYPE_CHECKING: + from typing import Any + + from numpy.typing import NDArray + + +def _generate_coclustering_matrix(cl: NDArray[Any]) -> sp.sparse.csr_matrix: + """Generate a co-clustering indicator matrix.""" + from samap.sam.utils import convert_annotations + + cl_arr = convert_annotations(np.array(list(cl))) + clu, _cluc = np.unique(cl_arr, return_counts=True) + v = np.zeros((cl_arr.size, clu.size)) + v[np.arange(v.shape[0]), cl_arr] = 1 + return sp.sparse.csr_matrix(v) + + +def _scale_by_corr( + M_chunk: Any, + global_rows: NDArray[np.int64], + wPCA: NDArray[Any], +) -> Any: + """Rescale mutual-NN edge weights by cell-cell correlation in wPCA space. + + Operates on a chunk of rows (all columns present for those rows), so + per-row maxima are exact. Returns a CSR with the same sparsity pattern as + the input and rescaled data — matches the original full-matrix path + exactly. + """ + M_chunk = M_chunk.tocsr() + x, y = M_chunk.nonzero() + # map chunk-local row indices → global cell indices for the correlation lookup + vals = _replace(wPCA, global_rows[x], y) + # floor at 1e-3 (no eliminate_zeros — preserve M_chunk's sparsity pattern) + vals[vals < 1e-3] = 1e-3 + + F = M_chunk.copy() + F.data[:] = vals + + Fmax = np.asarray(F.max(1).todense()).flatten() + Fmax[Fmax == 0] = 1 + F = F.multiply(1 / Fmax[:, None]).tocsr() + F.data[:] = _tanh_scale(F.data, center=0.7, scale=10) + + Mmax = np.asarray(M_chunk.max(1).todense()).flatten() + Mmax[Mmax == 0] = 1 + + scaled = F.multiply(M_chunk).tocsr() + scaled.data[:] = np.sqrt(scaled.data) + + scaled_max = np.asarray(scaled.max(1).todense()).flatten() + scaled_max[scaled_max == 0] = 1 + + return scaled.multiply((Mmax / scaled_max)[:, None]).tocsr() + + +def _compute_mutual_graph( + nnms_in: dict[str, Any], + neigh_from_keys: dict[str, bool], + B: Any, + offsets: dict[str, int], + n_cells: dict[str, int], + sids: list[str], + k1: int, + N: int, + *, + pairwise: bool, + chunksize: int, + threshold: float, + scale_edges_by_corr: bool, + wPCA: NDArray[Any] | None, + bk: Backend | None = None, +) -> Any: + """Streaming per-species-pair mutual-NN construction. + + For each source species ``a``, iterates over row chunks and over partner + species ``b ≠ a``, computing:: + + left = D[a,b][chunk] = B[a,b][chunk] @ nnm_b.T + right = D[b,a].T[chunk] = nnm_a[chunk] @ B[b,a].T + M[a,b][chunk] = sqrt(left ⊙ right) # mutual geometric mean + + then assembles the chunk's full row (all partners), optionally rescales by + wPCA correlation, top-k sparsifies, and accumulates into a COO builder. + + Parameters + ---------- + nnms_in + Per-species within-species neighbour matrices. For a species with + ``neigh_from_keys[sid]`` false, this is an (N_i × N_i) expanded kNN. + For ``neigh_from_keys[sid]`` true, this is an (N_i × n_clusters) + one-hot cluster-membership matrix; the effective neighbour block is + ``M @ M.T`` (cells sharing a cluster), kept factored to avoid + materialising a potentially dense N_i² block. + neigh_from_keys + Per-species flag for the coclustering path (see above). + B + Cross-species kNN, (N × N), block-off-diagonal in global indices. + offsets, n_cells, sids, N + Species layout in global index space. + k1 + Neighbours to keep per row (per species-pair if ``pairwise`` and + more than two species; otherwise global per row). + pairwise + If True and ``len(sids) > 2``, top-k is applied per species-pair + block rather than globally per row. + chunksize + Row-chunk size for the source species loop. + threshold + Elementwise floor applied to both ``left`` and ``right`` before + mutualisation (entries below it are zeroed). Set to 0 to disable. + scale_edges_by_corr, wPCA + If True, rescale mutualised weights by tanh-scaled cell-cell + correlation in ``wPCA`` space. + + Returns + ------- + scipy.sparse.csr_matrix + The mutualised, sparsified cross-species graph (N × N). + """ + if bk is None: + bk = Backend("cpu") + builder = COOBuilder(bk, shape=(N, N)) + pairwise_topk = pairwise and len(sids) > 2 + + # Precompute per-species slices into B for cheap block extraction. + gslice: dict[str, slice] = { + sid: slice(offsets[sid], offsets[sid] + n_cells[sid]) for sid in sids + } + + for a in sids: + partners = [b for b in sids if b != a] + if not partners: + continue + + na = n_cells[a] + off_a = offsets[a] + nnm_a = nnms_in[a] + nfk_a = neigh_from_keys[a] + + # Cache per-partner blocks of B once (row slicing is cheap on CSR). + # B_ab[b]: (N_a × N_b), B_baT[b]: (N_a × N_b) = B[b,a].T + B_ab: dict[str, Any] = {} + B_baT: dict[str, Any] = {} + for b in partners: + B_ab[b] = B[gslice[a], gslice[b]].tocsr() + B_baT[b] = B[gslice[b], gslice[a]].T.tocsr() + + # For nfk_a, precompute Ma.T @ B_ba.T per partner + # (n_clusters_a × N_b, small). Reused across all chunks of species a. + pre_right: dict[str, Any] = {} + if nfk_a: + for b in partners: + pre_right[b] = nnm_a.T.dot(B_baT[b]) + + for start in range(0, na, chunksize): + end = min(start + chunksize, na) + local = slice(start, end) + chunk_len = end - start + global_rows = np.arange(off_a + start, off_a + end, dtype=np.int64) + + row_l: list[NDArray[np.intp]] = [] + col_l: list[NDArray[np.int64]] = [] + val_l: list[NDArray[np.float64]] = [] + + for b in partners: + nnm_b = nnms_in[b] + nfk_b = neigh_from_keys[b] + B_ab_chunk = B_ab[b][local] # (chunk × N_b) + + # left = D_ab[chunk] = B_ab[chunk] @ nnm_block_b.T + if nfk_b: + # nnm_block_b = M_b @ M_b.T → left = (B_ab_chunk @ M_b) @ M_b.T + left = B_ab_chunk.dot(nnm_b).dot(nnm_b.T) + else: + left = B_ab_chunk.dot(nnm_b.T) + + # right = D_ba.T[chunk] = nnm_block_a[chunk] @ B_ba.T + if nfk_a: + # nnm_block_a = M_a @ M_a.T → right = M_a[chunk] @ (M_a.T @ B_ba.T) + right = nnm_a[local].dot(pre_right[b]) + else: + right = nnm_a[local].dot(B_baT[b]) + + if threshold > 0: + left = left.tocsr() + left.data[left.data < threshold] = 0 + left.eliminate_zeros() + right = right.tocsr() + right.data[right.data < threshold] = 0 + right.eliminate_zeros() + + Mb = left.multiply(right).tocsr() + if Mb.nnz == 0: + continue + Mb.data[:] = np.sqrt(Mb.data) + + coo = Mb.tocoo() + row_l.append(coo.row) + col_l.append(coo.col.astype(np.int64) + offsets[b]) + val_l.append(coo.data) + + if not row_l: + continue + + M_chunk = sp.sparse.csr_matrix( + (np.concatenate(val_l), (np.concatenate(row_l), np.concatenate(col_l))), + shape=(chunk_len, N), + ) + + if scale_edges_by_corr: + M_chunk = _scale_by_corr(M_chunk, global_rows, wPCA) + + if pairwise_topk: + out_rows: list[NDArray[np.intp]] = [] + out_cols: list[NDArray[np.int64]] = [] + out_vals: list[NDArray[np.float64]] = [] + for b in partners: + Msub = M_chunk[:, gslice[b]] + if Msub.nnz == 0: + continue + Mk = sparse_knn(Msub, k1).tocoo() + out_rows.append(Mk.row) + out_cols.append(Mk.col.astype(np.int64) + offsets[b]) + out_vals.append(Mk.data) + if not out_rows: + continue + rows = np.concatenate(out_rows) + cols = np.concatenate(out_cols) + vals = np.concatenate(out_vals) + else: + Mk = sparse_knn(M_chunk, k1).tocoo() + rows, cols, vals = Mk.row, Mk.col.astype(np.int64), Mk.data + + builder.add_batch(global_rows[rows], cols, vals) + + return builder.finalize("csr") + + +def _mapper( + sams: dict[str, SAM], + gnnm: sp.sparse.csr_matrix | None = None, + gn: NDArray[Any] | None = None, + NHS: dict[str, int] | None = None, + umap: bool = False, + mdata: dict[str, Any] | None = None, + k: int | None = None, + K: int = 20, + chunksize: int = 20000, + coarsen: bool = True, + keys: dict[str, str] | None = None, + scale_edges_by_corr: bool = False, + neigh_from_keys: dict[str, bool] | None = None, + pairwise: bool = True, + proj_cache: dict[str, Any] | None = None, + bk: Backend | None = None, + **kwargs: Any, +) -> SAM: + """Map cells between species.""" + if NHS is None: + NHS = dict.fromkeys(sams.keys(), 3) + + if neigh_from_keys is None: + neigh_from_keys = dict.fromkeys(sams.keys(), False) + + if mdata is None: + if proj_cache is not None: + # Fast path: precomputed iteration-invariant state; the expensive + # ss/XtX/wpca_own are read from cache, not rebuilt. + mdata = _mapping_window_fast(gnnm, proj_cache, K=K, pairwise=pairwise) + else: + # Legacy path: rebuild precompute on the fly (wasteful but correct). + mdata = _mapping_window(sams, gnnm, gn, K=K, pairwise=pairwise) + + k1 = K + + if keys is None: + keys = dict.fromkeys(sams.keys(), "leiden_clusters") + + nnms_in: dict[str, Any] = {} + nnms_in0: dict[str, Any] = {} + any_nfk = False + for sid in sams: + logger.info("Expanding neighbourhoods of species %s...", sid) + cl = sams[sid].get_labels(keys[sid]) + _, ix, cluc = np.unique(cl, return_counts=True, return_inverse=True) + K_arr = cluc[ix] + nnms_in0[sid] = sams[sid].adata.obsp["connectivities"].copy() + if not neigh_from_keys[sid]: + nnm_in = _smart_expand(nnms_in0[sid], K_arr, NH=NHS[sid], bk=bk) + nnm_in.data[:] = 1 + nnms_in[sid] = nnm_in + else: + nnms_in[sid] = _generate_coclustering_matrix(cl) + any_nfk = True + + # --- Species layout in global index space ------------------------------- + sids = list(sams.keys()) + n_cells: dict[str, int] = {sid: nnms_in0[sid].shape[0] for sid in sids} + offsets: dict[str, int] = {} + _off = 0 + for sid in sids: + offsets[sid] = _off + _off += n_cells[sid] + N = _off + + nnm_internal0 = sp.sparse.block_diag(list(nnms_in0.values())).tocsr() + + logger.info("Indegree coarsening") + + # Original non-coclustering path applied a 0.1 floor to D before + # mutualisation; the coclustering path did not. Preserve that asymmetry. + threshold = 0.0 if any_nfk else 0.1 + + if scale_edges_by_corr: + logger.info("Rescaling edge weights by expression correlations.") + + Dk = _compute_mutual_graph( + nnms_in, + neigh_from_keys, + mdata["knn"], + offsets, + n_cells, + sids, + k1, + N, + pairwise=pairwise, + chunksize=chunksize, + threshold=threshold, + scale_edges_by_corr=scale_edges_by_corr, + wPCA=mdata["wPCA"] if scale_edges_by_corr else None, + bk=bk, + ) + + del nnms_in + gc.collect() + + if not pairwise or len(sids) == 2: + denom = k1 + else: + denom = k1 * (len(sids) - 1) + + species_list = [] + for sid in sids: + species_list += [sid] * n_cells[sid] + species_list = np.array(species_list) + + sr = np.asarray(Dk.sum(1)) + + x = 1 - sr.flatten() / denom + + omp = nnm_internal0.tocsr() + omp.data[:] = 1 + NNM = omp.multiply(x[:, None]) + NNM = (NNM + Dk).tolil() + NNM.setdiag(0) + + logger.info("Concatenating SAM objects...") + sam3 = _concatenate_sam(sams, NNM) + + sam3.adata.obs["species"] = pd.Categorical(species_list) + + sam3.adata.uns["gnnm_corr"] = mdata.get("gnnm_corr", None) + + if umap: + logger.info("Computing UMAP projection...") + maxiter = ( + UMAP_MAXITER_SMALL if sam3.adata.shape[0] <= UMAP_SIZE_THRESHOLD else UMAP_MAXITER_LARGE + ) + sc.tl.umap(sam3.adata, min_dist=UMAP_MIN_DIST, maxiter=maxiter) + return sam3 + + +def _concatenate_sam(sams: dict[str, SAM], nnm: sp.sparse.lil_matrix) -> SAM: + """Concatenate SAM objects.""" + acns = [] + exps = [] + agns = [] + sps = [] + for i, sid in enumerate(sams.keys()): + acns.append(_q(sams[sid].adata.obs_names)) + sps.append([sid] * acns[-1].size) + exps.append(sams[sid].adata.X) + agns.append(_q(sams[sid].adata.var_names)) + + acn = np.concatenate(acns) + agn = np.concatenate(agns) + sps_arr = np.concatenate(sps) + + xx = sp.sparse.block_diag(exps, format="csr") + + sam = SAM(counts=(xx, agn, acn)) + + sam.adata.uns["neighbors"] = {} + nnm = nnm.tocsr() + nnm.eliminate_zeros() + sam.adata.obsp["connectivities"] = nnm + sam.adata.uns["neighbors"]["params"] = { + "n_neighbors": 15, + "method": "umap", + "use_rep": "X", + "metric": "euclidean", + } + for i in sams: + for k in sams[i].adata.obs: + if sams[i].adata.obs[k].dtype.name == "category": + z = np.array(["unassigned"] * sam.adata.shape[0], dtype="object") + z[sps_arr == i] = _q(sams[i].adata.obs[k]) + sam.adata.obs[i + "_" + k] = pd.Categorical(z) + + a = [] + for i, sid in enumerate(sams.keys()): + a.extend(["batch" + str(i + 1)] * sams[sid].adata.shape[0]) + sam.adata.obs["batch"] = pd.Categorical(np.array(a)) + sam.adata.obs.columns = sam.adata.obs.columns.astype("str") + sam.adata.var.columns = sam.adata.var.columns.astype("str") + + for i in sam.adata.obs: + sam.adata.obs[i] = sam.adata.obs[i].astype("str") + + return sam diff --git a/src/samap/core/correlation.py b/src/samap/core/correlation.py new file mode 100644 index 0000000..e6bec8d --- /dev/null +++ b/src/samap/core/correlation.py @@ -0,0 +1,734 @@ +"""Gene-gene correlation refinement for the homology graph. + +Contains the numba-accelerated kernels for computing Pearson / Xi correlations +between homologous gene pairs across the stitched manifold, and the driver +routines that chunk the graph for parallel refinement. + +Implementation notes +-------------------- +The memory bottleneck here is ``Xavg = nnms @ Xs`` — an N × G_active matrix at +10-60% density (multiple GB at million-cell scale). It exists solely to feed +per-gene-pair Pearson correlations. We offer two paths: + +* **Materialized** (``batch_size=None``): builds the full ``Xavg`` up front. + Faster for moderate-scale runs (~3-5× at <10k cells); the right choice when + the estimated ``Xavg`` fits comfortably in memory. +* **Streaming** (``batch_size=int``): processes pairs in batches. For each + batch, computes ``Xavg`` only for the genes appearing in that batch's pairs + (at most ``2 * batch_size`` columns), correlates, discards. Peak memory + drops from O(N × G_active) to O(N × 2·batch_size). Some columns are + recomputed across batches if a gene appears in multiple pair-batches; + this is a cheap single-column SpMV and empirically <5% overhead. +* **Auto** (``batch_size="auto"``, the default): estimates the materialised + ``Xavg`` size from cell/gene counts, expression density, and kNN degree. + If the estimate is under ``correlation_mem_threshold`` (default 2 GB), + materialise; otherwise stream at ``batch_size=512``. See + :func:`_resolve_batch_size`. + +Separately, the numba kernel no longer uses a Python dict for species lookup. +Species cell ranges are passed as integer ``sp_starts`` / ``sp_lens`` arrays, +indexed by integer species ID. This is a prerequisite for a future CUDA port. +""" + +from __future__ import annotations + +import gc +import os +import warnings +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import scipy as sp +from numba import njit, prange +from numba.core.errors import NumbaPerformanceWarning, NumbaWarning + +from samap._logging import logger +from samap.utils import q as _q +from samap.utils import to_vn + +if TYPE_CHECKING: + from typing import Any + + from numpy.typing import NDArray + + from samap.sam import SAM + +warnings.filterwarnings("ignore", category=NumbaPerformanceWarning) +warnings.filterwarnings("ignore", category=NumbaWarning) + + +@njit(parallel=True) +def _replace(X: NDArray[Any], xi: NDArray[Any], yi: NDArray[Any]) -> NDArray[np.float64]: + """Per-pair Pearson over rows of a dense matrix (CPU fast path, numba). + + For each pair ``(xi[i], yi[i])``, gather the two rows of ``X`` and compute + Pearson correlation. ``prange`` parallelises over pairs. + + On CPU this outperforms the vectorised form below because numba can fuse + the mean/std/dot into a single tight loop per pair with no intermediate + ``(n_pairs × d)`` allocation. On GPU, use :func:`_replace_vectorized`. + """ + data = np.zeros(xi.size) + for i in prange(xi.size): + x = X[xi[i]] + y = X[yi[i]] + data[i] = ((x - x.mean()) * (y - y.mean()) / x.std() / y.std()).sum() / x.size + return data + + +def _replace_vectorized( + X: Any, + xi: Any, + yi: Any, + bk: Any, + batch_size: int | None = None, +) -> Any: + """Per-pair Pearson over rows of a dense matrix — vectorised, backend-agnostic. + + Algebraically identical to :func:`_replace`. Works on both numpy and cupy + via ``bk.xp`` dispatch. Each batch gathers ``2 × batch`` rows of ``X`` + and computes Pearson in one shot — one reduction kernel on GPU instead of + ``n_pairs`` launches. + + Parameters + ---------- + X + Dense (N × d) array on the active backend. + xi, yi + Integer index arrays of shape (n_pairs,) on the active backend. + bk + :class:`Backend` instance for xp dispatch. + batch_size + If ``None``, process all pairs at once (requires ``2 * n_pairs * d`` + floats of scratch). If an integer, chunk pairs to cap the working + set at ``2 * batch_size * d`` floats. Use when ``n_pairs × d × 8`` + bytes approaches device memory. + + Returns + ------- + Array of shape (n_pairs,) on the active backend, dtype float64. + """ + xp = bk.xp + n_pairs = xi.shape[0] + + def _one_batch(ii: Any, jj: Any) -> Any: + Xa = X[ii].astype(xp.float64, copy=True) + Xb = X[jj].astype(xp.float64, copy=True) + Xa -= Xa.mean(axis=1, keepdims=True) + Xb -= Xb.mean(axis=1, keepdims=True) + num = (Xa * Xb).sum(axis=1) + den = xp.sqrt((Xa * Xa).sum(axis=1) * (Xb * Xb).sum(axis=1)) + # den==0 (zero-variance row) → 0/0 → nan. Match _replace's behaviour + # (callers already do vals[isnan]=0). Suppress the expected warning. + with np.errstate(invalid="ignore"): + return num / den + + if batch_size is None or batch_size >= n_pairs: + return _one_batch(xi, yi) + + out = xp.empty(n_pairs, dtype=xp.float64) + for start in range(0, n_pairs, batch_size): + end = min(start + batch_size, n_pairs) + out[start:end] = _one_batch(xi[start:end], yi[start:end]) + return out + + +def replace_corr( + X: Any, + xi: Any, + yi: Any, + bk: Any = None, + batch_size: int | None = None, +) -> Any: + """Dispatch per-pair Pearson to numba (CPU) or vectorised (GPU). + + Drop-in entry point for callers that have a :class:`Backend`. When + ``bk is None`` or ``not bk.gpu``, uses the numba :func:`_replace` (fastest + on CPU: fused loop, no intermediate allocation). When ``bk.gpu`` is True, + uses :func:`_replace_vectorized` (single large reduction kernel on + device). + + Numerically equivalent to :func:`_replace` to machine precision. + """ + if bk is None or not bk.gpu: + # CPU fast path — numba prange beats numpy vectorisation here + return _replace(np.asarray(X), np.asarray(xi), np.asarray(yi)) + return _replace_vectorized(X, xi, yi, bk, batch_size=batch_size) + + +@njit +def nb_unique1d(ar: NDArray[Any]) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any], NDArray[Any]]: + """Find unique elements of an array (numba-optimized).""" + ar = ar.flatten() + perm = ar.argsort(kind="mergesort") + aux = ar[perm] + mask = np.empty(aux.shape, dtype=np.bool_) + mask[:1] = True + mask[1:] = aux[1:] != aux[:-1] + + imask = np.cumsum(mask) - 1 + inv_idx = np.empty(mask.shape, dtype=np.intp) + inv_idx[perm] = imask + idx = np.append(np.nonzero(mask)[0], mask.size) + + return aux[mask], perm[mask], inv_idx, np.diff(idx) + + +@njit +def _xicorr(X: NDArray[Any], Y: NDArray[Any]) -> float: + """Xi correlation coefficient.""" + n = X.size + xi = np.argsort(X, kind="quicksort") + Y = Y[xi] + _, _, b, c = nb_unique1d(Y) + r = np.cumsum(c)[b] + _, _, b, c = nb_unique1d(-Y) + left_counts = np.cumsum(c)[b] + denominator = 2 * (left_counts * (n - left_counts)).sum() + if denominator > 0: + return 1 - n * np.abs(np.diff(r)).sum() / denominator + else: + return 0.0 + + +# --------------------------------------------------------------------------- +# CUDA-porting notes for _corr_kernel +# --------------------------------------------------------------------------- +# The kernel below is structurally dict-free and uses only integer species +# indexing, but is NOT yet directly compilable under @cuda.jit. Remaining +# blockers and the restructuring they imply: +# +# 1. Dynamic allocation. `np.zeros(n)`, `np.empty(total)` allocate per-call +# with runtime sizes. CUDA kernels cannot heap-allocate. +# → Fix: do not densify. Instead of scattering CSC column → dense N-vector +# → slicing by species, iterate directly over the CSC nonzeros, +# test each index against the two species ranges [s1,s1+l1)∪[s2,s2+l2), +# and accumulate Pearson sums (sum_x, sum_y, sum_xx, sum_xy, sum_yy, +# count) in scalar registers. This is a two-pass loop per pair (mean +# first, then centred products) but needs only O(1) per-thread state. +# +# 2. Fancy-index scatter `x[pl1i] = pl1d`. Not supported on cuda device +# arrays. Becomes irrelevant once (1) removes the dense scatter. +# +# 3. `prange` → `cuda.grid(1)` + `if j >= n_pairs: return`. Trivial. +# +# 4. `.mean()`, `.std()` on arrays. Not available on cuda.local arrays. +# Irrelevant after (1) — sums are accumulated manually. +# +# 5. `_xicorr` device call. Xi correlation uses `np.argsort` and rank +# computations that have no @cuda.jit equivalent. A GPU Xi path would +# need cub/thrust sort (via cupy) + a separate device kernel for the +# rank statistics. For a first CUDA port, gate the kernel to +# `pearson=True` only and fall back to CPU for Xi. +# +# With (1), the kernel becomes a CSR/CSC-walking Pearson that reads CSC +# columns twice (two passes) and writes one float per pair. Workspace: +# ~6 float64 + ~6 int64 registers per thread. No shared memory needed. +# --------------------------------------------------------------------------- + + +@njit(parallel=True) +def _corr_kernel( + p1: NDArray[np.int64], + p2: NDArray[np.int64], + ps1: NDArray[np.int64], + ps2: NDArray[np.int64], + sp_starts: NDArray[np.int64], + sp_lens: NDArray[np.int64], + indptr: NDArray[Any], + indices: NDArray[Any], + data: NDArray[Any], + n: int, + pearson: bool, +) -> NDArray[np.float64]: + """Compute per-pair gene correlations (dict-free, GPU-portable). + + This replaces the original dict-based ``_refine_corr_kernel``. Species + cell ranges are passed as integer ``sp_starts`` / ``sp_lens`` arrays + (contiguous by construction in the stitched manifold), indexed by integer + species ID — no Python dict inside the hot loop. Numerically bit-identical + to the original (same Pearson formula, same Xi path). + + Parameters + ---------- + p1, p2 + Column indices into the CSC ``Xavg``, one pair per entry. + ps1, ps2 + Integer species IDs for each gene (index into ``sp_starts``). + sp_starts, sp_lens + Species ``s`` spans cells ``[sp_starts[s], sp_starts[s]+sp_lens[s])``. + indptr, indices, data + CSC components of ``Xavg``. + n + Number of cells (rows in ``Xavg``). + pearson + True → Pearson; False → Xi correlation. + """ + res = np.zeros(p1.size) + + for j in prange(len(p1)): + j1, j2 = p1[j], p2[j] + pl1d = data[indptr[j1] : indptr[j1 + 1]] + pl1i = indices[indptr[j1] : indptr[j1 + 1]] + + sc1d = data[indptr[j2] : indptr[j2 + 1]] + sc1i = indices[indptr[j2] : indptr[j2 + 1]] + + x = np.zeros(n) + x[pl1i] = pl1d + y = np.zeros(n) + y[sc1i] = sc1d + + s1 = sp_starts[ps1[j]] + l1 = sp_lens[ps1[j]] + s2 = sp_starts[ps2[j]] + l2 = sp_lens[ps2[j]] + + total = l1 + l2 + xx = np.empty(total) + yy = np.empty(total) + xx[:l1] = x[s1 : s1 + l1] + xx[l1:] = x[s2 : s2 + l2] + yy[:l1] = y[s1 : s1 + l1] + yy[l1:] = y[s2 : s2 + l2] + + if pearson: + c = ((xx - xx.mean()) * (yy - yy.mean()) / xx.std() / yy.std()).sum() / xx.size + else: + c = _xicorr(xx, yy) + res[j] = c + return res + + +def _compute_pair_corrs( + nnms: Any, + Xs: Any, + p: NDArray[np.int64], + ps_int: NDArray[np.int64], + sp_starts: NDArray[np.int64], + sp_lens: NDArray[np.int64], + n: int, + corr_mode: str, + batch_size: int | None, +) -> NDArray[np.float64]: + """Compute correlations for all gene pairs, materialised or streaming. + + Parameters + ---------- + nnms + (N × N) row-normalised neighbour-averaging operator (CSR). + Xs + (N × G_active) block-diagonal expression matrix. CSC preferred for + column slicing in the streaming path. + p + (n_pairs × 2) integer column-indices into ``Xs`` for each gene pair. + ps_int + (n_pairs × 2) integer species IDs for each gene pair. + sp_starts, sp_lens, n + Species layout (see :func:`_corr_kernel`). + corr_mode + ``"pearson"`` or ``"xi"``. + batch_size + ``None`` → materialise full ``Xavg = nnms @ Xs`` (legacy path, + golden-compatible). ``int`` → stream in pair-batches; per batch, + compute only the ≤ ``2 * batch_size`` columns of ``Xavg`` actually + needed, correlate, discard. Peak memory O(N × 2·batch_size) instead + of O(N × G_active). + """ + pearson = corr_mode == "pearson" + p1 = np.ascontiguousarray(p[:, 0], dtype=np.int64) + p2 = np.ascontiguousarray(p[:, 1], dtype=np.int64) + ps1 = np.ascontiguousarray(ps_int[:, 0], dtype=np.int64) + ps2 = np.ascontiguousarray(ps_int[:, 1], dtype=np.int64) + + if batch_size is None: + # --- Materialised path (golden-compatible) -------------------------- + Xavg = nnms.dot(Xs).tocsc() + return _corr_kernel( + p1, + p2, + ps1, + ps2, + sp_starts, + sp_lens, + Xavg.indptr, + Xavg.indices, + Xavg.data, + n, + pearson, + ) + + # --- Streaming path ----------------------------------------------------- + if not sp.sparse.isspmatrix_csc(Xs): + Xs = Xs.tocsc() + + n_pairs = p1.size + res = np.zeros(n_pairs) + + for start in range(0, n_pairs, batch_size): + end = min(start + batch_size, n_pairs) + + p1b = p1[start:end] + p2b = p2[start:end] + + # Genes needed for this batch (sorted, unique) — at most 2·batch_size. + needed = np.unique(np.concatenate((p1b, p2b))) + # Map original column index → local column index in Xavg_batch. + # `needed` is sorted so searchsorted gives the exact local position. + p1_local = np.searchsorted(needed, p1b).astype(np.int64) + p2_local = np.searchsorted(needed, p2b).astype(np.int64) + + # One SpMM for just the needed columns, then CSC for kernel access. + Xavg_batch = nnms.dot(Xs[:, needed]).tocsc() + + res[start:end] = _corr_kernel( + p1_local, + p2_local, + np.ascontiguousarray(ps1[start:end]), + np.ascontiguousarray(ps2[start:end]), + sp_starts, + sp_lens, + Xavg_batch.indptr, + Xavg_batch.indices, + Xavg_batch.data, + n, + pearson, + ) + + del Xavg_batch + gc.collect() + + return res + + +def _resolve_batch_size( + batch_size: int | str | None, + nnms: Any, + Xs: Any, + mem_threshold_gb: float = 2.0, +) -> int | None: + """Auto-select batched vs materialised correlation based on estimated memory. + + The materialised path (``batch_size=None``) is 3-5× faster on small data + (fewer SpMM dispatches, better cache reuse) but requires holding the full + ``Xavg = nnms @ Xs`` in memory. The streaming path (``batch_size=int``) + caps memory but pays per-batch overhead. + + Heuristic: estimate the materialised ``Xavg`` size. The output density + after kNN smoothing is roughly ``1 - (1 - p)^k`` where ``p`` is the input + expression density and ``k`` is the average neighbour degree — each output + entry is zero only if all ``k`` contributing inputs are zero. If the + estimate is comfortably under ``mem_threshold_gb``, materialise; otherwise + stream at 512. + + Parameters + ---------- + batch_size + User-supplied batch_size. If not the string ``"auto"``, returned + unchanged (respects explicit user choice including ``None``). + nnms + Row-normalised averaging operator (N × N sparse). + Xs + Block-diagonal expression (N × G_active sparse). + mem_threshold_gb + If estimated ``Xavg`` size is below this, materialise. Default 2 GB + leaves ample headroom on a 16 GB laptop; large-memory nodes can + raise it via ``correlation_mem_threshold``. + + Returns + ------- + ``None`` to materialise, or an integer batch size to stream. + """ + if batch_size != "auto": + return batch_size + + n_cells, n_genes = Xs.shape + if n_cells == 0 or n_genes == 0: + return None # trivial — materialise + + k_avg = nnms.nnz / max(nnms.shape[0], 1) + expr_density = Xs.nnz / (n_cells * n_genes) + # Output entry is zero iff all k contributing inputs are zero: (1-p)^k. + # This overestimates density slightly (ignores structural zeros from + # block-diag Xs outside a species' gene range), which is the safe direction. + out_density = min(1.0, 1.0 - (1.0 - expr_density) ** max(k_avg, 1.0)) + + # CSC storage: data (float64) + indices (int32) + indptr. Dominated by + # data + indices ≈ 12 bytes/nonzero. Use 12 not 8 to be conservative. + est_bytes = n_cells * n_genes * out_density * 12.0 + est_gb = est_bytes / 1e9 + + if est_gb < mem_threshold_gb: + logger.info( + "Correlation: estimated Xavg %.3f GB (density~%.1f%%, %d cells x %d genes) " + "< %.1f GB threshold — using materialised path.", + est_gb, + out_density * 100, + n_cells, + n_genes, + mem_threshold_gb, + ) + return None + + logger.info( + "Correlation: estimated Xavg %.3f GB (density~%.1f%%, %d cells x %d genes) " + ">= %.1f GB threshold — using streaming path (batch_size=512).", + est_gb, + out_density * 100, + n_cells, + n_genes, + mem_threshold_gb, + ) + return 512 + + +def _refine_corr( + sams: dict[str, SAM], + st: SAM, + gnnm: sp.sparse.csr_matrix, + gns_dict: dict[str, NDArray[Any]], + corr_mode: str = "pearson", + THR: float = 0, + use_seq: bool = False, + T1: float = 0.25, + NCLUSTERS: int = 1, + ncpus: int | None = None, + wscale: bool = False, + batch_size: int | str | None = "auto", + correlation_mem_threshold: float = 2.0, +) -> sp.sparse.csr_matrix: + """Refine correlation matrix for homology graph. + + Parameters + ---------- + batch_size : int | str | None + ``"auto"`` (default) → :func:`_resolve_batch_size` decides: materialise + when the estimated ``Xavg`` fits under ``correlation_mem_threshold`` + GB (3-5× faster on small data), stream at 512 otherwise. ``None`` + forces the materialised path unconditionally. An integer forces + streaming at that batch size. See the module docstring for the full + memory model. + correlation_mem_threshold : float + Memory threshold (GB) for ``batch_size="auto"``. Default 2.0 — leaves + ample headroom on a 16 GB laptop. Raise on large-memory nodes to keep + the faster materialised path for larger datasets; lower for + memory-constrained environments. + (other parameters unchanged) + """ + if ncpus is None: + ncpus = os.cpu_count() or 1 + + gns = np.concatenate(list(gns_dict.values())) + + x, y = gnnm.nonzero() + sam = next(iter(sams.values())) + cl = sam.leiden_clustering(gnnm, res=0.5) + ix = np.argsort(cl) + NGPC = gns.size // NCLUSTERS + 1 + + ixs = [] + for i in range(NCLUSTERS): + ixs.append(np.sort(ix[i * NGPC : (i + 1) * NGPC])) + + assert np.concatenate(ixs).size == gns.size + + GNNMSUBS = [] + GNSUBS = [] + for i in range(len(ixs)): + ixs[i] = np.unique(np.append(ixs[i], gnnm[ixs[i], :].nonzero()[1])) + gnnm_sub = gnnm[ixs[i], :][:, ixs[i]] + gnsub = gns[ixs[i]] + gns_dict_sub = {} + for sid in gns_dict: + gn = gns_dict[sid] + gns_dict_sub[sid] = gn[np.isin(gn, gnsub)] + + gnnm2_sub = _refine_corr_parallel( + sams, + st, + gnnm_sub, + gns_dict_sub, + corr_mode=corr_mode, + THR=THR, + use_seq=use_seq, + T1=T1, + ncpus=ncpus, + wscale=wscale, + batch_size=batch_size, + correlation_mem_threshold=correlation_mem_threshold, + ) + GNNMSUBS.append(gnnm2_sub) + GNSUBS.append(gnsub) + gc.collect() + + indices_list = [] + pairs_list = [] + for i in range(len(GNNMSUBS)): + indices_list.append(np.unique(np.sort(np.vstack(GNNMSUBS[i].nonzero()).T, axis=1), axis=0)) + pairs_list.append(GNSUBS[i][indices_list[-1]]) + + GNS = pd.DataFrame(data=np.arange(gns.size)[None, :], columns=gns) + gnnm3 = sp.sparse.lil_matrix(gnnm.shape) + for i in range(len(indices_list)): + x, y = GNS[pairs_list[i][:, 0]].values.flatten(), GNS[pairs_list[i][:, 1]].values.flatten() + gnnm3[x, y] = np.asarray( + GNNMSUBS[i][indices_list[i][:, 0], indices_list[i][:, 1]] + ).flatten() + + gnnm3 = gnnm3.tocsr() + x, y = gnnm3.nonzero() + gnnm3 = gnnm3.tolil() + gnnm3[y, x] = np.asarray(gnnm3[x, y].tocsr().todense()).flatten() + return gnnm3.tocsr() + + +def _refine_corr_parallel( + sams: dict[str, SAM], + st: SAM, + gnnm: sp.sparse.csr_matrix, + gns_dict: dict[str, NDArray[Any]], + corr_mode: str = "pearson", + THR: float = 0, + use_seq: bool = False, + T1: float = 0.0, + ncpus: int | None = None, + wscale: bool = False, + batch_size: int | str | None = "auto", + correlation_mem_threshold: float = 2.0, +) -> sp.sparse.csr_matrix: + """Parallel correlation refinement. + + Parameters + ---------- + batch_size + ``"auto"`` (default) → pick materialised vs streaming based on the + estimated ``Xavg`` memory footprint (see :func:`_resolve_batch_size`). + ``None`` → force materialised (legacy, fast on small data). Integer → + force streaming at that batch size. + correlation_mem_threshold + GB threshold for auto-selection. Default 2.0. + (other parameters unchanged) + """ + if ncpus is None: + ncpus = os.cpu_count() or 1 + + gn = np.concatenate(list(gns_dict.values())) + + Ws = [] + ix = [] + for sid in sams: + Ws.append(sams[sid].adata.var["weights"][gns_dict[sid]].values) + ix += [sid] * gns_dict[sid].size + ix = np.array(ix) + w = np.concatenate(Ws) + + w[w > T1] = 1 + w[w < 1] = 0 + + gnO = gn[w > 0] + ix = ix[w > 0] + gns_dictO = {} + for sid in gns_dict: + gns_dictO[sid] = gnO[ix == sid] + + gnnmO = gnnm[w > 0, :][:, w > 0] + x, y = gnnmO.nonzero() + + pairs = np.unique(np.sort(np.vstack((x, y)).T, axis=1), axis=0) + + xs = _q([i.split("_")[0] for i in gnO[pairs[:, 0]]]) + ys = _q([i.split("_")[0] for i in gnO[pairs[:, 1]]]) + pairs_species = np.vstack((xs, ys)).T + + nnm = st.adata.obsp["connectivities"] + xs_list = [] + nnms = [] + for i, sid in enumerate(sams.keys()): + batch_mask = (st.adata.obs["batch"] == f"batch{i + 1}").values + nnms.append(nnm[:, batch_mask]) + s1 = np.asarray(nnms[-1].sum(1)) + s1[s1 < 1e-3] = 1 + s1 = s1.flatten()[:, None] + nnms[-1] = nnms[-1].multiply(1 / s1) + + xs_list.append(sams[sid].adata[:, gns_dictO[sid]].X.astype("float32")) + + Xs = sp.sparse.block_diag(xs_list).tocsc() + nnms = sp.sparse.hstack(nnms).tocsr() + + # Resolve "auto" to a concrete batch_size now that nnms/Xs shapes are known. + batch_size = _resolve_batch_size(batch_size, nnms, Xs, correlation_mem_threshold) + + p = pairs + ps = pairs_species + + gnnm2 = gnnm.multiply(w[:, None]).multiply(w[None, :]).tocsr() + x, y = gnnm2.nonzero() + pairs = np.unique(np.sort(np.vstack((x, y)).T, axis=1), axis=0) + + # --- Species layout as integer start/len arrays (dict-free kernel) ------ + species = _q(st.adata.obs["species"]) + sidss = np.unique(species) + sp_starts = np.empty(sidss.size, dtype=np.int64) + sp_lens = np.empty(sidss.size, dtype=np.int64) + for i, sid in enumerate(sidss): + where = np.where(species == sid)[0] + sp_starts[i] = where[0] + sp_lens[i] = where.size + # Contiguity sanity check — true by construction in _concatenate_sam + # (cells are grouped by species); guard against future changes. + if where[-1] - where[0] + 1 != where.size: + raise RuntimeError( + f"Species {sid!r} cells are not contiguous in the stitched " + f"manifold. This violates the dict-free kernel's assumption " + f"(cells must be species-grouped). Please report." + ) + + # Map string species IDs in `ps` → integer positions in `sidss`. + sid_to_int = pd.Series(np.arange(sidss.size, dtype=np.int64), index=sidss) + ps_int = np.column_stack( + (sid_to_int.loc[ps[:, 0]].values, sid_to_int.loc[ps[:, 1]].values) + ).astype(np.int64) + + vals = _compute_pair_corrs( + nnms, + Xs, + p.astype(np.int64), + ps_int, + sp_starts, + sp_lens, + nnms.shape[0], + corr_mode, + batch_size, + ) + vals[np.isnan(vals)] = 0 + + CORR = dict(zip(to_vn(np.vstack((gnO[p[:, 0]], gnO[p[:, 1]])).T), vals)) + + for k in CORR: + CORR[k] = 0 if CORR[k] < THR else CORR[k] + if wscale: + id1, id2 = [x.split("_")[0] for x in k.split(";")] + weight1 = sams[id1].adata.var["weights"][k.split(";")[0]] + weight2 = sams[id2].adata.var["weights"][k.split(";")[1]] + CORR[k] = np.sqrt(CORR[k] * np.sqrt(weight1 * weight2)) + + CORR_arr = np.array([CORR[x] for x in to_vn(gn[pairs])]) + + gnnm3 = sp.sparse.lil_matrix(gnnm.shape) + + if use_seq: + gnnm3[pairs[:, 0], pairs[:, 1]] = ( + CORR_arr * np.asarray(gnnm2[pairs[:, 0], pairs[:, 1]]).flatten() + ) + gnnm3[pairs[:, 1], pairs[:, 0]] = ( + CORR_arr * np.asarray(gnnm2[pairs[:, 1], pairs[:, 0]]).flatten() + ) + else: + gnnm3[pairs[:, 0], pairs[:, 1]] = CORR_arr + gnnm3[pairs[:, 1], pairs[:, 0]] = CORR_arr + + gnnm3 = gnnm3.tocsr() + gnnm3.eliminate_zeros() + return gnnm3 diff --git a/src/samap/core/expand.py b/src/samap/core/expand.py new file mode 100644 index 0000000..c6edfc8 --- /dev/null +++ b/src/samap/core/expand.py @@ -0,0 +1,328 @@ +"""Neighborhood expansion for cluster-adaptive kNN stitching. + +Two implementations: + +* :func:`_smart_expand_matpow` — the original algorithm. Computes iterated + sparse matrix powers ``nnm^i`` to materialize hop-``i`` rings, then trims + each ring to a per-cell budget. Simple but densifies badly at scale: the + nnz of ``nnm^NH`` grows geometrically with NH and the kNN degree. + +* :func:`_smart_expand_bfs` — a per-cell budget-capped BFS. Each cell + independently walks its neighbourhood hop by hop, collecting nodes in + (hop, weight) priority order up to its budget. Never materializes matrix + powers; working set per cell is O(budget * k). Numba-parallel over cells. + +Both return a sparse matrix whose *structure* is what matters — the caller +immediately binarizes the output. The two algorithms agree exactly when +every cell's budget ≥ its reachable-within-NH-hops count. When budgets +truncate, they may select different neighbours at the margin because +``matpow`` ranks ring members by path-sum weight while ``bfs`` ranks by +max incoming edge weight. See ``tests/unit/test_expand.py`` for a +characterisation. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import scipy as sp +from numba import njit, prange + +if TYPE_CHECKING: + from typing import Any + + from numpy.typing import NDArray + + +# --------------------------------------------------------------------------- +# Legacy matpow implementation +# --------------------------------------------------------------------------- + + +def _sparse_knn_ks(D: sp.sparse.coo_matrix, ks: NDArray[Any]) -> sp.sparse.coo_matrix: + """Keep variable top-k values per row in sparse matrix.""" + D1 = D.tocoo() + idr = np.argsort(D1.row) + D1.row[:] = D1.row[idr] + D1.col[:] = D1.col[idr] + D1.data[:] = D1.data[idr] + + row, ind = np.unique(D1.row, return_index=True) + ind = np.append(ind, D1.data.size) + for i in range(ind.size - 1): + idx = np.argsort(D1.data[ind[i] : ind[i + 1]]) + k = ks[row[i]] + if idx.size > k: + idx = idx[:-k] if k != 0 else idx + D1.data[np.arange(ind[i], ind[i + 1])[idx]] = 0 + D1.eliminate_zeros() + return D1 + + +def _smart_expand_matpow( + nnm: sp.sparse.csr_matrix, K: NDArray[Any], NH: int = 3 +) -> sp.sparse.csr_matrix: + """Original matrix-power neighbourhood expansion. + + Builds hop-``i`` rings via sparse matrix powers, then greedily fills + each cell's budget ring by ring. Kept for regression testing and as a + fallback. + """ + stage0 = nnm.copy() + S = [stage0] + running = stage0 + for i in range(1, NH + 1): + stage = running.dot(stage0) + running = stage + stage = stage.tolil() + for j in range(i): + stage[S[j].nonzero()] = 0 + stage = stage.tocsr() + S.append(stage) + + for i in range(len(S)): + s = _sparse_knn_ks(S[i], K).tocsr() + a, c = np.unique(s.nonzero()[0], return_counts=True) + numnz = np.zeros(s.shape[0], dtype="int32") + numnz[a] = c + K = K - numnz + K[K < 0] = 0 + S[i] = s + res = S[0] + for i in range(1, len(S)): + res = res + S[i] + return res + + +# --------------------------------------------------------------------------- +# BFS implementation +# --------------------------------------------------------------------------- + + +@njit(parallel=True, cache=True) +def _bfs_expand_kernel( + indptr: NDArray[np.int64], + indices: NDArray[np.int64], + data: NDArray[np.float64], + K_arr: NDArray[np.int64], + NH: np.int64, + n_cells: np.int64, + buf_size: np.int64, + out_cols: NDArray[np.int64], + out_offsets: NDArray[np.int64], + out_counts: NDArray[np.int64], +) -> None: + """Per-cell budget-capped BFS over a CSR adjacency matrix. + + For each cell ``c`` (parallelized over cells): + + 1. Seed the frontier with ``c``'s direct neighbours, weighted by the + corresponding edge weights in ``nnm``. + 2. At each hop, sort the frontier by weight (descending) and collect + unvisited nodes in that order until the cell's budget is met or the + ring is exhausted. + 3. Expand only from *collected* nodes — their neighbours form the next + frontier. Duplicates are resolved at collection time (first hit by + highest weight wins, since the frontier is sorted). + 4. Stop after ``NH+1`` hops or when the budget is filled. + + Writes collected column indices for cell ``c`` into + ``out_cols[out_offsets[c] : out_offsets[c] + out_counts[c]]``. Slots + beyond ``out_counts[c]`` are unused (budget unfilled). + """ + for c in prange(n_cells): + budget = K_arr[c] + if budget == 0: + out_counts[c] = 0 + continue + + # Per-cell visited mask. Allocated inside the prange body so each + # parallel iteration gets its own — numba makes this thread-local. + visited = np.zeros(n_cells, dtype=np.bool_) + visited[c] = True + + # Frontier double-buffer. buf_size is a safe upper bound on the + # number of (node, weight) entries in a single hop's frontier: + # at most ``budget`` nodes are collected per hop, each contributing + # at most ``max_deg`` neighbours → budget * max_deg. The extra + # +max_deg headroom covers the initial seed from ``c`` itself. + front_idx = np.empty(buf_size, dtype=np.int64) + front_w = np.empty(buf_size, dtype=np.float64) + next_idx = np.empty(buf_size, dtype=np.int64) + next_w = np.empty(buf_size, dtype=np.float64) + + # Seed: direct neighbours of c. + n_front = 0 + for p in range(indptr[c], indptr[c + 1]): + front_idx[n_front] = indices[p] + front_w[n_front] = data[p] + n_front += 1 + + out_base = out_offsets[c] + n_collected = 0 + + for _hop in range(NH + 1): + if n_front == 0 or n_collected >= budget: + break + + # Rank this hop's candidates by weight, descending. + # argsort is ascending → negate to get descending. + order = np.argsort(-front_w[:n_front]) + + n_next = 0 + for oi in range(n_front): + node = front_idx[order[oi]] + if visited[node]: + continue + if n_collected >= budget: + break + + # Collect. + out_cols[out_base + n_collected] = node + n_collected += 1 + visited[node] = True + + # Expand: push this node's neighbours into next frontier. + # Visited-filtering here is a *conservative* prune — more + # filtering happens at collection time on the next hop + # (handles duplicates within next_idx too). + for p in range(indptr[node], indptr[node + 1]): + nb = indices[p] + if not visited[nb] and n_next < buf_size: + next_idx[n_next] = nb + next_w[n_next] = data[p] + n_next += 1 + + # Swap buffers. + front_idx, next_idx = next_idx, front_idx + front_w, next_w = next_w, front_w + n_front = n_next + + out_counts[c] = n_collected + + +def _smart_expand_bfs( + nnm: sp.sparse.csr_matrix, K: NDArray[Any], NH: int = 3 +) -> sp.sparse.csr_matrix: + """BFS-based neighbourhood expansion. + + Algorithmically equivalent to :func:`_smart_expand_matpow` *when every + cell's budget covers its full reachable set within ``NH+1`` hops*. When + budgets truncate, the two may pick different marginal neighbours because + they rank ring members differently (path-sum vs. max-edge weight). The + output is binarized by the caller so only membership matters. + + Parameters + ---------- + nnm + ``(n, n)`` CSR adjacency / connectivity matrix. Need not be + symmetric; only outgoing edges are walked. + K + ``(n,)`` per-cell collection budget (typically the cell's cluster + size). + NH + Maximum number of *extra* hops beyond direct neighbours. Total + hops walked is ``NH + 1``. + + Returns + ------- + ``(n, n)`` CSR matrix with ``1.0`` at every collected ``(cell, + neighbour)`` pair. + """ + nnm = nnm.tocsr() + n = nnm.shape[0] + + indptr = np.ascontiguousarray(nnm.indptr, dtype=np.int64) + indices = np.ascontiguousarray(nnm.indices, dtype=np.int64) + data = np.ascontiguousarray(nnm.data, dtype=np.float64) + K = np.ascontiguousarray(K, dtype=np.int64) + + if n == 0: + return sp.sparse.csr_matrix((0, 0), dtype=np.float64) + + # Preallocate output by budget. A cell may collect fewer than its + # budget if its reachable component is small; out_counts records actuals. + total = int(K.sum()) + out_cols = np.empty(max(total, 1), dtype=np.int64) + out_offsets = np.zeros(n + 1, dtype=np.int64) + np.cumsum(K, out=out_offsets[1:]) + out_counts = np.zeros(n, dtype=np.int64) + + # Kernel buffer sizing: each hop expands from ≤ max_K collected nodes, + # each with ≤ max_deg outgoing edges. + degs = np.diff(indptr) + max_deg = int(degs.max()) if degs.size else 0 + max_K = int(K.max()) if K.size else 0 + buf_size = max(max_K * max_deg + max_deg, 1) + + if total > 0 and max_deg > 0: + _bfs_expand_kernel( + indptr, + indices, + data, + K, + np.int64(NH), + np.int64(n), + np.int64(buf_size), + out_cols, + out_offsets, + out_counts, + ) + + # Compact output: each cell's block in out_cols is sized by budget but + # only the first out_counts[c] entries are valid. Build a mask. + if total == 0: + return sp.sparse.csr_matrix((n, n), dtype=np.float64) + + block_ids = np.repeat(np.arange(n, dtype=np.int64), K) + within = np.arange(total, dtype=np.int64) - out_offsets[block_ids] + valid = within < out_counts[block_ids] + + rows = block_ids[valid] + cols = out_cols[:total][valid] + vals = np.ones(rows.size, dtype=np.float64) + + return sp.sparse.csr_matrix((vals, (rows, cols)), shape=(n, n)) + + +# --------------------------------------------------------------------------- +# Public dispatch +# --------------------------------------------------------------------------- + + +def _smart_expand( + nnm: sp.sparse.csr_matrix, + K: NDArray[Any], + NH: int = 3, + *, + legacy: bool = False, + bk: Any = None, +) -> sp.sparse.csr_matrix: + """Expand each cell's neighbourhood to a per-cell budget via multi-hop walk. + + Parameters + ---------- + nnm + ``(n, n)`` sparse connectivity matrix. + K + ``(n,)`` per-cell budget (number of neighbours to collect). + NH + Number of extra hops beyond direct neighbours (default 3 → walks up + to 4 hops). + legacy + If ``False`` (default), use the BFS algorithm — ~5× faster at 3k cells + and memory-bounded. If ``True``, use the original matrix-power + algorithm. Note: matpow wastes ~1 budget slot per cell on self-loops + (a cell's 2-hop neighbourhood always includes itself); BFS avoids this + and is arguably more correct, but will select slightly different + marginal neighbours (~1% edge difference on the golden-suite data). + Set ``legacy=True`` only if you need bit-exact reproduction of + pre-3.0 SAMap output. + bk + Array backend. Currently unused (both paths are CPU-only numba); + threaded through for future GPU work. + """ + if legacy: + return _smart_expand_matpow(nnm, K, NH=NH) + return _smart_expand_bfs(nnm, K, NH=NH) diff --git a/src/samap/core/homology.py b/src/samap/core/homology.py new file mode 100644 index 0000000..898e728 --- /dev/null +++ b/src/samap/core/homology.py @@ -0,0 +1,279 @@ +"""Gene homology graph construction from BLAST results. + +Functions for building, coarsening, and filtering the cross-species gene +homology graph that seeds the SAMap iteration. +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import scipy as sp + +from samap.utils import coo_to_csr_overwrite, df_to_dict +from samap.utils import q as _q + +if TYPE_CHECKING: + from typing import Any + + from numpy.typing import NDArray + + from samap.sam import SAM + + +def _tanh_scale(x: NDArray[Any], scale: float = 10, center: float = 0.5) -> NDArray[Any]: + """Apply tanh scaling to values.""" + return center + (1 - center) * np.tanh(scale * (x - center)) + + +def _calculate_blast_graph( + ids: list[str], + f_maps: str = "maps/", + eval_thr: float = 1e-6, + reciprocate: bool = False, +) -> tuple[sp.sparse.csr_matrix, NDArray[Any], dict[str, NDArray[Any]]]: + """Calculate gene homology graph from BLAST results.""" + gns: list[str] = [] + Xs: list[Any] = [] + Ys: list[Any] = [] + Vs: list[Any] = [] + + for i in range(len(ids)): + id1 = ids[i] + for j in range(i, len(ids)): + id2 = ids[j] + if i != j: + if os.path.exists(f_maps + f"{id1}{id2}"): + fA = f_maps + f"{id1}{id2}/{id1}_to_{id2}.txt" + fB = f_maps + f"{id1}{id2}/{id2}_to_{id1}.txt" + elif os.path.exists(f_maps + f"{id2}{id1}"): + fA = f_maps + f"{id2}{id1}/{id1}_to_{id2}.txt" + fB = f_maps + f"{id2}{id1}/{id2}_to_{id1}.txt" + else: + raise FileNotFoundError( + f"BLAST mapping tables with the input IDs ({id1} and {id2}) " + f"not found in the specified path." + ) + + A = pd.read_csv(fA, sep="\t", header=None, index_col=0) + B = pd.read_csv(fB, sep="\t", header=None, index_col=0) + + A.columns = A.columns.astype(" NDArray[np.str_]: + """Add species prefix to gene names.""" + x = [str(item).split("_")[0] for item in data] + vn = [] + for i, g in enumerate(data): + if x[i] != pre: + vn.append(pre + "_" + g) + else: + vn.append(g) + return np.array(vn).astype("str").astype("object") + + +def _coarsen_blast_graph( + gnnm: sp.sparse.csr_matrix, + gns: NDArray[Any], + names: dict[str, Any], +) -> tuple[sp.sparse.csr_matrix, dict[str, NDArray[Any]], NDArray[Any]]: + """Coarsen BLAST graph by collapsing transcripts to genes.""" + gnnm = gnnm.tocsr() + gnnm.eliminate_zeros() + + sps = np.array([x.split("_")[0] for x in gns]) + sids = np.unique(sps) + ss = [] + for sid in sids: + n = names.get(sid) + if n is not None: + n = np.array(n) + n = (sid + "_" + n.astype("object")).astype("str") + s1 = pd.Series(index=n[:, 0], data=n[:, 1]) + g = gns[sps == sid] + g = g[np.isin(g, n[:, 0], invert=True)] + s2 = pd.Series(index=g, data=g) + s = pd.concat([s1, s2]) + else: + s = pd.Series(index=gns[sps == sid], data=gns[sps == sid]) + ss.append(s) + ss_combined = pd.concat(ss) + ss_combined = ss_combined[np.unique(_q(ss_combined.index), return_index=True)[1]] + x, y = gnnm.nonzero() + s = pd.Series(data=gns, index=np.arange(gns.size)) + xn, yn = s[x].values, s[y].values + xg, yg = ss_combined[xn].values, ss_combined[yn].values + + da = gnnm.data + + zgu, ix, _ivx, cu = np.unique( + np.array([xg, yg]).astype("str"), + axis=1, + return_counts=True, + return_index=True, + return_inverse=True, + ) + + xgu, ygu = zgu[:, cu > 1] + xgyg = _q(xg.astype("object") + ";" + yg.astype("object")) + xguygu = _q(xgu.astype("object") + ";" + ygu.astype("object")) + + filt = np.isin(xgyg, xguygu) + + DF = pd.DataFrame(data=xgyg[filt][:, None], columns=["key"]) + DF["val"] = da[filt] + + dic = df_to_dict(DF, key_key="key") + + xgu = _q([x.split(";")[0] for x in dic]) + ygu = _q([x.split(";")[1] for x in dic]) + replz = _q([max(dic[x]) for x in dic]) + + xgu1, ygu1 = zgu[:, cu == 1] + xg = np.append(xgu1, xgu) + yg = np.append(ygu1, ygu) + da = np.append(da[ix][cu == 1], replz) + gn = np.unique(np.append(xg, yg)) + + s = pd.Series(data=np.arange(gn.size), index=gn) + xn, yn = s[xg].values, s[yg].values + gnnm = sp.sparse.coo_matrix((da, (xn, yn)), shape=(gn.size,) * 2).tocsr() + + f = np.asarray(gnnm.sum(1)).flatten() != 0 + gn = gn[f] + sps = np.array([x.split("_")[0] for x in gn]) + + gns_dict: dict[str, NDArray[Any]] = {} + for sid in sids: + gns_dict[sid] = gn[sps == sid] + + return gnnm, gns_dict, gn + + +def _filter_gnnm(gnnm: sp.sparse.csr_matrix, thr: float = 0.25) -> sp.sparse.csr_matrix: + """Filter edges in homology graph below threshold.""" + x, y = gnnm.nonzero() + mas = np.asarray(gnnm.max(1).todense()).flatten() + gnnm4 = gnnm.copy() + # Use np.asarray to handle both sparse matrix and numpy.matrix returns + edge_values = np.asarray(gnnm4[x, y]).flatten() + gnnm4.data[edge_values < mas[x] * thr] = 0 + gnnm4.eliminate_zeros() + x, y = gnnm4.nonzero() + z = gnnm4.data + # Symmetrise: ensure (y, x) has the (x, y) value. Original entries first, + # transpose second — last-write-wins matches the old LIL [y,x]=z behaviour. + return coo_to_csr_overwrite( + np.concatenate([x, y]), + np.concatenate([y, x]), + np.concatenate([z, z]), + gnnm4.shape, + ) + + +def _get_pairs( + sams: dict[str, SAM], + gnnm: sp.sparse.csr_matrix, + gns_dict: dict[str, NDArray[Any]], + NOPs1: int = 0, + NOPs2: int = 0, +) -> sp.sparse.csr_matrix: + """Get gene pairs weighted by SAM weights.""" + su = np.asarray(gnnm.max(1).todense()) + su[su == 0] = 1 + gnnm = gnnm.multiply(1 / su).tocsr() + Ws = {} + for sid in sams: + Ws[sid] = sams[sid].adata.var["weights"][gns_dict[sid]].values + + W = np.concatenate(list(Ws.values())) + W[W < 0.0] = 0 + W[W > 0.0] = 1 + + B = gnnm.multiply(W[None, :]).multiply(W[:, None]).tocsr() + B.eliminate_zeros() + + return B diff --git a/src/samap/core/knn.py b/src/samap/core/knn.py new file mode 100644 index 0000000..751fd54 --- /dev/null +++ b/src/samap/core/knn.py @@ -0,0 +1,218 @@ +"""Cross-species k-nearest-neighbour dispatch: CPU HNSW vs GPU brute-force. + +Why GPU brute-force instead of a GPU approximate index +------------------------------------------------------ +SAMap's ``_united_proj`` rebuilds its kNN index **every iteration** because +the joint-embedding ``wpca`` changes on each pass. HNSW graph construction is +O(n log n · M) with M≈48 — all of which is discarded after one query batch. +For n in the hundreds of thousands and d≈600, a single GPU GEMM (N_q × d +times d × N_d) followed by a per-row top-k is faster than building a CPU +HNSW graph and querying it once. ``GpuIndexFlatIP`` is also **exact**, so +there is no recall trade-off. + +TODO: at >1M points the O(N_q · N_d) brute-force memory for the distance +matrix starts to hurt. At that scale switch to ``GpuIndexIVFFlat`` (coarse +quantiser + short inverted lists), trading a small amount of recall for a +linear-in-N footprint. Not implemented here — current SAMap datasets top out +well below that threshold. + +Both FAISS and its GPU extensions are **optional** dependencies. If FAISS is +absent or is a CPU-only build, the GPU path gracefully falls back to hnswlib +with a warning. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import hnswlib +import numpy as np + +from samap._logging import logger +from samap.core._backend import Backend + +if TYPE_CHECKING: + from numpy.typing import NDArray + +# --- Optional faiss import -------------------------------------------------- +# faiss may be absent, or present as a CPU-only build (no StandardGpuResources). +# We detect both conditions at module load and gate the GPU path on them. + +try: + import faiss as _faiss + + HAS_FAISS: bool = True + _FAISS_GPU: bool = hasattr(_faiss, "StandardGpuResources") +except ImportError: + _faiss = None # type: ignore[assignment] + HAS_FAISS = False + _FAISS_GPU = False + + +__all__ = ["HAS_FAISS", "approximate_knn"] + + +def approximate_knn( + queries: Any, + database: Any, + k: int, + metric: str = "cosine", + bk: Backend | None = None, +) -> tuple[NDArray[Any], NDArray[Any]]: + """Cross-species approximate k-nearest-neighbours. + + Dispatches to FAISS-GPU brute-force on a CUDA backend (when faiss-gpu is + available) and falls back to hnswlib otherwise. Returns results in + hnswlib's convention: ``(indices, distances)`` where distances are + ``1 - cos(q, d)`` for the cosine metric. + + Parameters + ---------- + queries : array-like, shape (n_q, d) + Query vectors. May be numpy or cupy. + database : array-like, shape (n_d, d) + Database (index) vectors. May be numpy or cupy. + k : int + Neighbours to return per query. + metric : str + Distance metric. Only ``'cosine'`` is supported on the GPU path; + hnswlib additionally supports ``'l2'`` and ``'ip'``. + bk : Backend or None + Backend instance. ``None`` → a fresh CPU backend. + + Returns + ------- + indices : int ndarray, shape (n_q, k) + Row indices into ``database``. + distances : float ndarray, shape (n_q, k) + Distances — for cosine, ``1 - cos(q, d)`` in ``[0, 2]``. + Always returned on host (numpy). + """ + if bk is None: + bk = Backend("cpu") + + if bk.gpu and _FAISS_GPU: + return _faiss_gpu_knn(queries, database, k, metric, bk) + + if bk.gpu and not _FAISS_GPU: + logger.warning( + "GPU backend requested but faiss-gpu is not available; " + "falling back to CPU hnswlib for kNN." + ) + + return _hnswlib_knn(queries, database, k, metric) + + +# --------------------------------------------------------------------------- +# CPU path — HNSW via hnswlib +# --------------------------------------------------------------------------- + +# Default HNSW parameters — kept identical to the legacy inline implementation +# in projection._united_proj so the golden regression test is bit-stable. +_HNSW_EF: int = 200 +_HNSW_M: int = 48 + + +def _hnswlib_knn( + queries: Any, + database: Any, + k: int, + metric: str = "cosine", + *, + ef: int = _HNSW_EF, + M: int = _HNSW_M, + num_threads: int = -1, +) -> tuple[NDArray[Any], NDArray[Any]]: + """CPU approximate kNN via hnswlib. + + Builds a fresh HNSW index over ``database`` and queries it. The index is + discarded on return — callers that need the same database across many + queries should use hnswlib directly. + + ``num_threads`` controls parallelism for both index construction and + querying (``-1`` → all cores). The golden regression test monkeypatches + the ``hnswlib`` module reference in this file to force single-threaded + deterministic behaviour; keep the top-level ``import hnswlib`` intact + for that patch to work. + """ + # hnswlib requires host numpy arrays (any float dtype). + q = np.ascontiguousarray(np.asarray(queries, dtype=np.float32)) + db = np.ascontiguousarray(np.asarray(database, dtype=np.float32)) + + n_d, dim = db.shape + labels = np.arange(n_d) + + index = hnswlib.Index(space=metric, dim=dim) + index.init_index(max_elements=n_d, ef_construction=ef, M=M) + index.add_items(db, labels, num_threads=num_threads) + index.set_ef(ef) + + idx, dist = index.knn_query(q, k=k, num_threads=num_threads) + return idx, dist + + +# --------------------------------------------------------------------------- +# GPU path — FAISS GpuIndexFlatIP (exact brute-force) +# --------------------------------------------------------------------------- + + +def _faiss_gpu_knn( + queries: Any, + database: Any, + k: int, + metric: str, + bk: Backend, +) -> tuple[NDArray[Any], NDArray[Any]]: + """GPU exact kNN via FAISS ``GpuIndexFlatIP``. + + For cosine similarity we L2-normalise both ``queries`` and ``database`` + so that their inner product equals ``cos(q, d)``. FAISS returns the + top-k by *descending* IP; we convert to cosine *distance* (``1 - ip``) to + match hnswlib's output convention. + + FAISS requires ``float32`` C-contiguous inputs. Both numpy and cupy + arrays are accepted — FAISS reads cupy arrays directly via the CUDA + array interface (zero-copy when the layout already matches). + """ + if metric != "cosine": + raise ValueError( + f"_faiss_gpu_knn only supports metric='cosine', got {metric!r}. " + "Use the CPU (hnswlib) path for other metrics." + ) + + res = bk.faiss_gpu_resources() + if res is None: + # Should not happen — caller checks _FAISS_GPU — but be defensive. + logger.warning("faiss_gpu_resources() returned None; falling back to hnswlib.") + return _hnswlib_knn(queries, database, k, metric) + + # FAISS insists on float32, C-contiguous. We upload to device first so + # normalisation runs on GPU, then hand device arrays to FAISS. + xp = bk.xp + q_dev = xp.ascontiguousarray(bk.to_device(queries), dtype=xp.float32) + db_dev = xp.ascontiguousarray(bk.to_device(database), dtype=xp.float32) + + # L2-normalise rows in-place (safe: we own these copies) + q_norm = xp.linalg.norm(q_dev, axis=1, keepdims=True) + q_norm = xp.where(q_norm == 0, 1.0, q_norm) + q_dev /= q_norm + + db_norm = xp.linalg.norm(db_dev, axis=1, keepdims=True) + db_norm = xp.where(db_norm == 0, 1.0, db_norm) + db_dev /= db_norm + + dim = int(db_dev.shape[1]) + + # Flat inner-product index — exact search, no training needed. + cfg = _faiss.GpuIndexFlatConfig() + cfg.device = 0 # TODO: multi-GPU device selection + index = _faiss.GpuIndexFlatIP(res, dim, cfg) + index.add(db_dev) + + sims, idx = index.search(q_dev, k) + # sims is inner-product == cos(q,d) since inputs are unit-norm. + # Convert to cosine distance; bring to host for downstream CSR assembly + # which runs on CPU regardless of backend. + dists_host = 1.0 - bk.to_host(sims) + idx_host = bk.to_host(idx) + return idx_host, dists_host diff --git a/src/samap/core/mapping.py b/src/samap/core/mapping.py index 45363ce..83c188c 100644 --- a/src/samap/core/mapping.py +++ b/src/samap/core/mapping.py @@ -6,18 +6,13 @@ import os import time import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal -import hnswlib import numpy as np import pandas as pd -import scanpy as sc import scipy as sp -from numba import njit, prange -from numba.core.errors import NumbaPerformanceWarning, NumbaWarning -from samalg import SAM -from sklearn.preprocessing import StandardScaler +from samap import _rsc_compat from samap._constants import ( DEFAULT_CROSS_K, DEFAULT_EVAL_THRESHOLD, @@ -37,21 +32,26 @@ UMAP_SIZE_THRESHOLD, ) from samap._logging import logger -from samap.utils import df_to_dict, prepend_var_prefix, sparse_knn, to_vn +from samap.sam import SAM +from samap.utils import prepend_var_prefix +from samap.utils import q as _q + +from ._backend import Backend +from .coarsening import _mapper +from .correlation import _refine_corr +from .homology import ( + _calculate_blast_graph, + _coarsen_blast_graph, + _filter_gnnm, + _get_pairs, +) +from .projection import _projection_precompute, prepare_SAMap_loadings if TYPE_CHECKING: from typing import Any from numpy.typing import NDArray -warnings.filterwarnings("ignore", category=NumbaPerformanceWarning) -warnings.filterwarnings("ignore", category=NumbaWarning) - - -def _q(x: Any) -> NDArray[Any]: - """Convert input to numpy array.""" - return np.array(list(x)) - class SAMAP: """Self-Assembling Manifold Mapping for cross-species single-cell analysis. @@ -89,6 +89,9 @@ class SAMAP: eval_thr : float, optional E-value threshold for BLAST results filtering. Default 1e-6. + backend : {"auto", "cpu", "cuda"}, optional + Compute backend. "auto" picks CUDA if a GPU is available, else CPU. + Attributes ---------- sams : dict @@ -115,7 +118,11 @@ def __init__( gnnm: tuple[Any, NDArray[Any], dict[str, NDArray[Any]]] | None = None, save_processed: bool = True, eval_thr: float = DEFAULT_EVAL_THRESHOLD, + backend: Literal["auto", "cpu", "cuda"] = "auto", ) -> None: + self._bk = Backend(backend) + logger.info("Using backend: %s", self._bk.device) + for key, data in sams.items(): if not (isinstance(data, str | SAM)): raise TypeError(f"Input data {key} must be either a path or a SAM object.") @@ -209,7 +216,7 @@ def __init__( if not sp.sparse.issparse(sams[sid].adata.X): sams[sid].adata.X = sp.sparse.csr_matrix(sams[sid].adata.X) - smap = _Samap_Iter(sams, gnnm_matrix, gns_dict, keys=keys) + smap = _Samap_Iter(sams, gnnm_matrix, gns_dict, keys=keys, bk=self._bk) self.sams = sams self.gnnm = gnnm_matrix self.gns_dict = gns_dict @@ -340,7 +347,13 @@ def run( if self.samap.adata.shape[0] <= UMAP_SIZE_THRESHOLD else UMAP_MAXITER_LARGE ) - sc.tl.umap(self.samap.adata, min_dist=UMAP_MIN_DIST, init_pos="random", maxiter=maxiter) + _rsc_compat.umap( + self.samap.adata, + self._bk, + min_dist=UMAP_MIN_DIST, + init_pos="random", + maxiter=maxiter, + ) ix = pd.Series(data=np.arange(samap.adata.shape[1]), index=samap.adata.var_names)[ gns @@ -393,7 +406,13 @@ def run_umap(self) -> None: if self.samap.adata.shape[0] <= UMAP_SIZE_THRESHOLD else UMAP_MAXITER_LARGE ) - sc.tl.umap(self.samap.adata, min_dist=UMAP_MIN_DIST, init_pos="random", maxiter=maxiter) + _rsc_compat.umap( + self.samap.adata, + self._bk, + min_dist=UMAP_MIN_DIST, + init_pos="random", + maxiter=maxiter, + ) for sid in ids: sams[sid].adata.obsm["X_umap_samap"] = self.samap.adata[sams[sid].adata.obs_names].obsm[ "X_umap" @@ -621,26 +640,6 @@ def hex_to_rgb(value: str) -> list[float]: return ax - def gui(self) -> Any: - """Launch a SAMGUI instance containing the SAM objects.""" - if "SamapGui" not in self.__dict__: - try: - from samalg.gui import SAMGUI - except ImportError: - raise ImportError( - "Please install SAMGUI dependencies. See the README in the SAM github repository." - ) from None - - sg = SAMGUI( - sam=list(self.sams.values()), - title=list(self.ids), - default_proj="X_umap_samap", - ) - self.SamapGui = sg - return sg.SamPlot - else: - return self.SamapGui.SamPlot - def refine_homology_graph( self, thr: float = 0, @@ -648,6 +647,7 @@ def refine_homology_graph( ncpus: int | None = None, corr_mode: str = "pearson", wscale: bool = False, + correlation_mem_threshold: float = 2.0, ) -> sp.sparse.csr_matrix: """Refine the homology graph using expression correlations. @@ -663,6 +663,13 @@ def refine_homology_graph( Correlation mode: 'pearson'. Default 'pearson'. wscale : bool, optional Whether to scale by weights. Default False. + correlation_mem_threshold : float, optional + Memory threshold (GB) for auto-selecting materialised vs + streaming correlation. The materialised path is 3-5× faster on + small data but requires holding the full smoothed-expression + matrix in memory. Default 2.0 GB — raise on large-memory nodes + to keep the faster path for larger datasets. See + ``docs/performance.md`` for details. Returns ------- @@ -673,7 +680,12 @@ def refine_homology_graph( ncpus = os.cpu_count() or 1 gnnm = self.smap.refine_homology_graph( - NCLUSTERS=n_clusters, ncpus=ncpus, THR=thr, corr_mode=corr_mode, wscale=wscale + NCLUSTERS=n_clusters, + ncpus=ncpus, + THR=thr, + corr_mode=corr_mode, + wscale=wscale, + correlation_mem_threshold=correlation_mem_threshold, ) samap = self.smap.samap gns_dict = self.smap.gns_dict @@ -704,7 +716,9 @@ def __init__( gnnm: sp.sparse.csr_matrix, gns_dict: dict[str, NDArray[Any]], keys: dict[str, str] | None = None, + bk: Backend | None = None, ) -> None: + self._bk = bk if bk is not None else Backend("cpu") self.sams = sams self.gnnm = gnnm self.gnnmu = gnnm @@ -726,6 +740,13 @@ def __init__( ] self.iter = 0 + # Iteration-invariant projection state: standardised expression matrices, + # their Gram matrices/means (for the sigma quadratic form), and the + # own-species PC projection. Built once here, consumed every iteration + # inside _mapper → _mapping_window_fast. + self._gns = np.concatenate(list(gns_dict.values())) + self._proj_cache = _projection_precompute(sams, self._gns, self._bk) + def refine_homology_graph( self, NCLUSTERS: int = 1, @@ -733,6 +754,7 @@ def refine_homology_graph( THR: float = 0, corr_mode: str = "pearson", wscale: bool = False, + correlation_mem_threshold: float = 2.0, ) -> sp.sparse.csr_matrix: """Refine homology graph using correlations.""" if ncpus is None: @@ -750,6 +772,7 @@ def refine_homology_graph( ncpus=ncpus, corr_mode=corr_mode, wscale=wscale, + correlation_mem_threshold=correlation_mem_threshold, ) return gnnmu @@ -779,7 +802,7 @@ def run( NHS = dict.fromkeys(sams.keys(), 2) if neigh_from_keys is None: neigh_from_keys = dict.fromkeys(sams, False) - gns = np.concatenate(list(gns_dict.values())) + gns = self._gns if self.iter > 0: sam4 = self.samap @@ -809,6 +832,8 @@ def run( scale_edges_by_corr=scale_edges_by_corr, neigh_from_keys=neigh_from_keys, pairwise=pairwise, + proj_cache=self._proj_cache, + bk=self._bk, ) sam4.adata.uns["mapping_K"] = K self.samap = sam4 @@ -848,1053 +873,3 @@ def _avg_as(s: SAM) -> pd.DataFrame: / s.adata.uns["mapping_K"] ) return pd.DataFrame(data=a, index=xu, columns=xu) - - -@njit(parallel=True) -def _replace(X: NDArray[Any], xi: NDArray[Any], yi: NDArray[Any]) -> NDArray[np.float64]: - """Compute correlations for pairs in parallel.""" - data = np.zeros(xi.size) - for i in prange(xi.size): - x = X[xi[i]] - y = X[yi[i]] - data[i] = ((x - x.mean()) * (y - y.mean()) / x.std() / y.std()).sum() / x.size - return data - - -def _generate_coclustering_matrix(cl: NDArray[Any]) -> sp.sparse.csr_matrix: - """Generate a co-clustering indicator matrix.""" - import samalg.utilities as ut - - cl_arr = ut.convert_annotations(np.array(list(cl))) - clu, _cluc = np.unique(cl_arr, return_counts=True) - v = np.zeros((cl_arr.size, clu.size)) - v[np.arange(v.shape[0]), cl_arr] = 1 - return sp.sparse.csr_matrix(v) - - -def prepare_SAMap_loadings(sam: SAM, npcs: int = 300) -> None: - """Prepare SAM object with PC loadings for manifold. - - Parameters - ---------- - sam : SAM - SAM object to prepare. - npcs : int, optional - Number of PCs to calculate. Default 300. - """ - ra = sam.adata.uns["run_args"] - preprocessing = ra.get("preprocessing", "StandardScaler") - weight_PCs = ra.get("weight_PCs", False) - A, _ = sam.calculate_nnm( - n_genes=sam.adata.shape[1], - preprocessing=preprocessing, - npcs=npcs, - weight_PCs=weight_PCs, - sparse_pca=True, - update_manifold=False, - weight_mode="dispersion", - ) - sam.adata.varm["PCs_SAMap"] = A - - -# Include remaining internal functions from original mapping.py -# These are simplified versions with proper type hints and error handling - - -def _calculate_blast_graph( - ids: list[str], - f_maps: str = "maps/", - eval_thr: float = 1e-6, - reciprocate: bool = False, -) -> tuple[sp.sparse.csr_matrix, NDArray[Any], dict[str, NDArray[Any]]]: - """Calculate gene homology graph from BLAST results.""" - gns: list[str] = [] - Xs: list[Any] = [] - Ys: list[Any] = [] - Vs: list[Any] = [] - - for i in range(len(ids)): - id1 = ids[i] - for j in range(i, len(ids)): - id2 = ids[j] - if i != j: - if os.path.exists(f_maps + f"{id1}{id2}"): - fA = f_maps + f"{id1}{id2}/{id1}_to_{id2}.txt" - fB = f_maps + f"{id1}{id2}/{id2}_to_{id1}.txt" - elif os.path.exists(f_maps + f"{id2}{id1}"): - fA = f_maps + f"{id2}{id1}/{id1}_to_{id2}.txt" - fB = f_maps + f"{id2}{id1}/{id2}_to_{id1}.txt" - else: - raise FileNotFoundError( - f"BLAST mapping tables with the input IDs ({id1} and {id2}) " - f"not found in the specified path." - ) - - A = pd.read_csv(fA, sep="\t", header=None, index_col=0) - B = pd.read_csv(fB, sep="\t", header=None, index_col=0) - - A.columns = A.columns.astype(" NDArray[np.str_]: - """Add species prefix to gene names.""" - x = [str(item).split("_")[0] for item in data] - vn = [] - for i, g in enumerate(data): - if x[i] != pre: - vn.append(pre + "_" + g) - else: - vn.append(g) - return np.array(vn).astype("str").astype("object") - - -def _coarsen_blast_graph( - gnnm: sp.sparse.csr_matrix, - gns: NDArray[Any], - names: dict[str, Any], -) -> tuple[sp.sparse.csr_matrix, dict[str, NDArray[Any]], NDArray[Any]]: - """Coarsen BLAST graph by collapsing transcripts to genes.""" - gnnm = gnnm.tocsr() - gnnm.eliminate_zeros() - - sps = np.array([x.split("_")[0] for x in gns]) - sids = np.unique(sps) - ss = [] - for sid in sids: - n = names.get(sid) - if n is not None: - n = np.array(n) - n = (sid + "_" + n.astype("object")).astype("str") - s1 = pd.Series(index=n[:, 0], data=n[:, 1]) - g = gns[sps == sid] - g = g[np.isin(g, n[:, 0], invert=True)] - s2 = pd.Series(index=g, data=g) - s = pd.concat([s1, s2]) - else: - s = pd.Series(index=gns[sps == sid], data=gns[sps == sid]) - ss.append(s) - ss_combined = pd.concat(ss) - ss_combined = ss_combined[np.unique(_q(ss_combined.index), return_index=True)[1]] - x, y = gnnm.nonzero() - s = pd.Series(data=gns, index=np.arange(gns.size)) - xn, yn = s[x].values, s[y].values - xg, yg = ss_combined[xn].values, ss_combined[yn].values - - da = gnnm.data - - zgu, ix, _ivx, cu = np.unique( - np.array([xg, yg]).astype("str"), - axis=1, - return_counts=True, - return_index=True, - return_inverse=True, - ) - - xgu, ygu = zgu[:, cu > 1] - xgyg = _q(xg.astype("object") + ";" + yg.astype("object")) - xguygu = _q(xgu.astype("object") + ";" + ygu.astype("object")) - - filt = np.isin(xgyg, xguygu) - - DF = pd.DataFrame(data=xgyg[filt][:, None], columns=["key"]) - DF["val"] = da[filt] - - dic = df_to_dict(DF, key_key="key") - - xgu = _q([x.split(";")[0] for x in dic]) - ygu = _q([x.split(";")[1] for x in dic]) - replz = _q([max(dic[x]) for x in dic]) - - xgu1, ygu1 = zgu[:, cu == 1] - xg = np.append(xgu1, xgu) - yg = np.append(ygu1, ygu) - da = np.append(da[ix][cu == 1], replz) - gn = np.unique(np.append(xg, yg)) - - s = pd.Series(data=np.arange(gn.size), index=gn) - xn, yn = s[xg].values, s[yg].values - gnnm = sp.sparse.coo_matrix((da, (xn, yn)), shape=(gn.size,) * 2).tocsr() - - f = np.asarray(gnnm.sum(1)).flatten() != 0 - gn = gn[f] - sps = np.array([x.split("_")[0] for x in gn]) - - gns_dict: dict[str, NDArray[Any]] = {} - for sid in sids: - gns_dict[sid] = gn[sps == sid] - - return gnnm, gns_dict, gn - - -def _filter_gnnm(gnnm: sp.sparse.csr_matrix, thr: float = 0.25) -> sp.sparse.csr_matrix: - """Filter edges in homology graph below threshold.""" - x, y = gnnm.nonzero() - mas = np.asarray(gnnm.max(1).todense()).flatten() - gnnm4 = gnnm.copy() - # Use np.asarray to handle both sparse matrix and numpy.matrix returns - edge_values = np.asarray(gnnm4[x, y]).flatten() - gnnm4.data[edge_values < mas[x] * thr] = 0 - gnnm4.eliminate_zeros() - x, y = gnnm4.nonzero() - z = gnnm4.data - gnnm4 = gnnm4.tolil() - gnnm4[y, x] = z - return gnnm4.tocsr() - - -def _get_pairs( - sams: dict[str, SAM], - gnnm: sp.sparse.csr_matrix, - gns_dict: dict[str, NDArray[Any]], - NOPs1: int = 0, - NOPs2: int = 0, -) -> sp.sparse.csr_matrix: - """Get gene pairs weighted by SAM weights.""" - su = np.asarray(gnnm.max(1).todense()) - su[su == 0] = 1 - gnnm = gnnm.multiply(1 / su).tocsr() - Ws = {} - for sid in sams: - Ws[sid] = sams[sid].adata.var["weights"][gns_dict[sid]].values - - W = np.concatenate(list(Ws.values())) - W[W < 0.0] = 0 - W[W > 0.0] = 1 - - B = gnnm.multiply(W[None, :]).multiply(W[:, None]).tocsr() - B.eliminate_zeros() - - return B - - -@njit -def nb_unique1d(ar: NDArray[Any]) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any], NDArray[Any]]: - """Find unique elements of an array (numba-optimized).""" - ar = ar.flatten() - perm = ar.argsort(kind="mergesort") - aux = ar[perm] - mask = np.empty(aux.shape, dtype=np.bool_) - mask[:1] = True - mask[1:] = aux[1:] != aux[:-1] - - imask = np.cumsum(mask) - 1 - inv_idx = np.empty(mask.shape, dtype=np.intp) - inv_idx[perm] = imask - idx = np.append(np.nonzero(mask)[0], mask.size) - - return aux[mask], perm[mask], inv_idx, np.diff(idx) - - -@njit -def _xicorr(X: NDArray[Any], Y: NDArray[Any]) -> float: - """Xi correlation coefficient.""" - n = X.size - xi = np.argsort(X, kind="quicksort") - Y = Y[xi] - _, _, b, c = nb_unique1d(Y) - r = np.cumsum(c)[b] - _, _, b, c = nb_unique1d(-Y) - left_counts = np.cumsum(c)[b] - denominator = 2 * (left_counts * (n - left_counts)).sum() - if denominator > 0: - return 1 - n * np.abs(np.diff(r)).sum() / denominator - else: - return 0.0 - - -@njit(parallel=True) -def _refine_corr_kernel( - p: NDArray[Any], - ps: NDArray[Any], - sids: NDArray[Any], - sixs: list[NDArray[Any]], - indptr: NDArray[Any], - indices: NDArray[Any], - data: NDArray[Any], - n: int, - corr_mode: str, -) -> NDArray[np.float64]: - """Kernel for computing gene correlations in parallel.""" - p1 = p[:, 0] - p2 = p[:, 1] - - ps1 = ps[:, 0] - ps2 = ps[:, 1] - - d = {} - for i in range(len(sids)): - d[sids[i]] = sixs[i] - - res = np.zeros(p1.size) - - for j in prange(len(p1)): - j1, j2 = p1[j], p2[j] - pl1d = data[indptr[j1] : indptr[j1 + 1]] - pl1i = indices[indptr[j1] : indptr[j1 + 1]] - - sc1d = data[indptr[j2] : indptr[j2 + 1]] - sc1i = indices[indptr[j2] : indptr[j2 + 1]] - - x = np.zeros(n) - x[pl1i] = pl1d - y = np.zeros(n) - y[sc1i] = sc1d - - a1, a2 = ps1[j], ps2[j] - ix1 = d[a1] - ix2 = d[a2] - - xa, xb, ya, yb = x[ix1], x[ix2], y[ix1], y[ix2] - xx = np.append(xa, xb) - yy = np.append(ya, yb) - - if corr_mode == "pearson": - c = ((xx - xx.mean()) * (yy - yy.mean()) / xx.std() / yy.std()).sum() / xx.size - else: - c = _xicorr(xx, yy) - res[j] = c - return res - - -def _tanh_scale(x: NDArray[Any], scale: float = 10, center: float = 0.5) -> NDArray[Any]: - """Apply tanh scaling to values.""" - return center + (1 - center) * np.tanh(scale * (x - center)) - - -def _refine_corr( - sams: dict[str, SAM], - st: SAM, - gnnm: sp.sparse.csr_matrix, - gns_dict: dict[str, NDArray[Any]], - corr_mode: str = "pearson", - THR: float = 0, - use_seq: bool = False, - T1: float = 0.25, - NCLUSTERS: int = 1, - ncpus: int | None = None, - wscale: bool = False, -) -> sp.sparse.csr_matrix: - """Refine correlation matrix for homology graph.""" - if ncpus is None: - ncpus = os.cpu_count() or 1 - - gns = np.concatenate(list(gns_dict.values())) - - x, y = gnnm.nonzero() - sam = next(iter(sams.values())) - cl = sam.leiden_clustering(gnnm, res=0.5) - ix = np.argsort(cl) - NGPC = gns.size // NCLUSTERS + 1 - - ixs = [] - for i in range(NCLUSTERS): - ixs.append(np.sort(ix[i * NGPC : (i + 1) * NGPC])) - - assert np.concatenate(ixs).size == gns.size - - GNNMSUBS = [] - GNSUBS = [] - for i in range(len(ixs)): - ixs[i] = np.unique(np.append(ixs[i], gnnm[ixs[i], :].nonzero()[1])) - gnnm_sub = gnnm[ixs[i], :][:, ixs[i]] - gnsub = gns[ixs[i]] - gns_dict_sub = {} - for sid in gns_dict: - gn = gns_dict[sid] - gns_dict_sub[sid] = gn[np.isin(gn, gnsub)] - - gnnm2_sub = _refine_corr_parallel( - sams, - st, - gnnm_sub, - gns_dict_sub, - corr_mode=corr_mode, - THR=THR, - use_seq=use_seq, - T1=T1, - ncpus=ncpus, - wscale=wscale, - ) - GNNMSUBS.append(gnnm2_sub) - GNSUBS.append(gnsub) - gc.collect() - - indices_list = [] - pairs_list = [] - for i in range(len(GNNMSUBS)): - indices_list.append(np.unique(np.sort(np.vstack(GNNMSUBS[i].nonzero()).T, axis=1), axis=0)) - pairs_list.append(GNSUBS[i][indices_list[-1]]) - - GNS = pd.DataFrame(data=np.arange(gns.size)[None, :], columns=gns) - gnnm3 = sp.sparse.lil_matrix(gnnm.shape) - for i in range(len(indices_list)): - x, y = GNS[pairs_list[i][:, 0]].values.flatten(), GNS[pairs_list[i][:, 1]].values.flatten() - gnnm3[x, y] = np.asarray( - GNNMSUBS[i][indices_list[i][:, 0], indices_list[i][:, 1]] - ).flatten() - - gnnm3 = gnnm3.tocsr() - x, y = gnnm3.nonzero() - gnnm3 = gnnm3.tolil() - gnnm3[y, x] = np.asarray(gnnm3[x, y].tocsr().todense()).flatten() - return gnnm3.tocsr() - - -def _refine_corr_parallel( - sams: dict[str, SAM], - st: SAM, - gnnm: sp.sparse.csr_matrix, - gns_dict: dict[str, NDArray[Any]], - corr_mode: str = "pearson", - THR: float = 0, - use_seq: bool = False, - T1: float = 0.0, - ncpus: int | None = None, - wscale: bool = False, -) -> sp.sparse.csr_matrix: - """Parallel correlation refinement.""" - if ncpus is None: - ncpus = os.cpu_count() or 1 - - gn = np.concatenate(list(gns_dict.values())) - - Ws = [] - ix = [] - for sid in sams: - Ws.append(sams[sid].adata.var["weights"][gns_dict[sid]].values) - ix += [sid] * gns_dict[sid].size - ix = np.array(ix) - w = np.concatenate(Ws) - - w[w > T1] = 1 - w[w < 1] = 0 - - gnO = gn[w > 0] - ix = ix[w > 0] - gns_dictO = {} - for sid in gns_dict: - gns_dictO[sid] = gnO[ix == sid] - - gnnmO = gnnm[w > 0, :][:, w > 0] - x, y = gnnmO.nonzero() - - pairs = np.unique(np.sort(np.vstack((x, y)).T, axis=1), axis=0) - - xs = _q([i.split("_")[0] for i in gnO[pairs[:, 0]]]) - ys = _q([i.split("_")[0] for i in gnO[pairs[:, 1]]]) - pairs_species = np.vstack((xs, ys)).T - - nnm = st.adata.obsp["connectivities"] - xs_list = [] - nnms = [] - for i, sid in enumerate(sams.keys()): - batch_mask = (st.adata.obs["batch"] == f"batch{i + 1}").values - nnms.append(nnm[:, batch_mask]) - s1 = np.asarray(nnms[-1].sum(1)) - s1[s1 < 1e-3] = 1 - s1 = s1.flatten()[:, None] - nnms[-1] = nnms[-1].multiply(1 / s1) - - xs_list.append(sams[sid].adata[:, gns_dictO[sid]].X.astype("float32")) - - Xs = sp.sparse.block_diag(xs_list).tocsc() - nnms = sp.sparse.hstack(nnms).tocsr() - Xavg = nnms.dot(Xs).tocsc() - - p = pairs - ps = pairs_species - - gnnm2 = gnnm.multiply(w[:, None]).multiply(w[None, :]).tocsr() - x, y = gnnm2.nonzero() - pairs = np.unique(np.sort(np.vstack((x, y)).T, axis=1), axis=0) - - species = _q(st.adata.obs["species"]) - sixs = [] - sidss = np.unique(species) - for sid in sidss: - sixs.append(np.where(species == sid)[0]) - - vals = _refine_corr_kernel( - p, ps, sidss, sixs, Xavg.indptr, Xavg.indices, Xavg.data, Xavg.shape[0], corr_mode - ) - vals[np.isnan(vals)] = 0 - - CORR = dict(zip(to_vn(np.vstack((gnO[p[:, 0]], gnO[p[:, 1]])).T), vals)) - - for k in CORR: - CORR[k] = 0 if CORR[k] < THR else CORR[k] - if wscale: - id1, id2 = [x.split("_")[0] for x in k.split(";")] - weight1 = sams[id1].adata.var["weights"][k.split(";")[0]] - weight2 = sams[id2].adata.var["weights"][k.split(";")[1]] - CORR[k] = np.sqrt(CORR[k] * np.sqrt(weight1 * weight2)) - - CORR_arr = np.array([CORR[x] for x in to_vn(gn[pairs])]) - - gnnm3 = sp.sparse.lil_matrix(gnnm.shape) - - if use_seq: - gnnm3[pairs[:, 0], pairs[:, 1]] = ( - CORR_arr * np.asarray(gnnm2[pairs[:, 0], pairs[:, 1]]).flatten() - ) - gnnm3[pairs[:, 1], pairs[:, 0]] = ( - CORR_arr * np.asarray(gnnm2[pairs[:, 1], pairs[:, 0]]).flatten() - ) - else: - gnnm3[pairs[:, 0], pairs[:, 1]] = CORR_arr - gnnm3[pairs[:, 1], pairs[:, 0]] = CORR_arr - - gnnm3 = gnnm3.tocsr() - gnnm3.eliminate_zeros() - return gnnm3 - - -def _united_proj( - wpca1: NDArray[Any], - wpca2: NDArray[Any], - k: int = 20, - metric: str = "cosine", - ef: int = 200, - M: int = 48, -) -> sp.sparse.csr_matrix: - """Project between feature spaces using HNSW.""" - metric = "l2" if metric == "euclidean" else metric - metric = "cosine" if metric == "correlation" else metric - labels2 = np.arange(wpca2.shape[0]) - p2 = hnswlib.Index(space=metric, dim=wpca2.shape[1]) - p2.init_index(max_elements=wpca2.shape[0], ef_construction=ef, M=M) - p2.add_items(wpca2, labels2) - p2.set_ef(ef) - idx1, dist1 = p2.knn_query(wpca1, k=k) - - if metric == "cosine": - dist1 = 1 - dist1 - dist1[dist1 < 1e-3] = 1e-3 - dist1 = dist1 / dist1.max(1)[:, None] - dist1 = _tanh_scale(dist1, scale=10, center=0.7) - else: - sigma1 = dist1[:, 4] - sigma1[sigma1 < 1e-3] = 1e-3 - dist1 = np.exp(-dist1 / sigma1[:, None]) - - Sim1 = dist1 - knn1v2 = sp.sparse.lil_matrix((wpca1.shape[0], wpca2.shape[0])) - x1 = np.tile(np.arange(idx1.shape[0])[:, None], (1, idx1.shape[1])).flatten() - knn1v2[x1.astype("int32"), idx1.flatten().astype("int32")] = Sim1.flatten() - return knn1v2.tocsr() - - -def _mapper( - sams: dict[str, SAM], - gnnm: sp.sparse.csr_matrix | None = None, - gn: NDArray[Any] | None = None, - NHS: dict[str, int] | None = None, - umap: bool = False, - mdata: dict[str, Any] | None = None, - k: int | None = None, - K: int = 20, - chunksize: int = 20000, - coarsen: bool = True, - keys: dict[str, str] | None = None, - scale_edges_by_corr: bool = False, - neigh_from_keys: dict[str, bool] | None = None, - pairwise: bool = True, - **kwargs: Any, -) -> SAM: - """Map cells between species.""" - if NHS is None: - NHS = dict.fromkeys(sams.keys(), 3) - - if neigh_from_keys is None: - neigh_from_keys = dict.fromkeys(sams.keys(), False) - - if mdata is None: - mdata = _mapping_window(sams, gnnm, gn, K=K, pairwise=pairwise) - - k1 = K - - if keys is None: - keys = dict.fromkeys(sams.keys(), "leiden_clusters") - - nnms_in: dict[str, Any] = {} - nnms_in0: dict[str, Any] = {} - flag = False - species_indexer = [] - for sid in sams: - logger.info("Expanding neighbourhoods of species %s...", sid) - cl = sams[sid].get_labels(keys[sid]) - _, ix, cluc = np.unique(cl, return_counts=True, return_inverse=True) - K_arr = cluc[ix] - nnms_in0[sid] = sams[sid].adata.obsp["connectivities"].copy() - species_indexer.append(np.arange(sams[sid].adata.shape[0])) - if not neigh_from_keys[sid]: - nnm_in = _smart_expand(nnms_in0[sid], K_arr, NH=NHS[sid]) - nnm_in.data[:] = 1 - nnms_in[sid] = nnm_in - else: - nnms_in[sid] = _generate_coclustering_matrix(cl) - flag = True - - for i in range(1, len(species_indexer)): - species_indexer[i] += species_indexer[i - 1].max() + 1 - - if not flag: - nnm_internal = sp.sparse.block_diag(list(nnms_in.values())).tocsr() - nnm_internal0 = sp.sparse.block_diag(list(nnms_in0.values())).tocsr() - - ovt = mdata["knn"] - ovt0 = ovt.copy() - ovt0.data[:] = 1 - - B = ovt - - logger.info("Indegree coarsening") - - numiter = nnm_internal0.shape[0] // chunksize + 1 - - D = sp.sparse.csr_matrix((0, nnm_internal0.shape[0])) - if flag: - Cs = [] - for it, sid in enumerate(sams.keys()): - nfk = neigh_from_keys[sid] - if nfk: - Cs.append(nnms_in[sid].dot(nnms_in[sid].T.dot(B.T[species_indexer[it]]))) - else: - Cs.append(nnms_in[sid].dot(B.T[species_indexer[it]])) - D = sp.sparse.vstack(Cs).T - del Cs - gc.collect() - else: - for bl in range(numiter): - logger.debug("%d/%d, shape %s", bl, numiter, D.shape) - C = B[bl * chunksize : (bl + 1) * chunksize].dot(nnm_internal.T) - C.data[C.data < 0.1] = 0 - C.eliminate_zeros() - - D = sp.sparse.vstack((D, C)) - del C - gc.collect() - - D = D.multiply(D.T).tocsr() - D.data[:] = D.data**0.5 - mdata["xsim"] = D - - if scale_edges_by_corr: - logger.info("Rescaling edge weights by expression correlations.") - x, y = D.nonzero() - vals = _replace(mdata["wPCA"], x, y) - vals[vals < 1e-3] = 1e-3 - - F = D.copy() - F.data[:] = vals - - ma = np.asarray(F.max(1).todense()) - ma[ma == 0] = 1 - F = F.multiply(1 / ma).tocsr() - F.data[:] = _tanh_scale(F.data, center=0.7, scale=10) - - ma = np.asarray(D.max(1).todense()) - ma[ma == 0] = 1 - - D = F.multiply(D).tocsr() - D.data[:] = np.sqrt(D.data) - - ma2 = np.asarray(D.max(1).todense()) - ma2[ma2 == 0] = 1 - - D = D.multiply(ma / ma2).tocsr() - - species_list = [] - for sid in sams: - species_list += [sid] * sams[sid].adata.shape[0] - species_list = np.array(species_list) - - if not pairwise or len(sams.keys()) == 2: - Dk = sparse_knn(D, k1).tocsr() - denom = k1 - else: - Dk = [] - for sid1 in sams: - row = [] - for sid2 in sams: - if sid1 != sid2: - Dsubk = sparse_knn(D[species_list == sid1][:, species_list == sid2], k1).tocsr() - else: - Dsubk = sp.sparse.csr_matrix((sams[sid1].adata.shape[0],) * 2) - row.append(Dsubk) - Dk.append(sp.sparse.hstack(row)) - Dk = sp.sparse.vstack(Dk).tocsr() - denom = k1 * (len(sams.keys()) - 1) - - sr = np.asarray(Dk.sum(1)) - - x = 1 - sr.flatten() / denom - - sr[sr == 0] = 1 - st = np.asarray(Dk.sum(0)).flatten()[None, :] - st[st == 0] = 1 - proj = Dk.multiply(1 / sr).dot(Dk.multiply(1 / st)).tocsr() - z = proj.copy() - z.data[:] = 1 - idx = np.where(np.asarray(z.sum(1)).flatten() >= k1)[0] - - omp = nnm_internal0 - omp.data[:] = 1 - s = np.asarray(proj.max(1).todense()) - s[s == 0] = 1 - proj = proj.multiply(1 / s).tocsr() - X, Y = omp.nonzero() - X2 = X[np.isin(X, idx)] - Y2 = Y[np.isin(X, idx)] - - omp = omp.tolil() - omp[X2, Y2] = np.vstack((np.asarray(proj[X2, Y2]).flatten(), np.ones(X2.size) * 0.3)).max(0) - - omp = nnm_internal0.tocsr() - NNM = omp.multiply(x[:, None]) - NNM = (NNM + Dk).tolil() - NNM.setdiag(0) - - logger.info("Concatenating SAM objects...") - sam3 = _concatenate_sam(sams, NNM) - - sam3.adata.obs["species"] = pd.Categorical(species_list) - - sam3.adata.uns["gnnm_corr"] = mdata.get("gnnm_corr", None) - - if umap: - logger.info("Computing UMAP projection...") - maxiter = ( - UMAP_MAXITER_SMALL if sam3.adata.shape[0] <= UMAP_SIZE_THRESHOLD else UMAP_MAXITER_LARGE - ) - sc.tl.umap(sam3.adata, min_dist=UMAP_MIN_DIST, maxiter=maxiter) - return sam3 - - -def _concatenate_sam(sams: dict[str, SAM], nnm: sp.sparse.lil_matrix) -> SAM: - """Concatenate SAM objects.""" - acns = [] - exps = [] - agns = [] - sps = [] - for i, sid in enumerate(sams.keys()): - acns.append(_q(sams[sid].adata.obs_names)) - sps.append([sid] * acns[-1].size) - exps.append(sams[sid].adata.X) - agns.append(_q(sams[sid].adata.var_names)) - - acn = np.concatenate(acns) - agn = np.concatenate(agns) - sps_arr = np.concatenate(sps) - - xx = sp.sparse.block_diag(exps, format="csr") - - sam = SAM(counts=(xx, agn, acn)) - - sam.adata.uns["neighbors"] = {} - nnm = nnm.tocsr() - nnm.eliminate_zeros() - sam.adata.obsp["connectivities"] = nnm - sam.adata.uns["neighbors"]["params"] = { - "n_neighbors": 15, - "method": "umap", - "use_rep": "X", - "metric": "euclidean", - } - for i in sams: - for k in sams[i].adata.obs: - if sams[i].adata.obs[k].dtype.name == "category": - z = np.array(["unassigned"] * sam.adata.shape[0], dtype="object") - z[sps_arr == i] = _q(sams[i].adata.obs[k]) - sam.adata.obs[i + "_" + k] = pd.Categorical(z) - - a = [] - for i, sid in enumerate(sams.keys()): - a.extend(["batch" + str(i + 1)] * sams[sid].adata.shape[0]) - sam.adata.obs["batch"] = pd.Categorical(np.array(a)) - sam.adata.obs.columns = sam.adata.obs.columns.astype("str") - sam.adata.var.columns = sam.adata.var.columns.astype("str") - - for i in sam.adata.obs: - sam.adata.obs[i] = sam.adata.obs[i].astype("str") - - return sam - - -def _mapping_window( - sams: dict[str, SAM], - gnnm: sp.sparse.csr_matrix | None = None, - gns: NDArray[Any] | None = None, - K: int = 20, - pairwise: bool = True, -) -> dict[str, Any]: - """Create mapping window for cross-species projection.""" - k = K - output_dict: dict[str, Any] = {} - if gnnm is not None and gns is not None: - logger.info("Prepping datasets for translation.") - gnnm_corr = gnnm.copy() - gnnm_corr.data[:] = _tanh_scale(gnnm_corr.data) - - std = StandardScaler(with_mean=False) - - gs = {} - adatas = {} - Ws = {} - ss = {} - species_indexer = [] - genes_indexer = [] - for sid in sams: - gs[sid] = gns[np.isin(gns, _q(sams[sid].adata.var_names))] - adatas[sid] = sams[sid].adata[:, gs[sid]] - Ws[sid] = adatas[sid].var["weights"].values - ss[sid] = std.fit_transform(adatas[sid].X).multiply(Ws[sid][None, :]).tocsr() - species_indexer.append(np.arange(ss[sid].shape[0])) - genes_indexer.append(np.arange(gs[sid].size)) - - for i in range(1, len(species_indexer)): - species_indexer[i] = species_indexer[i] + species_indexer[i - 1].max() + 1 - genes_indexer[i] = genes_indexer[i] + genes_indexer[i - 1].max() + 1 - - su = np.asarray(gnnm_corr.sum(0)) - su[su == 0] = 1 - gnnm_corr = gnnm_corr.multiply(1 / su).tocsr() - - X = sp.sparse.block_diag(list(ss.values())).tocsr() - W = np.concatenate(list(Ws.values())).flatten() - - ttt = time.time() - if pairwise: - logger.info("Translating feature spaces pairwise.") - Xtr = [] - for i, _sid1 in enumerate(sams.keys()): - xtr = [] - for j, _sid2 in enumerate(sams.keys()): - if i != j: - gnnm_corr_sub = gnnm_corr[genes_indexer[i]][:, genes_indexer[j]] - su = np.asarray(gnnm_corr_sub.sum(0)) - su[su == 0] = 1 - gnnm_corr_sub = gnnm_corr_sub.multiply(1 / su).tocsr() - xtr.append(X[species_indexer[i]][:, genes_indexer[i]].dot(gnnm_corr_sub)) - xtr[-1] = std.fit_transform(xtr[-1]).multiply(W[genes_indexer[j]][None, :]) - else: - xtr.append( - sp.sparse.csr_matrix((species_indexer[i].size, genes_indexer[i].size)) - ) - Xtr.append(sp.sparse.hstack(xtr)) - Xtr = sp.sparse.vstack(Xtr) - else: - logger.info("Translating feature spaces all-to-all.") - - Xtr = [] - for i, sid in enumerate(sams.keys()): - Xtr.append(X[species_indexer[i]].dot(gnnm_corr)) - Xtr[-1] = std.fit_transform(Xtr[-1]).multiply(W[None, :]) - Xtr = sp.sparse.vstack(Xtr) - Xc = (X + Xtr).tocsr() - - mus = [] - for i, sid in enumerate(sams.keys()): - mus.append(np.asarray(Xc[species_indexer[i]].mean(0)).flatten()) - - gc.collect() - - logger.info("Projecting data into joint latent space. %.2fs", time.time() - ttt) - C = sp.linalg.block_diag(*[adatas[sid].varm["PCs_SAMap"] for sid in sams]) - M = np.vstack(mus).dot(C) - ttt = time.time() - it = 0 - PCAs = [] - for sid in sams: - PCAs.append(Xc[:, it : it + gs[sid].size].dot(adatas[sid].varm["PCs_SAMap"])) - it += gs[sid].size - wpca = np.hstack(PCAs) - - logger.info("Correcting data with means. %.2fs", time.time() - ttt) - for i, sid in enumerate(sams.keys()): - ixq = species_indexer[i] - wpca[ixq] -= M[i] - output_dict["gnnm_corr"] = gnnm_corr - else: - std = StandardScaler(with_mean=False) - - gs = {} - adatas = {} - Ws = {} - ss = {} - species_indexer = [] - mus = [] - for sid in sams: - adatas[sid] = sams[sid].adata - Ws[sid] = adatas[sid].var["weights"].values - ss[sid] = std.fit_transform(adatas[sid].X).multiply(Ws[sid][None, :]).tocsr() - mus.append(np.asarray(ss[sid].mean(0)).flatten()) - species_indexer.append(np.arange(ss[sid].shape[0])) - for i in range(1, len(species_indexer)): - species_indexer[i] = species_indexer[i] + species_indexer[i - 1].max() + 1 - X = sp.sparse.vstack(list(ss.values())) - C = np.hstack([adatas[sid].varm["PCs_SAMap"] for sid in sams]) - wpca = X.dot(C) - M = np.vstack(mus).dot(C) - for i, sid in enumerate(sams.keys()): - ixq = species_indexer[i] - wpca[ixq] -= M[i] - - ixg = np.arange(wpca.shape[0]) - Xs = [] - Ys = [] - Vs = [] - for i, sid in enumerate(sams.keys()): - ixq = species_indexer[i] - query = wpca[ixq] - - for j, _sid2 in enumerate(sams.keys()): - if i != j: - ixr = species_indexer[j] - reference = wpca[ixr] - - b = _united_proj(query, reference, k=k) - - su = b.sum(1).A - su[su == 0] = 1 - b = b.multiply(1 / su).tocsr() - - A = pd.Series(index=np.arange(b.shape[0]), data=ixq) - B = pd.Series(index=np.arange(b.shape[1]), data=ixr) - - x, y = b.nonzero() - x, y = A[x].values, B[y].values - Xs.extend(x) - Ys.extend(y) - Vs.extend(b.data) - - knn = sp.sparse.coo_matrix((Vs, (Xs, Ys)), shape=(ixg.size, ixg.size)) - - output_dict["knn"] = knn.tocsr() - output_dict["wPCA"] = wpca - return output_dict - - -def _sparse_knn_ks(D: sp.sparse.coo_matrix, ks: NDArray[Any]) -> sp.sparse.coo_matrix: - """Keep variable top-k values per row in sparse matrix.""" - D1 = D.tocoo() - idr = np.argsort(D1.row) - D1.row[:] = D1.row[idr] - D1.col[:] = D1.col[idr] - D1.data[:] = D1.data[idr] - - row, ind = np.unique(D1.row, return_index=True) - ind = np.append(ind, D1.data.size) - for i in range(ind.size - 1): - idx = np.argsort(D1.data[ind[i] : ind[i + 1]]) - k = ks[row[i]] - if idx.size > k: - idx = idx[:-k] if k != 0 else idx - D1.data[np.arange(ind[i], ind[i + 1])[idx]] = 0 - D1.eliminate_zeros() - return D1 - - -def _smart_expand(nnm: sp.sparse.csr_matrix, K: NDArray[Any], NH: int = 3) -> sp.sparse.csr_matrix: - """Expand neighborhoods progressively.""" - stage0 = nnm.copy() - S = [stage0] - running = stage0 - for i in range(1, NH + 1): - stage = running.dot(stage0) - running = stage - stage = stage.tolil() - for j in range(i): - stage[S[j].nonzero()] = 0 - stage = stage.tocsr() - S.append(stage) - - for i in range(len(S)): - s = _sparse_knn_ks(S[i], K).tocsr() - a, c = np.unique(s.nonzero()[0], return_counts=True) - numnz = np.zeros(s.shape[0], dtype="int32") - numnz[a] = c - K = K - numnz - K[K < 0] = 0 - S[i] = s - res = S[0] - for i in range(1, len(S)): - res = res + S[i] - return res diff --git a/src/samap/core/projection.py b/src/samap/core/projection.py new file mode 100644 index 0000000..fc5d123 --- /dev/null +++ b/src/samap/core/projection.py @@ -0,0 +1,538 @@ +"""Cross-species feature projection and kNN construction. + +Projects cells from each species into a joint latent space via the homology +graph, then builds the cross-species kNN graph with HNSW. + +Implementation notes — precomposed feature translation +------------------------------------------------------ +The legacy algorithm materialised a cells × genes translated-feature matrix +``Xtr = X_i @ G_ij`` per species pair, scaled it column-wise, weighted it by +gene weights, then projected it through the target species' PC loadings. +For realistic datasets that intermediate is ~30% dense and dominates both +memory and wall time. + +We now precompose the projection. Writing the column-wise scaling as a +diagonal matrix ``D = diag(W_j / σ)``, the cross contribution is + + wpca_cross = (X_i @ G_ij) · D · PCs_j + = X_i @ (G_ij · D · PCs_j) + = X_i @ P_ij + +where ``P_ij`` has shape (G_i × npcs) — typically a few MB — and the final +result follows from one SpMM. The per-column standard deviation ``σ`` is +recovered from iteration-invariant precomputes (``X_i^T X_i`` and +``X_i.mean(0)``) via a quadratic form in the columns of ``G_ij``, so the +dense intermediate is never materialised. + +The own-species contribution ``X_i @ PCs_i`` and its mean correction do not +depend on the homology graph at all and are cached in :func:`_projection_precompute`. +""" + +from __future__ import annotations + +import gc +import time +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import scipy as sp +from sklearn.preprocessing import StandardScaler + +from samap._logging import logger +from samap.core._backend import Backend +from samap.core.knn import approximate_knn +from samap.utils import q as _q + +from .homology import _tanh_scale + +if TYPE_CHECKING: + from typing import Any + + from numpy.typing import NDArray + + from samap.sam import SAM + + +def prepare_SAMap_loadings(sam: SAM, npcs: int = 300) -> None: + """Prepare SAM object with PC loadings for manifold. + + Parameters + ---------- + sam : SAM + SAM object to prepare. + npcs : int, optional + Number of PCs to calculate. Default 300. + """ + ra = sam.adata.uns["run_args"] + preprocessing = ra.get("preprocessing", "StandardScaler") + weight_PCs = ra.get("weight_PCs", False) + A, _ = sam.calculate_nnm( + n_genes=sam.adata.shape[1], + preprocessing=preprocessing, + npcs=npcs, + weight_PCs=weight_PCs, + sparse_pca=True, + update_manifold=False, + weight_mode="dispersion", + ) + sam.adata.varm["PCs_SAMap"] = A + + +def _united_proj( + wpca1: NDArray[Any], + wpca2: NDArray[Any], + k: int = 20, + metric: str = "cosine", + bk: Backend | None = None, +) -> sp.sparse.csr_matrix: + """Build cross-species kNN sparse graph with similarity weights. + + Finds the ``k`` nearest neighbours of each row of ``wpca1`` in + ``wpca2``, transforms distances into similarity weights, and returns a + sparse (n_q, n_d) CSR. + + The kNN search is delegated to :func:`samap.core.knn.approximate_knn`, + which dispatches between CPU HNSW (hnswlib) and GPU brute-force (FAISS) + based on ``bk``. When ``bk`` is ``None`` a CPU backend is used. + """ + metric = "l2" if metric == "euclidean" else metric + metric = "cosine" if metric == "correlation" else metric + + idx1, dist1 = approximate_knn(wpca1, wpca2, k=k, metric=metric, bk=bk) + + if metric == "cosine": + dist1 = 1 - dist1 + dist1[dist1 < 1e-3] = 1e-3 + dist1 = dist1 / dist1.max(1)[:, None] + dist1 = _tanh_scale(dist1, scale=10, center=0.7) + else: + sigma1 = dist1[:, 4] + sigma1[sigma1 < 1e-3] = 1e-3 + dist1 = np.exp(-dist1 / sigma1[:, None]) + + Sim1 = dist1 + n_q = wpca1.shape[0] + n_d = wpca2.shape[0] + rows = np.repeat(np.arange(n_q, dtype=np.int32), k) + cols = idx1.ravel().astype(np.int32) + vals = Sim1.ravel() + return sp.sparse.coo_matrix((vals, (rows, cols)), shape=(n_q, n_d)).tocsr() + + +# --------------------------------------------------------------------------- # +# Sigma from precomputes # +# --------------------------------------------------------------------------- # + + +def _compute_sigma( + XtX: Any, + mu: NDArray[Any], + G: Any, + n: int, + bk: Backend | None = None, +) -> NDArray[Any]: + """Column-wise standard deviation of ``X @ G`` without materialising it. + + Equivalent to ``StandardScaler(with_mean=False).fit(X @ G).scale_``. + + Parameters + ---------- + XtX : sparse (G_i × G_i) + Precomputed Gram matrix ``X.T @ X``. + mu : 1-d array, length G_i + Precomputed column means ``X.mean(axis=0)``. + G : sparse (G_i × G_j) + The homology sub-block whose columns we're scaling. + n : int + Number of rows in ``X`` (cells in the source species). + bk : Backend or None + Array backend. If None, uses numpy/scipy directly. + + Returns + ------- + sigma : 1-d array, length G_j + Per-column biased standard deviation, with zero-variance columns + mapped to 1.0 (matching sklearn's ``_handle_zeros_in_scale``). + + Notes + ----- + Uses the identity ``diag(Gᵀ·XtX·G)_k = Σ_r (X·g_k)[r]²`` so that + ``σ_k² = diag(Gᵀ·XtX·G)_k / n − (μ·g_k)²``. The diagonal is extracted + as ``((XtX @ G) ⊙ G).sum(0)`` — one SpGEMM + one elementwise product, + never forming the full G_j × G_j outer product. + """ + xp = bk.xp if bk is not None else np + # diag(Gᵀ · XtX · G) — elementwise-multiply trick avoids G_j × G_j dense + sq = xp.asarray((XtX @ G).multiply(G).sum(0)).flatten() + mu_terms = xp.asarray(mu @ G).flatten() + var = sq / n - mu_terms * mu_terms + # numerical guard — floating-point cancellation can produce tiny negatives + var = xp.maximum(var, 0.0) + sigma = xp.sqrt(var) + # sklearn maps zero-variance columns to scale_=1.0 + sigma = xp.where(sigma == 0.0, 1.0, sigma) + return sigma + + +# --------------------------------------------------------------------------- # +# Iteration-invariant precompute # +# --------------------------------------------------------------------------- # + + +def _projection_precompute( + sams: dict[str, SAM], + gns: NDArray[Any], + bk: Backend | None = None, +) -> dict[str, Any]: + """Build the iteration-invariant state for :func:`_mapping_window_fast`. + + Everything that depends only on the input SAM objects — not on the + homology graph — is computed once here and cached. Specifically: the + standardised, gene-weighted expression matrices; their Gram matrices and + column means (for the sigma quadratic form); and the own-species PC + projection (which never changes across SAMap iterations). + + Parameters + ---------- + sams : dict[str, SAM] + Input per-species SAM objects. Must have ``adata.varm['PCs_SAMap']`` + and ``adata.var['weights']`` populated. + gns : array of str + Concatenated homology-graph gene names, species-prefixed and ordered + so that species blocks are contiguous. + bk : Backend or None + Array backend for device placement. Default: CPU. + + Returns + ------- + dict + Keys: ``sids``, ``gs``, ``W``, ``species_indexer``, ``genes_indexer``, + ``ss``, ``PCs``, ``n_cells``, ``XtX``, ``mu_ss``, ``wpca_own``, + ``M_own``, ``bk``. All array-valued entries live on ``bk``'s device. + """ + if bk is None: + bk = Backend("cpu") + + std = StandardScaler(with_mean=False) + + sids = list(sams.keys()) + gs: dict[str, NDArray[Any]] = {} + W: dict[str, Any] = {} + ss: dict[str, Any] = {} + PCs: dict[str, Any] = {} + n_cells: dict[str, int] = {} + species_indexer: list[NDArray[Any]] = [] + genes_indexer: list[NDArray[Any]] = [] + + for sid in sids: + gs[sid] = gns[np.isin(gns, _q(sams[sid].adata.var_names))] + sub = sams[sid].adata[:, gs[sid]] + W[sid] = bk.to_device(bk.xp.asarray(sub.var["weights"].values)) + # StandardScaler runs on host; move result to device after + ss_host = std.fit_transform(sub.X).multiply(sub.var["weights"].values[None, :]).tocsr() + ss[sid] = bk.to_device(ss_host) + PCs[sid] = bk.to_device(bk.xp.asarray(sub.varm["PCs_SAMap"])) + n_cells[sid] = ss_host.shape[0] + species_indexer.append(np.arange(n_cells[sid])) + genes_indexer.append(np.arange(gs[sid].size)) + + for i in range(1, len(species_indexer)): + species_indexer[i] = species_indexer[i] + species_indexer[i - 1].max() + 1 + genes_indexer[i] = genes_indexer[i] + genes_indexer[i - 1].max() + 1 + + # Gram matrices + column means — feed the sigma quadratic form + XtX: dict[str, Any] = {} + mu_ss: dict[str, Any] = {} + for sid in sids: + XtX[sid] = (ss[sid].T @ ss[sid]).tocsr() + mu_ss[sid] = bk.xp.asarray(ss[sid].mean(0)).flatten() + + # Own-species PC projection — fully iteration-invariant + wpca_own: dict[str, Any] = {} + M_own: dict[str, Any] = {} + for sid in sids: + wpca_own[sid] = ss[sid] @ PCs[sid] # N_sid × npcs_sid + M_own[sid] = mu_ss[sid] @ PCs[sid] # npcs_sid + + return { + "sids": sids, + "gs": gs, + "W": W, + "species_indexer": species_indexer, + "genes_indexer": genes_indexer, + "ss": ss, + "PCs": PCs, + "n_cells": n_cells, + "XtX": XtX, + "mu_ss": mu_ss, + "wpca_own": wpca_own, + "M_own": M_own, + "bk": bk, + } + + +# --------------------------------------------------------------------------- # +# Fast per-iteration path # +# --------------------------------------------------------------------------- # + + +def _mapping_window_fast( + gnnm: Any, + precompute: dict[str, Any], + K: int = 20, + pairwise: bool = True, +) -> dict[str, Any]: + """Cross-species projection using precomposed feature translation. + + Per-iteration worker: consumes the current homology graph ``gnnm`` and + the cached invariants from :func:`_projection_precompute`, and produces + the cross-species kNN graph. Never materialises the cells × genes + translated feature matrix — see the module docstring for the algebra. + + Parameters + ---------- + gnnm : sparse (G_total × G_total) + Current gene-homology graph. Row/column order must match + ``precompute['gs']`` block structure. + precompute : dict + Output of :func:`_projection_precompute`. + K : int + Number of nearest neighbours per species pair. + pairwise : bool + If True (default), the homology sub-block is re-normalised per + species pair. If False, the global column normalisation is used as-is. + These differ for 3+ species. + + Returns + ------- + dict + Same shape as legacy ``_mapping_window``: keys ``knn`` (CSR, host), + ``wPCA`` (dense, host), ``gnnm_corr`` (CSR, host; globally + column-normalised — downstream consumers rely on this). + """ + bk: Backend = precompute["bk"] + sids: list[str] = precompute["sids"] + n_species = len(sids) + species_indexer = precompute["species_indexer"] + genes_indexer = precompute["genes_indexer"] + ss = precompute["ss"] + PCs = precompute["PCs"] + W = precompute["W"] + n_cells = precompute["n_cells"] + XtX = precompute["XtX"] + mu_ss = precompute["mu_ss"] + wpca_own = precompute["wpca_own"] + M_own = precompute["M_own"] + + logger.info("Prepping datasets for translation.") + + # ---- Global gnnm_corr preparation (tanh-scale + column-normalise) ---- # + # This normalised graph is also a pipeline output (output_dict['gnnm_corr']). + gnnm_corr = bk.to_device(gnnm.copy()) + gnnm_corr.data[:] = _tanh_scale(bk.to_host(gnnm_corr.data)) # tanh is cheap; stay on host np + gnnm_corr = bk.to_device(gnnm_corr) + su = bk.xp.asarray(gnnm_corr.sum(0)) + su = bk.xp.where(su == 0, 1.0, su) + gnnm_corr = gnnm_corr.multiply(1.0 / su).tocsr() + + ttt = time.time() + if pairwise: + logger.info("Translating feature spaces pairwise.") + else: + logger.info("Translating feature spaces all-to-all.") + + # Per-species row-block of wpca, each decomposed into per-species column-blocks + # row_blocks[i][s] = N_i × npcs_s contribution + # M_blocks[i][s] = npcs_s mean-correction vector + row_blocks: list[list[Any]] = [[None] * n_species for _ in range(n_species)] + M_blocks: list[list[Any]] = [[None] * n_species for _ in range(n_species)] + + for i, sid_i in enumerate(sids): + gi = genes_indexer[i] + n_i = n_cells[sid_i] + + for s, sid_s in enumerate(sids): + if s == i: + # Own-species contribution — iteration-invariant, just reference + row_blocks[i][s] = wpca_own[sid_i] + M_blocks[i][s] = M_own[sid_i] + continue + + gs = genes_indexer[s] + # Extract cross-species homology sub-block + G_is = gnnm_corr[gi[0] : gi[-1] + 1, gs[0] : gs[-1] + 1] + + if pairwise: + # Re-normalise locally — matches legacy pairwise branch + col_sum = bk.xp.asarray(G_is.sum(0)) + col_sum = bk.xp.where(col_sum == 0, 1.0, col_sum) + G_is = G_is.multiply(1.0 / col_sum).tocsr() + + # ---- Sigma via quadratic form (no Xtr materialised) -------- # + sigma = _compute_sigma(XtX[sid_i], mu_ss[sid_i], G_is, n_i, bk) + # mu_terms needed below for the mean-correction; recompute (cheap) + mu_terms = bk.xp.asarray(mu_ss[sid_i] @ G_is).flatten() + + # ---- Precompose: P_is = G · diag(W/σ) · PCs_s --------------- # + # Shape: (G_i × npcs_s), dense — typically a few MB regardless of N + scale = W[sid_s] / sigma + P_is = G_is.multiply(scale).tocsr() @ PCs[sid_s] + + # ---- ONE SpMM replaces the N_i × G_s dense intermediate ---- # + row_blocks[i][s] = ss[sid_i] @ P_is # N_i × npcs_s + + # Mean-correction vector for this block — same identity as sigma, + # mu(Xtr_weighted) = (mu_ss · G) / σ · W + mu_cross = mu_terms / sigma * W[sid_s] + M_blocks[i][s] = mu_cross @ PCs[sid_s] + + gc.collect() + + logger.info("Projecting data into joint latent space. %.2fs", time.time() - ttt) + ttt = time.time() + + # ---- Assemble wpca: row-block i = hstack of column-blocks, minus M --- # + N_total = sum(n_cells.values()) + npcs_blocks = [PCs[sid].shape[1] for sid in sids] + npcs_total = sum(npcs_blocks) + wpca = bk.xp.zeros((N_total, npcs_total), dtype=bk.xp.float64) + + col_offsets = np.cumsum([0, *npcs_blocks]) + for i, sid_i in enumerate(sids): + r0, r1 = species_indexer[i][0], species_indexer[i][-1] + 1 + M_i = bk.xp.concatenate(M_blocks[i]) # full-width correction vector for species i + for s in range(n_species): + c0, c1 = col_offsets[s], col_offsets[s + 1] + block = row_blocks[i][s] + # row_blocks may come out sparse (rare, e.g. all-zero G_is); coerce + if hasattr(block, "toarray"): + block = block.toarray() + wpca[r0:r1, c0:c1] = block + wpca[r0:r1] -= M_i + + logger.info("Correcting data with means. %.2fs", time.time() - ttt) + + # ---- Cross-species kNN (hnswlib CPU or FAISS GPU via approximate_knn) # + wpca_host = bk.to_host(wpca) + gnnm_corr_host = bk.to_host(gnnm_corr) + + k = K + ixg = np.arange(wpca_host.shape[0]) + Xs: list[Any] = [] + Ys: list[Any] = [] + Vs: list[Any] = [] + for i in range(n_species): + ixq = species_indexer[i] + query = wpca_host[ixq] + for j in range(n_species): + if i == j: + continue + ixr = species_indexer[j] + reference = wpca_host[ixr] + + b = _united_proj(query, reference, k=k, bk=bk) + + su = np.asarray(b.sum(1)) + su[su == 0] = 1 + b = b.multiply(1 / su).tocsr() + + A = pd.Series(index=np.arange(b.shape[0]), data=ixq) + B = pd.Series(index=np.arange(b.shape[1]), data=ixr) + + x, y = b.nonzero() + x, y = A[x].values, B[y].values + Xs.extend(x) + Ys.extend(y) + Vs.extend(b.data) + + knn = sp.sparse.coo_matrix((Vs, (Xs, Ys)), shape=(ixg.size, ixg.size)) + + return { + "knn": knn.tocsr(), + "wPCA": wpca_host, + "gnnm_corr": gnnm_corr_host, + } + + +# --------------------------------------------------------------------------- # +# Backward-compat wrapper # +# --------------------------------------------------------------------------- # + + +def _mapping_window( + sams: dict[str, SAM], + gnnm: sp.sparse.csr_matrix | None = None, + gns: NDArray[Any] | None = None, + K: int = 20, + pairwise: bool = True, +) -> dict[str, Any]: + """Cross-species projection — backward-compatible entry point. + + Builds the precompute dict on-the-fly and delegates to + :func:`_mapping_window_fast`. For iterative use, prefer calling + :func:`_projection_precompute` once and :func:`_mapping_window_fast` + per iteration — the precompute is iteration-invariant and expensive. + + When ``gnnm is None`` (own-species-only projection, used by the no-graph + bootstrap path), falls back to the legacy implementation. + """ + if gnnm is not None and gns is not None: + pre = _projection_precompute(sams, gns, bk=Backend("cpu")) + return _mapping_window_fast(gnnm, pre, K=K, pairwise=pairwise) + + # ---- No-graph path (legacy, unchanged) ------------------------------- # + # Only the own-species projection — no homology graph, no cross blocks. + std = StandardScaler(with_mean=False) + adatas: dict[str, Any] = {} + Ws: dict[str, Any] = {} + ss: dict[str, Any] = {} + species_indexer: list[NDArray[Any]] = [] + mus: list[Any] = [] + for sid in sams: + adatas[sid] = sams[sid].adata + Ws[sid] = adatas[sid].var["weights"].values + ss[sid] = std.fit_transform(adatas[sid].X).multiply(Ws[sid][None, :]).tocsr() + mus.append(np.asarray(ss[sid].mean(0)).flatten()) + species_indexer.append(np.arange(ss[sid].shape[0])) + for i in range(1, len(species_indexer)): + species_indexer[i] = species_indexer[i] + species_indexer[i - 1].max() + 1 + X = sp.sparse.vstack(list(ss.values())) + C = np.hstack([adatas[sid].varm["PCs_SAMap"] for sid in sams]) + wpca = X.dot(C) + M = np.vstack(mus).dot(C) + for i, sid in enumerate(sams.keys()): + ixq = species_indexer[i] + wpca[ixq] -= M[i] + + k = K + ixg = np.arange(wpca.shape[0]) + Xs = [] + Ys = [] + Vs = [] + for i, sid in enumerate(sams.keys()): + ixq = species_indexer[i] + query = wpca[ixq] + for j, _sid2 in enumerate(sams.keys()): + if i != j: + ixr = species_indexer[j] + reference = wpca[ixr] + + b = _united_proj(query, reference, k=k) + + su = np.asarray(b.sum(1)) + su[su == 0] = 1 + b = b.multiply(1 / su).tocsr() + + A = pd.Series(index=np.arange(b.shape[0]), data=ixq) + B = pd.Series(index=np.arange(b.shape[1]), data=ixr) + + x, y = b.nonzero() + x, y = A[x].values, B[y].values + Xs.extend(x) + Ys.extend(y) + Vs.extend(b.data) + + knn = sp.sparse.coo_matrix((Vs, (Xs, Ys)), shape=(ixg.size, ixg.size)) + + return {"knn": knn.tocsr(), "wPCA": wpca} diff --git a/src/samap/sam/__init__.py b/src/samap/sam/__init__.py new file mode 100644 index 0000000..ff50d9c --- /dev/null +++ b/src/samap/sam/__init__.py @@ -0,0 +1,9 @@ +"""Vendored SAM (Self-Assembling Manifold) algorithm. + +Originally from the sc-sam package (samalg module). Vendored into SAMap +to eliminate the external dependency and enable targeted optimizations. +""" + +from .core import SAM + +__all__ = ["SAM"] diff --git a/src/samap/sam/core.py b/src/samap/sam/core.py new file mode 100644 index 0000000..cbd729e --- /dev/null +++ b/src/samap/sam/core.py @@ -0,0 +1,1311 @@ +"""Self-Assembling Manifold (SAM) algorithm — vendored core. + +Vendored from samalg.sam (sc-sam v2.0.2). Contains the SAM class with +only the methods SAMap actually uses. Heavy features dropped during +vendoring: + - run_umap, run_tsne, run_diff_map, run_diff_umap + - kmeans_clustering, hdbscan_clustering, louvain_clustering + - identify_marker_genes_* + - save/load (dill-based) + +Fixes applied: + - Removed obsm["X_processed"] = D_sub (n_cells x n_genes stored every + iteration, never read — pure memory waste). + - Replaced .tolil() + .setdiag() + .tocsr() cycles with direct CSR + setdiag (no format round-trip). + - Dropped numba import (SAM has no @jit functions; import existed + only to suppress a warning). + +Copyright 2018, Alexander J. Tarashansky. +""" + +from __future__ import annotations + +import contextlib +import gc +import time +import warnings +from typing import TYPE_CHECKING, Any, Literal + +import anndata +import numpy as np +import pandas as pd +import scipy.sparse as sp +import sklearn.utils.sparsefuncs as sf +from anndata import AnnData +from scipy.sparse import SparseEfficiencyWarning +from sklearn.preprocessing import Normalizer + +from .._logging import get_logger +from .knn import calc_nnm +from .pca import _pca_with_sparse, weighted_PCA +from .utils import convert_annotations + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from samap.core._backend import Backend + +logger = get_logger("samap.sam") + + +class DataNotLoadedError(RuntimeError): + """Raised when an operation requires data that has not been loaded.""" + + def __init__(self, msg: str | None = None) -> None: + super().__init__( + msg or "No data has been loaded. Use load_data() or pass data to the constructor." + ) + + +class InvalidParameterError(ValueError): + """Raised when a parameter has an invalid value.""" + + def __init__(self, param: str, value: Any, valid_values: list[Any] | None = None) -> None: + msg = f"Invalid value for '{param}': {value!r}." + if valid_values: + msg += f" Valid values: {valid_values}." + super().__init__(msg) + + +def _csr_setdiag(mat: sp.csr_matrix, val: float) -> sp.csr_matrix: + """Set the diagonal of a CSR matrix in place, suppressing efficiency warnings. + + scipy's CSR setdiag works natively; the lil round-trip in the original + SAM code was unnecessary. The SparseEfficiencyWarning fires only if + the diagonal entries don't already exist in the sparsity structure. + For SAM's k-NN matrices, the diagonal is always present (hnswlib + returns each point as its own nearest neighbor), so this is a no-op + structural change — just a data overwrite. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", SparseEfficiencyWarning) + mat.setdiag(val) + if val == 0: + mat.eliminate_zeros() + return mat + + +class SAM: + """Self-Assembling Manifolds single-cell RNA sequencing analysis tool. + + SAM iteratively rescales the input gene expression matrix to emphasize + genes that are spatially variable along the intrinsic manifold of the data. + It outputs the gene weights, nearest neighbor matrix, and a 2D projection. + + Parameters + ---------- + counts : tuple | list | pd.DataFrame | AnnData | None + Input data in one of the following formats: + - tuple/list: (data, gene_names, cell_names) where data is sparse/dense matrix + - pd.DataFrame: cells x genes expression matrix + - AnnData: annotated data object + inplace : bool, optional + If True and counts is AnnData, use the object directly without copying. + Default is False. + + Attributes + ---------- + preprocess_args : dict + Dictionary of arguments used for the 'preprocess_data' function. + run_args : dict + Dictionary of arguments used for the 'run' function. + adata_raw : AnnData + An AnnData object containing the raw, unfiltered input data. + adata : AnnData + An AnnData object containing all processed data and SAM outputs. + """ + + def __init__( + self, + counts: ( + tuple[sp.spmatrix | NDArray[np.floating[Any]], NDArray[Any], NDArray[Any]] + | list[Any] + | pd.DataFrame + | AnnData + | None + ) = None, + inplace: bool = False, + bk: Backend | None = None, + ) -> None: + # Backend for GPU dispatch in the iteration-loop hot spots + # (dispersion SpMM, sparse PCA, kNN). Lazy-construct Backend("cpu") + # if not provided so importing SAM doesn't force a dependency on + # samap.core._backend at module-load time. + if bk is None: + from samap.core._backend import Backend as _Backend + + bk = _Backend("cpu") + self._bk = bk + + self.run_args: dict[str, Any] = {} + self.preprocess_args: dict[str, Any] = {} + + if isinstance(counts, (tuple, list)): + raw_data, all_gene_names, all_cell_names = counts + if isinstance(raw_data, np.ndarray): + raw_data = sp.csr_matrix(raw_data) + + self.adata_raw = AnnData( + X=raw_data, + obs={"obs_names": all_cell_names}, + var={"var_names": all_gene_names}, + ) + + elif isinstance(counts, pd.DataFrame): + raw_data = sp.csr_matrix(counts.values) + all_gene_names = np.array(list(counts.columns.values)) + all_cell_names = np.array(list(counts.index.values)) + + self.adata_raw = AnnData( + X=raw_data, + obs={"obs_names": all_cell_names}, + var={"var_names": all_gene_names}, + ) + + elif isinstance(counts, AnnData): + all_cell_names = np.array(list(counts.obs_names)) + all_gene_names = np.array(list(counts.var_names)) + if counts.is_view: + counts = counts.copy() + + if inplace: + self.adata_raw = counts + else: + self.adata_raw = counts.copy() + + elif counts is not None: + raise TypeError( + "'counts' must be either a tuple/list of " + "(data, gene IDs, cell IDs), a Pandas DataFrame of " + "cells x genes, or an AnnData object." + ) + + if counts is not None: + if np.unique(all_gene_names).size != all_gene_names.size: + self.adata_raw.var_names_make_unique() + if np.unique(all_cell_names).size != all_cell_names.size: + self.adata_raw.obs_names_make_unique() + + if inplace: + self.adata = self.adata_raw + else: + self.adata = self.adata_raw.copy() + + if "X_disp" not in self.adata_raw.layers: + self.adata.layers["X_disp"] = self.adata.X + + def preprocess_data( + self, + div: float = 1, + downsample: float = 0, + sum_norm: str | float | None = "cell_median", + norm: str | None = "log", + min_expression: float = 1, + thresh_low: float = 0.0, + thresh_high: float = 0.96, + thresh: float | None = None, + filter_genes: bool = True, + ) -> None: + """Log-normalize and filter the expression data. + + Parameters + ---------- + div : float, optional + The factor by which the gene expression will be divided prior to + normalization. Default is 1. + downsample : float, optional + The factor by which to randomly downsample the data. If 0, the + data will not be downsampled. Default is 0. + sum_norm : str | float | None, optional + Library normalization method. Options: + - float: Normalize each cell to this total count + - 'cell_median': Normalize to median total count per cell + - 'gene_median': Normalize genes to median total count per gene + - None: No normalization + Default is 'cell_median'. + norm : str | None, optional + Data transformation method. Options: + - 'log': log2(x + 1) transformation + - 'ftt': Freeman-Tukey variance-stabilizing transformation + - 'asin': arcsinh transformation + - 'multinomial': Pearson residual transformation (experimental) + - None: No transformation + Default is 'log'. + min_expression : float, optional + Threshold above which a gene is considered expressed. Values below + this are set to zero. Default is 1. + thresh_low : float, optional + Keep genes expressed in greater than thresh_low*100% of cells. + Default is 0.0. + thresh_high : float, optional + Keep genes expressed in less than thresh_high*100% of cells. + Default is 0.96. + thresh : float | None, optional + If provided, sets thresh_low=thresh and thresh_high=1-thresh. + filter_genes : bool, optional + Whether to apply gene filtering. Default is True. + """ + if thresh is not None: + thresh_low = thresh + thresh_high = 1 - thresh + + if not hasattr(self, "adata_raw"): + raise DataNotLoadedError() + + self.preprocess_args = { + "div": div, + "sum_norm": sum_norm, + "norm": norm, + "min_expression": min_expression, + "thresh_low": thresh_low, + "thresh_high": thresh_high, + "filter_genes": filter_genes, + } + + self.run_args = self.adata.uns.get("run_args", {}) + + D = self.adata_raw.X + self.adata = self.adata_raw.copy() + + D = self.adata.X + if isinstance(D, np.ndarray): + D = sp.csr_matrix(D, dtype="float32") + else: + if str(D.dtype) != "float32": + D = D.astype("float32") + D.sort_indices() + + if D.getformat() == "csc": + D = D.tocsr() + + # Sum-normalize + if sum_norm == "cell_median" and norm != "multinomial": + s = np.asarray(D.sum(1)).flatten() + sum_norm_val = np.median(s) + D = D.multiply(1 / s[:, None] * sum_norm_val).tocsr() + elif sum_norm == "gene_median" and norm != "multinomial": + s = np.asarray(D.sum(0)).flatten() + sum_norm_val = np.median(s[s > 0]) + s[s == 0] = 1 + D = D.multiply(1 / s[None, :] * sum_norm_val).tocsr() + elif sum_norm is not None and norm != "multinomial": + D = D.multiply(1 / np.asarray(D.sum(1)).flatten()[:, None] * sum_norm).tocsr() + + # Normalize + self.adata.X = D + if norm is None: + D.data[:] = D.data / div + + elif norm.lower() == "log": + D.data[:] = np.log2(D.data / div + 1) + + elif norm.lower() == "ftt": + D.data[:] = np.sqrt(D.data / div) + np.sqrt(D.data / div + 1) - 1 + + elif norm.lower() == "asin": + D.data[:] = np.arcsinh(D.data / div) + + elif norm.lower() == "multinomial": + ni = np.asarray(D.sum(1)).flatten() # cells + pj = np.asarray(D.sum(0) / D.sum()).flatten() # genes + col = D.indices + row = [] + for i in range(D.shape[0]): + row.append(i * np.ones(D.indptr[i + 1] - D.indptr[i])) + row = np.concatenate(row).astype("int32") + mu = sp.coo_matrix((ni[row] * pj[col], (row, col))).tocsr() + mu2 = mu.copy() + mu2.data[:] = mu2.data**2 + mu2 = mu2.multiply(1 / ni[:, None]) + mu.data[:] = (D.data - mu.data) / np.sqrt(mu.data - mu2.data) + + self.adata.X = mu + if sum_norm is None: + sum_norm = np.median(ni) + D = D.multiply(1 / ni[:, None] * sum_norm).tocsr() + D.data[:] = np.log2(D.data / div + 1) + + else: + D.data[:] = D.data / div + + # Zero-out low-expressed genes + idx = np.where(D.data <= min_expression)[0] + D.data[idx] = 0 + + # Filter genes + idx_genes = np.arange(D.shape[1]) + if filter_genes: + a, ct = np.unique(D.indices, return_counts=True) + c = np.zeros(D.shape[1]) + c[a] = ct + + keep = np.where( + np.logical_and(c / D.shape[0] > thresh_low, c / D.shape[0] <= thresh_high) + )[0] + + idx_genes = np.array(list(set(keep) & set(idx_genes)), dtype=np.intp) + + mask_genes = np.zeros(D.shape[1], dtype="bool") + mask_genes[idx_genes] = True + + self.adata.X = self.adata.X.multiply(mask_genes[None, :]).tocsr() + self.adata.X.eliminate_zeros() + self.adata.var["mask_genes"] = mask_genes + + if norm == "multinomial": + self.adata.layers["X_disp"] = D.multiply(mask_genes[None, :]).tocsr() + self.adata.layers["X_disp"].eliminate_zeros() + else: + self.adata.layers["X_disp"] = self.adata.X + + self.calculate_mean_var() + + self.adata.uns["preprocess_args"] = self.preprocess_args + self.adata.uns["run_args"] = self.run_args + + def calculate_mean_var(self, adata: AnnData | None = None) -> None: + """Calculate mean and variance for each gene. + + Parameters + ---------- + adata : AnnData | None, optional + The AnnData object to calculate statistics for. + If None, uses self.adata. + """ + if adata is None: + adata = self.adata + + if sp.issparse(adata.X): + mu, var = sf.mean_variance_axis(adata.X, axis=0) + else: + mu = adata.X.mean(0) + var = adata.X.var(0) + + adata.var["means"] = mu + adata.var["variances"] = var + + def get_labels(self, key: str) -> NDArray[Any]: + """Get labels from obs. + + Parameters + ---------- + key : str + Key in adata.obs. + + Returns + ------- + NDArray + Array of labels. + """ + if key not in list(self.adata.obs.keys()): + logger.warning("Key '%s' does not exist in `obs`.", key) + return np.array([]) + return np.array(list(self.adata.obs[key])) + + def load_data( + self, + filename: str, + transpose: bool = True, + sep: str = ",", + calculate_avg: bool = False, + **kwargs: Any, + ) -> None: + """Load expression data from file. + + Parameters + ---------- + filename : str + Path to the data file. Supported formats: + - .csv/.txt: Tabular format (genes x cells by default) + - .h5ad: AnnData format + transpose : bool, optional + If True (default), assumes file is genes x cells. + Set to False if file is cells x genes. + sep : str, optional + Delimiter for CSV/TXT files. Default is ','. + calculate_avg : bool, optional + If True and loading .h5ad with existing neighbors, perform + kNN averaging. Default is False. + **kwargs + Additional arguments passed to file loading functions. + """ + ext = filename.split(".")[-1] + + if ext != "h5ad": + df = pd.read_csv(filename, sep=sep, index_col=0, **kwargs) + dataset = df.T if transpose else df + + raw_data = sp.csr_matrix(dataset.values) + all_cell_names = np.array(list(dataset.index.values)) + all_gene_names = np.array(list(dataset.columns.values)) + + self.adata_raw = AnnData( + X=raw_data, + obs={"obs_names": all_cell_names}, + var={"var_names": all_gene_names}, + ) + + if np.unique(all_gene_names).size != all_gene_names.size: + self.adata_raw.var_names_make_unique() + if np.unique(all_cell_names).size != all_cell_names.size: + self.adata_raw.obs_names_make_unique() + + self.adata = self.adata_raw.copy() + self.adata.layers["X_disp"] = raw_data + + else: + self.adata = anndata.read_h5ad(filename, **kwargs) + if self.adata.raw is not None: + self.adata_raw = AnnData(X=self.adata.raw.X) + self.adata_raw.var_names = self.adata.var_names + self.adata_raw.obs_names = self.adata.obs_names + self.adata_raw.obs = self.adata.obs + + del self.adata.raw + + if ( + "X_knn_avg" not in self.adata.layers + and "connectivities" in self.adata.obsp + and calculate_avg + ): + self.dispersion_ranking_NN(save_avgs=True) + else: + self.adata_raw = self.adata + + if "X_disp" not in list(self.adata.layers.keys()): + self.adata.layers["X_disp"] = self.adata.X + + filename = ".".join(filename.split(".")[:-1]) + ".h5ad" + self.adata.uns["path_to_file"] = filename + self.adata_raw.uns["path_to_file"] = filename + + def save_anndata(self, fname: str = "", save_knn: bool = False, **kwargs: Any) -> None: + """Save adata to an h5ad file. + + Parameters + ---------- + fname : str, optional + Output file path. If empty, uses path from adata.uns['path_to_file']. + save_knn : bool, optional + If True, include X_knn_avg layer. Default is False (layer can be large). + **kwargs + Additional arguments passed to AnnData.write_h5ad(). + """ + Xknn = None + if not save_knn and "X_knn_avg" in self.adata.layers: + Xknn = self.adata.layers["X_knn_avg"] + del self.adata.layers["X_knn_avg"] + + if fname == "": + if "path_to_file" not in self.adata.uns: + raise KeyError("Path to file not known.") + fname = self.adata.uns["path_to_file"] + + x = self.adata + x.raw = self.adata_raw + + # Fix weird issues when index name is an integer + for y in [ + x.obs.columns, + x.var.columns, + x.obs.index, + x.var.index, + x.raw.var.index, + x.raw.var.columns, + ]: + y.name = str(y.name) if y.name is not None else None + + x.write_h5ad(fname, **kwargs) + del x.raw + + if Xknn is not None: + self.adata.layers["X_knn_avg"] = Xknn + + def dispersion_ranking_NN( + self, + nnm: sp.spmatrix | None = None, + num_norm_avg: int = 50, + weight_mode: Literal["dispersion", "variance", "rms", "combined"] = "combined", + save_avgs: bool = False, + adata: AnnData | None = None, + ) -> NDArray[np.float64]: + """Compute spatial dispersion factors for each gene. + + Parameters + ---------- + nnm : scipy.sparse.spmatrix | None, optional + Cell-to-cell nearest-neighbor matrix. If None, uses + adata.obsp['connectivities']. + num_norm_avg : int, optional + Number of top dispersions to average for normalization. Default is 50. + weight_mode : str, optional + Weight calculation method. One of 'dispersion', 'variance', 'rms', + 'combined'. Default is 'combined'. + save_avgs : bool, optional + If True, save kNN-averaged values to layers['X_knn_avg']. Default is False. + adata : AnnData | None, optional + AnnData object to use. If None, uses self.adata. + + Returns + ------- + NDArray[np.float64] + Vector of gene weights. + """ + if adata is None: + adata = self.adata + + if nnm is None: + nnm = adata.obsp["connectivities"] + f = np.asarray(nnm.sum(1)) + f[f == 0] = 1 + + bk = self._bk + + # --- SpMM: D_avg = (nnm / row_sums) @ X_disp ---------------------- + # This is the dominant cost of the SAM iteration. On GPU we upload + # both sparse operands once, do a cuSPARSE SpGEMM, compute column + # mean/var on the device, and pull back only the (n_genes,) vectors. + # The rest of the dispersion arithmetic is cheap numpy on host. + if bk.gpu: + # Row-normalise nnm before upload so we do one SpGEMM on device. + nnm_norm = nnm.multiply(1.0 / f) + nnm_g = bk.to_device(nnm_norm.tocsr()) + Xd_g = bk.to_device(adata.layers["X_disp"].tocsr()) + D_avg_g = nnm_g @ Xd_g # cuSPARSE sparse-sparse → sparse (n_cells, n_genes) + + xp = bk.xp + n = D_avg_g.shape[0] + # Mean: column sums / n. Var: E[x²] - E[x]² (population variance, + # matching sklearn.sparsefuncs.mean_variance_axis axis=0 ddof=0). + col_sum = xp.asarray(D_avg_g.sum(axis=0)).ravel() + mu = col_sum / n + # E[x²] via squaring the data buffer. We need D_avg_g.data² summed + # per column — reuse the matrix structure. + D_sq_g = D_avg_g.copy() + D_sq_g.data = D_sq_g.data**2 + ex2 = xp.asarray(D_sq_g.sum(axis=0)).ravel() / n + var = ex2 - mu**2 + + mu2 = None + if weight_mode in ("rms", "combined"): + # RMS = sqrt(E[x²]) — we already have ex2. + mu_rms = xp.sqrt(ex2) + if weight_mode == "rms": + mu = mu_rms + else: # combined + mu2 = mu_rms + + mu = bk.to_host(mu) + var = bk.to_host(var) + if mu2 is not None: + mu2 = bk.to_host(mu2) + + if save_avgs: + adata.layers["X_knn_avg"] = bk.to_host(D_avg_g) + del nnm_g, Xd_g, D_avg_g, D_sq_g + bk.free_pool() + + else: + # --- CPU path (original) -------------------------------------- + D_avg = (nnm.multiply(1 / f)).dot(adata.layers["X_disp"]) + + if save_avgs: + adata.layers["X_knn_avg"] = D_avg.copy() + + if sp.issparse(D_avg): + mu, var = sf.mean_variance_axis(D_avg, axis=0) + if weight_mode == "rms": + D_avg.data[:] = D_avg.data**2 + mu, _ = sf.mean_variance_axis(D_avg, axis=0) + mu = mu**0.5 + + if weight_mode == "combined": + D_avg.data[:] = D_avg.data**2 + mu2, _ = sf.mean_variance_axis(D_avg, axis=0) + mu2 = mu2**0.5 + else: + mu = D_avg.mean(0) + var = D_avg.var(0) + if weight_mode == "rms": + mu = (D_avg**2).mean(0) ** 0.5 + if weight_mode == "combined": + mu2 = (D_avg**2).mean(0) ** 0.5 + + if not save_avgs: + del D_avg + gc.collect() + + if weight_mode in ("dispersion", "rms", "combined"): + dispersions = np.zeros(var.size) + dispersions[mu > 0] = var[mu > 0] / mu[mu > 0] + adata.var["spatial_dispersions"] = dispersions.copy() + + if weight_mode == "combined": + dispersions2 = np.zeros(var.size) + dispersions2[mu2 > 0] = var[mu2 > 0] / mu2[mu2 > 0] + + elif weight_mode == "variance": + dispersions = var + adata.var["spatial_variances"] = dispersions.copy() + else: + raise InvalidParameterError( + "weight_mode", + weight_mode, + valid_values=["dispersion", "variance", "rms", "combined"], + ) + + ma = np.sort(dispersions)[-num_norm_avg:].mean() + dispersions[dispersions >= ma] = ma + + weights = ((dispersions / dispersions.max()) ** 0.5).flatten() + + if weight_mode == "combined": + ma = np.sort(dispersions2)[-num_norm_avg:].mean() + dispersions2[dispersions2 >= ma] = ma + + weights2 = ((dispersions2 / dispersions2.max()) ** 0.5).flatten() + weights = np.vstack((weights, weights2)).max(0) + + return weights + + def run( + self, + max_iter: int = 10, + verbose: bool = True, + projection: str | None = None, + stopping_condition: float = 1e-2, + num_norm_avg: int = 50, + k: int = 20, + distance: Literal["correlation", "euclidean", "cosine"] = "cosine", + preprocessing: Literal["StandardScaler", "Normalizer"] | None = "StandardScaler", + npcs: int = 150, + n_genes: int | None = 3000, + weight_PCs: bool = False, + sparse_pca: bool = False, + proj_kwargs: dict[str, Any] | None = None, + seed: int = 0, + weight_mode: Literal["dispersion", "variance", "rms", "combined"] = "rms", + components: NDArray[np.floating[Any]] | None = None, + batch_key: str | None = None, + ) -> None: + """Run the Self-Assembling Manifold algorithm. + + Parameters + ---------- + max_iter : int, optional + Maximum number of iterations. Default is 10. + verbose : bool, optional + If True, print progress. Default is True. + projection : str | None, optional + Projection method. In this vendored version, projections are not + computed; pass None (default). A non-None value logs a warning. + stopping_condition : float, optional + RMSE threshold for convergence. Default is 1e-2. + num_norm_avg : int, optional + Top dispersions to average for normalization. Default is 50. + k : int, optional + Number of nearest neighbors. Default is 20. + distance : str, optional + Distance metric: 'correlation', 'euclidean', 'cosine'. Default is 'cosine'. + preprocessing : str | None, optional + Preprocessing method: 'StandardScaler', 'Normalizer', None. + Default is 'StandardScaler'. + npcs : int, optional + Number of principal components. Default is 150. + n_genes : int | None, optional + Number of genes to use. Default is 3000. If None, uses all genes. + weight_PCs : bool, optional + Weight PCs by eigenvalues. Default is False. + sparse_pca : bool, optional + Use sparse PCA implementation. Default is False. + proj_kwargs : dict | None, optional + Unused in vendored version. Kept for signature compatibility. + seed : int, optional + Random seed. Default is 0. + weight_mode : str, optional + Weight calculation mode. Default is 'rms'. + components : NDArray | None, optional + Pre-computed PCA components. Default is None. + batch_key : str | None, optional + Key in obs for batch correction with Harmony. Default is None. + """ + if proj_kwargs is None: + proj_kwargs = {} + + D = self.adata.X + if k < 5: + k = 5 + if k > D.shape[0] - 1: + k = D.shape[0] - 2 + + if preprocessing not in ("StandardScaler", "Normalizer", None, "None"): + raise InvalidParameterError( + "preprocessing", + preprocessing, + valid_values=["StandardScaler", "Normalizer", None], + ) + if weight_mode not in ("dispersion", "variance", "rms", "combined"): + raise InvalidParameterError( + "weight_mode", + weight_mode, + valid_values=["dispersion", "variance", "rms", "combined"], + ) + + if self.adata.layers["X_disp"].min() < 0 and weight_mode == "dispersion": + logger.warning( + "`X_disp` layer contains negative values. Setting `weight_mode` to 'rms'." + ) + weight_mode = "rms" + + numcells = D.shape[0] + + if n_genes is None: + n_genes = self.adata.shape[1] + if not sparse_pca and numcells > 10000: + warnings.warn( + "All genes are being used. It is recommended " + "to set `sparse_pca=True` to satisfy memory " + "constraints for datasets with more than " + "10,000 cells. Setting `sparse_pca` to True.", + stacklevel=2, + ) + sparse_pca = True + + if not sparse_pca: + n_genes = min(n_genes, (D.sum(0) > 0).sum()) + + self.run_args = { + "max_iter": max_iter, + "verbose": verbose, + "projection": projection, + "stopping_condition": stopping_condition, + "num_norm_avg": num_norm_avg, + "k": k, + "distance": distance, + "preprocessing": preprocessing, + "npcs": npcs, + "n_genes": n_genes, + "weight_PCs": weight_PCs, + "proj_kwargs": proj_kwargs, + "sparse_pca": sparse_pca, + "weight_mode": weight_mode, + "seed": seed, + "components": components, + } + self.adata.uns["run_args"] = self.run_args + + tinit = time.time() + np.random.seed(seed) + + if verbose: + logger.info("Running SAM algorithm") + + W = np.ones(D.shape[1]) + self.adata.var["weights"] = W + + old = np.zeros(W.size) + new = W + + i = 0 + err = ((new - old) ** 2).mean() ** 0.5 + + if max_iter < 5: + max_iter = 5 + + nnas = num_norm_avg + + while i < max_iter and err > stopping_condition: + conv = err + if verbose: + logger.info("Iteration: %d, Convergence: %.6f", i, conv) + + i += 1 + old = new + first = i == 1 + + W = self.calculate_nnm( + batch_key=batch_key, + n_genes=n_genes, + preprocessing=preprocessing, + npcs=npcs, + num_norm_avg=nnas, + weight_PCs=weight_PCs, + sparse_pca=sparse_pca, + weight_mode=weight_mode, + seed=seed, + components=components, + first=first, + ) + gc.collect() + new = W + err = ((new - old) ** 2).mean() ** 0.5 + self.adata.var["weights"] = W + + all_gene_names = np.array(list(self.adata.var_names)) + indices = np.argsort(-W) + ranked_genes = all_gene_names[indices] + + self.adata.uns["ranked_genes"] = ranked_genes + + # Projections (umap/tsne/diff_umap) stripped in vendored version. + # SAMap computes its own projections on the combined manifold. + if projection is not None: + logger.warning( + "projection=%r requested but projection methods are not included " + "in the vendored SAM. Compute projections separately if needed.", + projection, + ) + + elapsed = time.time() - tinit + if verbose: + logger.info("Elapsed time: %.2f seconds", elapsed) + + def calculate_nnm( + self, + adata: AnnData | None = None, + batch_key: str | None = None, + g_weighted: NDArray[np.floating[Any]] | None = None, + n_genes: int = 3000, + preprocessing: str | None = "StandardScaler", + npcs: int = 150, + num_norm_avg: int = 50, + weight_PCs: bool = False, + sparse_pca: bool = False, + update_manifold: bool = True, + weight_mode: str = "dispersion", + seed: int = 0, + components: NDArray[np.floating[Any]] | None = None, + first: bool = False, + ) -> NDArray[np.float64] | tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]: + """Calculate nearest neighbor matrix and update weights. + + This is the core iteration step of the SAM algorithm. + + Parameters + ---------- + adata : AnnData | None + AnnData object to use. + batch_key : str | None + Key for batch correction. + g_weighted : NDArray | None + Pre-computed weighted coordinates. + n_genes : int + Number of genes to use. + preprocessing : str | None + Preprocessing method. + npcs : int + Number of PCs. + num_norm_avg : int + Normalization averaging. + weight_PCs : bool + Weight by eigenvalues. + sparse_pca : bool + Use sparse PCA. + update_manifold : bool + Update manifold structure. + weight_mode : str + Weight calculation mode. + seed : int + Random seed. + components : NDArray | None + Pre-computed components. + first : bool + Is this the first iteration. + + Returns + ------- + NDArray | tuple + Gene weights, or (PCs, weighted_coords) if not updating manifold. + """ + if adata is None: + adata = self.adata + + numcells = adata.shape[0] + k = adata.uns["run_args"].get("k", 20) + distance = adata.uns["run_args"].get("distance", "correlation") + + D = adata.X + W = adata.var["weights"].values + + if "means" not in adata.var or "variances" not in adata.var: + self.calculate_mean_var(adata) + + if n_genes is None: + gkeep = np.arange(W.size) + else: + if first: + mu = np.array(list(adata.var["means"])) + var = np.array(list(adata.var["variances"])) + mu[mu == 0] = 1 + dispersions = var / mu + gkeep = np.sort(np.argsort(-dispersions)[:n_genes]) + else: + gkeep = np.sort(np.argsort(-W)[:n_genes]) + + if g_weighted is None: + if preprocessing == "Normalizer": + Ds = D[:, gkeep] + if sp.issparse(Ds) and not sparse_pca: + Ds = Ds.toarray() + + Ds = Normalizer().fit_transform(Ds) + + elif preprocessing == "StandardScaler": + if not sparse_pca: + Ds = D[:, gkeep] + if sp.issparse(Ds): + Ds = Ds.toarray() + + v = adata.var["variances"].values[gkeep] + m = adata.var["means"].values[gkeep] + v[v == 0] = 1 + Ds = (Ds - m) / v**0.5 + + Ds[Ds > 10] = 10 + Ds[Ds < -10] = -10 + else: + Ds = D[:, gkeep] + v = adata.var["variances"].values[gkeep] + v[v == 0] = 1 + Ds = Ds.multiply(1 / v**0.5).tocsr() + + else: + Ds = D[:, gkeep].toarray() + + D_sub = Ds.multiply(W[gkeep]).tocsr() if sp.issparse(Ds) else Ds * W[gkeep] + + if components is None: + if not sparse_pca: + npcs = min(npcs, min((D.shape[0], gkeep.size))) + if numcells > 500: + g_weighted, pca = weighted_PCA( + D_sub, + npcs=npcs, + do_weight=weight_PCs, + solver="auto", + seed=seed, + ) + else: + g_weighted, pca = weighted_PCA( + D_sub, + npcs=npcs, + do_weight=weight_PCs, + solver="full", + seed=seed, + ) + components = pca.components_ + + else: + npcs = min(npcs, min((D.shape[0], gkeep.size)) - 1) + v = adata.var["variances"].values[gkeep] + v[v == 0] = 1 + m = adata.var["means"].values[gkeep] * W[gkeep] + if preprocessing == "StandardScaler": + no = m / v**0.5 + else: + no = np.asarray(D_sub.mean(0)).flatten() + mean_correction = no + output = _pca_with_sparse(D_sub, npcs, mu=(no)[None, :], seed=seed, bk=self._bk) + components = output["components"] + g_weighted = output["X_pca"] + + if weight_PCs: + ev = output["variance"] + ev = ev / ev.max() + g_weighted = g_weighted * (ev**0.5) + else: + components = components[:, gkeep] + v = adata.var["variances"].values[gkeep] + v[v == 0] = 1 + m = adata.var["means"].values[gkeep] * W[gkeep] + if preprocessing == "StandardScaler": + ns = m / v**0.5 + else: + ns = np.asarray(D_sub.mean(0)).flatten() + mean_correction = ns + + if sp.issparse(D_sub): + g_weighted = D_sub.dot(components.T) - ns.flatten().dot(components.T) + else: + g_weighted = (D_sub - ns).dot(components.T) + if weight_PCs: + ev = g_weighted.var(0) + ev = ev / ev.max() + g_weighted = g_weighted * (ev**0.5) + + adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, npcs)) + adata.varm["PCs"][gkeep] = components.T + # NOTE: original SAM stored D_sub in obsm["X_processed"] here — + # an (n_cells x n_genes) matrix written every iteration and never + # read back. Dropped during vendoring. + adata.uns["dimred_indices"] = gkeep + if sparse_pca: + mc = np.zeros(adata.shape[1]) + mc[gkeep] = mean_correction + adata.var["mean_correction"] = mc + + if batch_key is not None: + try: + import harmonypy + + harmony_out = harmonypy.run_harmony(g_weighted, adata.obs, batch_key, verbose=False) + g_weighted = harmony_out.Z_corr.T + except ImportError as err: + raise ImportError( + "harmonypy is required for batch correction. " + "Install it with: pip install harmonypy" + ) from err + + if update_manifold: + edm = calc_nnm(g_weighted, k, distance, bk=self._bk) + + # Distances matrix: zero out self-distances on the diagonal. + edm_dist = edm.copy() + _csr_setdiag(edm_dist, 0) + adata.obsp["distances"] = edm_dist + + # Connectivities: binary adjacency with self-loops. + EDM = edm.copy() + EDM.data[:] = 1 + _csr_setdiag(EDM, 1) + adata.obsp["connectivities"] = EDM + + # nnm: similarity-weighted adjacency for correlation/cosine. + if distance in ("correlation", "cosine"): + edm.data[:] = 1 - edm.data + _csr_setdiag(edm, 1) + edm.data[edm.data < 0] = 0.001 + adata.obsp["nnm"] = edm + else: + adata.obsp["nnm"] = EDM + + W = self.dispersion_ranking_NN( + EDM, weight_mode=weight_mode, num_norm_avg=num_norm_avg, adata=adata + ) + adata.obsm["X_pca"] = g_weighted + return W + else: + logger.info("Not updating the manifold...") + PCs = np.zeros(shape=(adata.n_vars, npcs)) + PCs[gkeep] = components.T + return PCs, g_weighted + + def leiden_clustering( + self, + X: sp.spmatrix | None = None, + res: float = 1, + method: Literal["modularity", "significance"] = "modularity", + seed: int = 0, + ) -> NDArray[np.int64] | None: + """Perform Leiden clustering. + + On a CUDA backend with rapids-singlecell installed, dispatches to + the cugraph-backed GPU implementation for the common case + (``X=None``, ``method='modularity'``). Otherwise uses CPU + leidenalg/igraph — which also handles custom adjacency matrices + and the significance-based partition. + + Parameters + ---------- + X : sparse matrix | None, optional + Adjacency matrix. If None, uses connectivities. + res : float, optional + Resolution parameter. Default is 1. + method : str, optional + Optimization method. Default is 'modularity'. + seed : int, optional + Random seed. Default is 0. + + Returns + ------- + NDArray | None + Cluster labels if X provided, None otherwise. + """ + # --- GPU fast path ---------------------------------------------------- + # rsc.tl.leiden only handles the modularity partition on an AnnData + # with pre-computed neighbors. That covers SAM's default invocation + # (X=None, method='modularity'). Custom-X and significance paths fall + # through to CPU leidenalg below, which is the only implementation + # that supports them. + if X is None and method == "modularity" and self._bk.gpu: + from samap import _rsc_compat + + if _rsc_compat.HAS_RSC: + _rsc_compat.leiden( + self.adata, + self._bk, + resolution=res, + key_added="leiden_clusters", + random_state=seed, + ) + return None + + if X is None: + X = self.adata.obsp["connectivities"] + save = True + else: + if not sp.isspmatrix_csr(X): + X = sp.csr_matrix(X) + save = False + + import igraph as ig + import leidenalg + + adjacency = X + sources, targets = adjacency.nonzero() + weights = adjacency[sources, targets] + if isinstance(weights, np.matrix): + weights = np.asarray(weights).flatten() + g = ig.Graph(directed=True) + g.add_vertices(adjacency.shape[0]) + g.add_edges(list(zip(sources, targets, strict=False))) + with contextlib.suppress(ValueError, TypeError): + g.es["weight"] = weights + + if method == "significance": + cl = leidenalg.find_partition(g, leidenalg.SignificanceVertexPartition, seed=seed) + else: + cl = leidenalg.find_partition( + g, leidenalg.RBConfigurationVertexPartition, resolution_parameter=res, seed=seed + ) + + if save: + if method == "modularity": + self.adata.obs["leiden_clusters"] = pd.Categorical(np.array(cl.membership)) + elif method == "significance": + self.adata.obs["leiden_sig_clusters"] = pd.Categorical(np.array(cl.membership)) + return None + return np.array(cl.membership) + + def scatter( + self, + projection: str | NDArray[np.floating[Any]] | None = None, + c: str | NDArray[Any] | None = None, + colorspec: str | NDArray[Any] | None = None, + cmap: str = "rainbow", + linewidth: float = 0.0, + edgecolor: str = "k", + axes: Any | None = None, + colorbar: bool = True, + s: float = 10, + **kwargs: Any, + ) -> Any: + """Display a scatter plot. + + Parameters + ---------- + projection : str | NDArray | None, optional + Key in adata.obsm or 2D coordinates array. Default is UMAP. + c : str | NDArray | None, optional + Color data - key in adata.obs or array. + colorspec : str | NDArray | None, optional + Direct color specification. + cmap : str, optional + Colormap name. Default is 'rainbow'. + linewidth : float, optional + Marker edge width. Default is 0.0. + edgecolor : str, optional + Marker edge color. Default is 'k'. + axes : matplotlib.axes.Axes | None, optional + Existing axes to plot on. + colorbar : bool, optional + Whether to show colorbar. Default is True. + s : float, optional + Marker size. Default is 10. + **kwargs + Additional arguments passed to matplotlib.pyplot.scatter. + + Returns + ------- + matplotlib.axes.Axes + The axes object. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + logger.error("matplotlib not installed!") + return None + + if isinstance(projection, str): + if projection not in self.adata.obsm: + logger.error("Projection %r not found in adata.obsm", projection) + return None + dt = self.adata.obsm[projection] + + elif projection is None: + if "X_umap" in self.adata.obsm: + dt = self.adata.obsm["X_umap"] + elif "X_tsne" in self.adata.obsm: + dt = self.adata.obsm["X_tsne"] + else: + logger.error("No projection found. Pass one via `projection=`.") + return None + else: + dt = projection + + if axes is None: + plt.figure() + axes = plt.gca() + + if colorspec is not None: + axes.scatter( + dt[:, 0], + dt[:, 1], + s=s, + linewidth=linewidth, + edgecolor=edgecolor, + c=colorspec, + **kwargs, + ) + elif c is None: + axes.scatter( + dt[:, 0], + dt[:, 1], + s=s, + linewidth=linewidth, + edgecolor=edgecolor, + **kwargs, + ) + else: + if isinstance(c, str): + with contextlib.suppress(KeyError): + c = self.get_labels(c) + + if (isinstance(c[0], (str, np.str_))) and (isinstance(c, (np.ndarray, list))): + i = convert_annotations(c) + ui, ai = np.unique(i, return_index=True) + cax = axes.scatter( + dt[:, 0], + dt[:, 1], + c=i, + cmap=cmap, + s=s, + linewidth=linewidth, + edgecolor=edgecolor, + **kwargs, + ) + + if colorbar: + cbar = plt.colorbar(cax, ax=axes, ticks=ui) + cbar.ax.set_yticklabels(c[ai]) + else: + if not isinstance(c, (np.ndarray, list)): + colorbar = False + i = c + + scatter_kwargs: dict[str, Any] = { + "c": i, + "s": s, + "linewidth": linewidth, + "edgecolor": edgecolor, + **kwargs, + } + if isinstance(i, np.ndarray) and np.issubdtype(i.dtype, np.number): + scatter_kwargs["cmap"] = cmap + + cax = axes.scatter(dt[:, 0], dt[:, 1], **scatter_kwargs) + + if colorbar: + plt.colorbar(cax, ax=axes) + return axes diff --git a/src/samap/sam/knn.py b/src/samap/sam/knn.py new file mode 100644 index 0000000..58e4e76 --- /dev/null +++ b/src/samap/sam/knn.py @@ -0,0 +1,140 @@ +"""k-NN graph construction for the vendored SAM algorithm. + +Vendored from samalg.utilities with `gen_sparse_knn` rewritten for +direct CSR construction (no lil_matrix scatter). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from scipy import sparse + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from samap.core._backend import Backend + + +def gen_sparse_knn( + knni: NDArray[np.int64], + knnd: NDArray[np.floating[Any]], + shape: tuple[int, int] | None = None, +) -> sparse.csr_matrix: + """Generate sparse k-NN matrix from indices and distances. + + Direct CSR construction via COO. Replaces the original lil_matrix + + fancy-index scatter, which was O(n*k) Python-loop overhead inside + scipy's lil assignment path. + + Parameters + ---------- + knni : NDArray + k-NN indices (n x k). + knnd : NDArray + k-NN distances (n x k). + shape : tuple | None, optional + Output shape. If None, uses (n, n). + + Returns + ------- + sparse.csr_matrix + Sparse k-NN matrix. + """ + n, k = knni.shape + if shape is None: + shape = (n, n) + rows = np.repeat(np.arange(n, dtype=np.int32), k) + cols = knni.ravel().astype(np.int32, copy=False) + data = knnd.ravel() + # COO -> CSR handles duplicate (row, col) pairs by summing, and sorts + # column indices within each row automatically. + return sparse.csr_matrix((data, (rows, cols)), shape=shape) + + +def nearest_neighbors_hnsw( + x: NDArray[np.floating[Any]], + ef: int = 200, + M: int = 48, + n_neighbors: int = 100, +) -> tuple[NDArray[np.int64], NDArray[np.floating[Any]]]: + """Compute approximate nearest neighbors using HNSW algorithm. + + Parameters + ---------- + x : NDArray + Input data matrix. + ef : int, optional + HNSW ef parameter (search quality). Default is 200. + M : int, optional + HNSW M parameter (graph connectivity). Default is 48. + n_neighbors : int, optional + Number of neighbors. Default is 100. + + Returns + ------- + tuple + (indices, distances) arrays of shape (n, k). + """ + import hnswlib + + labels = np.arange(x.shape[0]) + p = hnswlib.Index(space="cosine", dim=x.shape[1]) + p.init_index(max_elements=x.shape[0], ef_construction=ef, M=M) + p.add_items(x, labels) + p.set_ef(ef) + idx, dist = p.knn_query(x, k=n_neighbors) + return idx, dist + + +def _nearest_neighbors_umap( + X: NDArray[np.floating[Any]], + n_neighbors: int = 15, + metric: str = "correlation", + random_state: int = 0, +) -> tuple[NDArray[np.int64], NDArray[np.floating[Any]]]: + """Fallback k-NN via UMAP's nearest_neighbors (pynndescent).""" + from umap.umap_ import nearest_neighbors + + rs = np.random.RandomState(random_state) + return nearest_neighbors(X, n_neighbors, metric, {}, True, rs)[:2] + + +def calc_nnm( + g_weighted: NDArray[np.floating[Any]], + k: int, + distance: str | None = None, + bk: Backend | None = None, +) -> sparse.csr_matrix: + """Calculate k-nearest neighbor matrix. + + Parameters + ---------- + g_weighted : NDArray + Input coordinates (typically PCA-reduced). + k : int + Number of neighbors. + distance : str | None, optional + Distance metric. If 'cosine', dispatches to + :func:`samap.core.knn.approximate_knn` (FAISS-GPU on a CUDA + backend, hnswlib otherwise). For other metrics falls back to + UMAP's pynndescent — CPU only. + bk : Backend, optional + GPU/CPU dispatch for the cosine path. Ignored for non-cosine + metrics (FAISS-GPU is cosine-only). + + Returns + ------- + sparse.csr_matrix + Sparse k-NN matrix with distances as values. + """ + if distance == "cosine": + # approximate_knn dispatches FAISS-GPU ↔ hnswlib. SAM's kNN is + # symmetric (self-query) so queries == database. + from samap.core.knn import approximate_knn + + nnm, dists = approximate_knn(g_weighted, g_weighted, k, metric="cosine", bk=bk) + else: + nnm, dists = _nearest_neighbors_umap(g_weighted, n_neighbors=k, metric=distance) + return gen_sparse_knn(nnm, dists) diff --git a/src/samap/sam/pca.py b/src/samap/sam/pca.py new file mode 100644 index 0000000..7881f7c --- /dev/null +++ b/src/samap/sam/pca.py @@ -0,0 +1,369 @@ +"""PCA implementations for the vendored SAM algorithm. + +Vendored from samalg.utilities. Provides: +- `_pca_with_sparse`: Implicit-centering sparse PCA via LinearOperator + svds. + Avoids densifying a sparse matrix for mean-subtraction. Now dispatches to + either ARPACK (serial Lanczos, original) or randomized SVD (block SpMM, + GPU-friendly) via the `svd_solver` parameter. +- `randomized_svd_implicit_center`: fbpca-style randomized SVD with implicit + mean-centering. O(k) SpMM passes instead of hundreds of serial matvecs. +- `weighted_PCA`: Dense PCA with optional eigenvalue weighting. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import scipy.sparse as sp +import scipy.sparse.linalg as spla +import sklearn.utils.sparsefuncs as sf +from sklearn.decomposition import PCA +from sklearn.utils import check_array, check_random_state +from sklearn.utils.extmath import svd_flip + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from samap.core._backend import Backend + + +# --------------------------------------------------------------------------- +# Randomized SVD with implicit centering +# --------------------------------------------------------------------------- + + +def randomized_svd_implicit_center( + X: Any, + k: int, + mu: NDArray[np.floating[Any]] | None = None, + n_oversamples: int = 10, + n_power: int = 4, + seed: int = 0, + bk: Backend | None = None, +) -> dict[str, NDArray[np.floating[Any]]]: + r"""Top-k SVD of ``(X - 1·μ)`` via randomized range-finding, never densifying. + + Every occurrence of ``(X - 1·μ) @ M`` is computed as + ``X @ M - 1 · (μ @ M)`` — one SpMM plus a rank-1 broadcast correction. + Symmetrically, ``(X - 1·μ).T @ N = X.T @ N - μ.T · (1.T @ N)``. + This replaces ARPACK's hundreds of serial matvecs with O(n_power) block + SpMM calls, which is dramatically better for GPU utilisation + (cuSPARSE SpMM >> repeated SpMV). + + Algorithm (Halko–Martinsson–Tropp, adapted for implicit centering):: + + Ω = random(m, k + p) # Gaussian sketch + Y = (X - 1μ) @ Ω # range sketch + for q power iterations: + Y = qr(Y).Q # stabilise + Z = (X - 1μ).T @ Y # one SpMM + correction + Z = qr(Z).Q # stabilise + Y = (X - 1μ) @ Z # one SpMM + correction + Q = qr(Y).Q # (n, k+p) orthonormal + B = Q.T @ (X - 1μ) # small: (k+p, m) + U_b, s, Vt = svd(B) # dense, fits in memory + U = Q @ U_b + + Parameters + ---------- + X : sparse matrix + Input (n_cells × n_genes), CSR or CSC. On GPU, a cupyx sparse matrix. + k : int + Number of singular components to return. + mu : array-like, optional + Column means (1 × n_genes or flat). If None, computed from X. + n_oversamples : int, default 10 + Extra sketch dimensions beyond k. More improves accuracy at the cost + of wider SpMM blocks. + n_power : int, default 4 + Power iterations. Each adds two SpMM passes. 4 is conservative for + SAM's conditioning (single-cell data has a long singular-value tail). + seed : int, default 0 + Seeds the random sketch matrix for reproducibility. + bk : Backend, optional + CPU/GPU dispatch. If None, a CPU backend is created. + + Returns + ------- + dict with keys: + - ``X_pca``: (n, k) projected scores, i.e. ``(X - 1μ) @ V.T`` == ``U @ diag(s)`` + - ``components``: (k, m) right singular vectors ``Vt`` + - ``variance``: (k,) explained variances ``s**2 / (n-1)`` + - ``variance_ratio``: variances normalised by total column variance + """ + if bk is None: + from samap.core._backend import Backend as _Backend + + bk = _Backend("cpu") + + xp = bk.xp + n, m = X.shape + + # --- Mean vector as a flat (m,) array on the active backend --- + if mu is None: + mu_flat = xp.asarray(X.mean(axis=0)).ravel() + else: + mu_flat = bk.to_device(np.asarray(mu).ravel()) + mu_row = mu_flat[None, :] # (1, m) for broadcasting + mu_col = mu_flat[:, None] # (m, 1) + + # --- Implicit-centering SpMM helpers ------------------------------------- + # (X - 1μ) @ M == X @ M - 1 · (μ @ M) + # The second term is a (1, r) row vector, broadcast-subtracted from + # every row of the (n, r) SpMM result. No (n, m) dense materialised. + def A_matmul(M: Any) -> Any: + M = bk.asfortran_if_gpu(M) # cuSPARSE SpMM prefers F-order RHS + return X @ M - mu_row @ M # (n, r) - (1, r) broadcast + + # (X - 1μ).T @ N == X.T @ N - μ.T · (1.T @ N) + # (1.T @ N) is just the column sums of N → (1, r). Correction is a + # rank-1 outer product μ.T @ col_sums. + def At_matmul(N: Any) -> Any: + N = bk.asfortran_if_gpu(N) + col_sums = N.sum(axis=0, keepdims=True) # (1, r) + return X.T @ N - mu_col @ col_sums # (m, r) - (m, 1)@(1, r) + + # --- Range finder -------------------------------------------------------- + k_os = k + n_oversamples + # numpy Generator path works on both backends; cupy mirrors the API. + rng = xp.random.default_rng(seed) + Omega = rng.standard_normal((m, k_os)).astype(X.dtype, copy=False) + + Y = A_matmul(Omega) # (n, k_os) + + # Power iterations with QR stabilisation on both sides. Without the + # intermediate re-orthonormalisation the smaller singular directions wash + # out in floating point after ~2 iterations. + for _ in range(n_power): + Y, _ = xp.linalg.qr(Y) + Z = At_matmul(Y) # (m, k_os) + Z, _ = xp.linalg.qr(Z) + Y = A_matmul(Z) # (n, k_os) + + Q, _ = xp.linalg.qr(Y) # (n, k_os), orthonormal columns + + # --- Project and small SVD ---------------------------------------------- + # B = Q.T @ (X - 1μ). This is the *transpose* of (X - 1μ).T @ Q, which + # our At_matmul helper already computes as an SpMM. + B = At_matmul(Q).T # (k_os, m), dense — small by construction + + U_b, s, Vt = xp.linalg.svd(B, full_matrices=False) + U = Q @ U_b # (n, k_os) + + # Truncate to k, sign-fix for determinism. svd_flip is numpy-side, so + # pull back to host if on GPU. SVD output is already sorted descending. + U_h = bk.to_host(U[:, :k]) + s_h = bk.to_host(s[:k]) + Vt_h = bk.to_host(Vt[:k, :]) + U_h, Vt_h = svd_flip(U_h, Vt_h) + + X_pca = U_h * s_h # (n, k) — scores == U @ diag(s) + ev = s_h**2 / (n - 1) + + # Total variance of the *uncentered* sparse X — matches `_pca_with_sparse` + # which uses sf.mean_variance_axis on X (not on X - μ). This is the + # correct denominator: sum of column variances. + X_host = bk.to_host(X) + if sp.issparse(X_host): + total_var = sf.mean_variance_axis(X_host, axis=0)[1].sum() + else: + total_var = X_host.var(axis=0, ddof=0).sum() + ev_ratio = ev / total_var + + return { + "X_pca": X_pca, + "variance": ev, + "variance_ratio": ev_ratio, + "components": Vt_h, + } + + +# --------------------------------------------------------------------------- +# ARPACK path (original) + dispatch +# --------------------------------------------------------------------------- + + +def _pca_with_sparse( + X: sp.spmatrix, + npcs: int, + solver: str = "arpack", + mu: NDArray[np.floating[Any]] | None = None, + seed: int = 0, + mu_axis: int = 0, + svd_solver: Literal["arpack", "randomized"] = "arpack", + bk: Backend | None = None, +) -> dict[str, NDArray[np.floating[Any]]]: + """Perform PCA on sparse matrices using iterative SVD with implicit centering. + + Uses a LinearOperator to represent (X - mu) without ever materializing the + dense centered matrix. The matvec/rmatvec closures compute Xv - mu·v on the + fly, keeping memory at O(nnz) instead of O(n*m). + + Parameters + ---------- + X : sparse.spmatrix + Input sparse matrix. + npcs : int + Number of principal components. + solver : str, optional + SVD solver passed to scipy's ``svds`` (ARPACK path only). Default is 'arpack'. + mu : NDArray | None, optional + Pre-computed mean. If None, computed from X. + seed : int, optional + Random seed. Default is 0. + mu_axis : int, optional + Axis along which mean was computed. Default is 0. + svd_solver : {"arpack", "randomized"}, optional + Algorithm dispatch. ``"arpack"`` (default) uses the original + LinearOperator + iterative Lanczos path — hundreds of serial matvecs, + tight convergence, but poor GPU utilisation. ``"randomized"`` uses + block SpMM (see :func:`randomized_svd_implicit_center`) — O(k) passes, + each a wide sparse-dense matmul, typically 5-10× faster and far + better suited to GPU. Randomized requires ``mu_axis=0``. + bk : Backend, optional + Backend for the randomized path. Ignored for ARPACK. + + Returns + ------- + dict + Dictionary with keys 'X_pca', 'variance', 'variance_ratio', 'components'. + """ + # --- Randomized solver dispatch ------------------------------------------ + # The randomized path only supports column centering (mu_axis=0) — the + # only case SAM actually uses. For mu_axis=1, fall through to ARPACK. + if svd_solver == "randomized": + if mu_axis != 0: + raise ValueError("randomized svd_solver only supports mu_axis=0 (column centering)") + return randomized_svd_implicit_center(X, npcs, mu=mu, seed=seed, bk=bk) + + # --- ARPACK path (original) ---------------------------------------------- + random_state = check_random_state(seed) + np.random.set_state(random_state.get_state()) + random_init = np.random.rand(np.min(X.shape)) + X = check_array(X, accept_sparse=["csr", "csc"]) + + if mu is None: + if mu_axis == 0: + mu = np.asarray(X.mean(0)).flatten()[None, :] + else: + mu = np.asarray(X.mean(1)).flatten()[:, None] + + if mu_axis == 0: + mdot = mu.dot + mmat = mdot + mhdot = mu.T.dot + mhmat = mu.T.dot + Xdot = X.dot + Xmat = Xdot + XHdot = X.T.conj().dot + XHmat = XHdot + ones = np.ones(X.shape[0])[None, :].dot + + def matvec(x: NDArray[Any]) -> NDArray[Any]: + return Xdot(x) - mdot(x) + + def matmat(x: NDArray[Any]) -> NDArray[Any]: + return Xmat(x) - mmat(x) + + def rmatvec(x: NDArray[Any]) -> NDArray[Any]: + return XHdot(x) - mhdot(ones(x)) + + def rmatmat(x: NDArray[Any]) -> NDArray[Any]: + return XHmat(x) - mhmat(ones(x)) + + else: + mdot = mu.dot + mmat = mdot + mhdot = mu.T.dot + mhmat = mu.T.dot + Xdot = X.dot + Xmat = Xdot + XHdot = X.T.conj().dot + XHmat = XHdot + ones = np.ones(X.shape[1])[None, :].dot + + def matvec(x: NDArray[Any]) -> NDArray[Any]: + return Xdot(x) - mdot(ones(x)) + + def matmat(x: NDArray[Any]) -> NDArray[Any]: + return Xmat(x) - mmat(ones(x)) + + def rmatvec(x: NDArray[Any]) -> NDArray[Any]: + return XHdot(x) - mhdot(x) + + def rmatmat(x: NDArray[Any]) -> NDArray[Any]: + return XHmat(x) - mhmat(x) + + XL = spla.LinearOperator( + matvec=matvec, + dtype=X.dtype, + matmat=matmat, + shape=X.shape, + rmatvec=rmatvec, + rmatmat=rmatmat, + ) + + u, s, v = spla.svds(XL, solver=solver, k=npcs, v0=random_init) + u, v = svd_flip(u, v) + idx = np.argsort(-s) + v = v[idx, :] + + X_pca = (u * s)[:, idx] + ev = s[idx] ** 2 / (X.shape[0] - 1) + + total_var = sf.mean_variance_axis(X, axis=0)[1].sum() + ev_ratio = ev / total_var + + return { + "X_pca": X_pca, + "variance": ev, + "variance_ratio": ev_ratio, + "components": v, + } + + +def weighted_PCA( + mat: NDArray[np.floating[Any]], + do_weight: bool = True, + npcs: int | None = None, + solver: str = "auto", + seed: int = 0, +) -> tuple[NDArray[np.floating[Any]], PCA]: + """Perform PCA with optional eigenvalue weighting. + + Parameters + ---------- + mat : NDArray + Input data matrix. + do_weight : bool, optional + If True, weight PCs by eigenvalues. Default is True. + npcs : int | None, optional + Number of components. If None, uses min(mat.shape). + solver : str, optional + SVD solver. Default is 'auto'. + seed : int, optional + Random seed. Default is 0. + + Returns + ------- + tuple + (reduced_weighted, pca_object) + """ + if do_weight: + ncom = min(mat.shape) if npcs is None else min((min(mat.shape), npcs)) + + pca = PCA(svd_solver=solver, n_components=ncom, random_state=check_random_state(seed)) + reduced = pca.fit_transform(mat) + scaled_eigenvalues = pca.explained_variance_ + scaled_eigenvalues = scaled_eigenvalues / scaled_eigenvalues.max() + reduced_weighted = reduced * scaled_eigenvalues[None, :] ** 0.5 + else: + pca = PCA(n_components=npcs, svd_solver=solver, random_state=check_random_state(seed)) + reduced = pca.fit_transform(mat) + if reduced.shape[1] == 1: + pca = PCA(n_components=2, svd_solver=solver, random_state=check_random_state(seed)) + reduced = pca.fit_transform(mat) + reduced_weighted = reduced + + return reduced_weighted, pca diff --git a/src/samap/sam/utils.py b/src/samap/sam/utils.py new file mode 100644 index 0000000..9026dc4 --- /dev/null +++ b/src/samap/sam/utils.py @@ -0,0 +1,83 @@ +"""Small helpers for the vendored SAM algorithm. + +Vendored from samalg.utilities. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +def convert_annotations(A: NDArray[Any]) -> NDArray[np.int64]: + """Convert categorical annotations to integer codes. + + Parameters + ---------- + A : NDArray + Array of categorical values. + + Returns + ------- + NDArray + Integer codes. + """ + x = np.unique(A) + y = np.zeros(A.size) + for z, i in enumerate(x): + y[i == A] = z + return y.astype("int") + + +def extract_annotation( + cn: NDArray[Any], + x: int | None, + c: str = "_", +) -> NDArray[Any] | list[NDArray[Any]]: + """Extract annotations from cell names by splitting on delimiter. + + Parameters + ---------- + cn : NDArray + Array of cell names. + x : int | None + Index of annotation field to extract. If None, returns all fields. + c : str, optional + Delimiter character. Default is '_'. + + Returns + ------- + NDArray | list + Extracted annotations. + """ + m = [] + if x is not None: + for i in range(cn.size): + f = cn[i].split(c) + x = min(len(f) - 1, x) + m.append(f[x]) + return np.array(m) + else: + ms: list[list[str]] = [] + ls = [] + for i in range(cn.size): + f = cn[i].split(c) + m_inner = [] + for field_x in range(len(f)): + m_inner.append(f[field_x]) + ms.append(m_inner) + ls.append(len(m_inner)) + ml = max(ls) + for i in range(len(ms)): + ms[i].extend([""] * (ml - len(ms[i]))) + if ml - len(ms[i]) > 0: + ms[i] = list(np.concatenate(ms[i])) + ms_arr = np.vstack(ms) + MS = [] + for i in range(ms_arr.shape[1]): + MS.append(ms_arr[:, i]) + return MS diff --git a/src/samap/utils.py b/src/samap/utils.py index 13172fa..8365b9d 100644 --- a/src/samap/utils.py +++ b/src/samap/utils.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +import scipy.sparse as sparse from samap._logging import logger @@ -14,7 +15,36 @@ import scipy.sparse as sp from numpy.typing import NDArray - from samalg import SAM + + from samap.sam import SAM + + +def q(x: Any) -> NDArray[Any]: + """Convert an iterable to a numpy array via list().""" + return np.array(list(x)) + + +def coo_to_csr_overwrite( + row: Any, col: Any, data: Any, shape: tuple[int, int] +) -> sparse.csr_matrix: + """Build a CSR matrix from COO triplets with last-write-wins semantics. + + Equivalent to ``M = lil_matrix(shape); M[row, col] = data; M.tocsr()`` but + without LIL (which cupy lacks). Standard COO construction sums duplicate + (row, col) entries; this keeps only the *last* value for each duplicate, + matching LIL fancy-index assignment. + """ + row = np.asarray(row, dtype=np.int64) + col = np.asarray(col, dtype=np.int64) + data = np.asarray(data) + if row.size == 0: + return sparse.csr_matrix(shape) + # Linearise indices to a single key; find last occurrence of each key by + # reversing, taking unique first-occurrence indices, mapping back. + lin = row * shape[1] + col + _, idx = np.unique(lin[::-1], return_index=True) + keep = lin.size - 1 - idx + return sparse.coo_matrix((data[keep], (row[keep], col[keep])), shape=shape).tocsr() def save_samap(sm: Any, fn: str) -> None: @@ -184,9 +214,9 @@ def to_vo(op: NDArray[Any]) -> NDArray[Any]: ndarray Nx2 array of pairs. """ - import samalg.utilities as ut + from samap.sam.utils import extract_annotation - return np.vstack(ut.extract_annotation(op, None, ";")).T + return np.vstack(extract_annotation(op, None, ";")).T def substr( diff --git a/tests/integration/test_samap_integration.py b/tests/integration/test_samap_integration.py index c105b23..252c95c 100644 --- a/tests/integration/test_samap_integration.py +++ b/tests/integration/test_samap_integration.py @@ -3,10 +3,10 @@ from pathlib import Path import pytest -from samalg import SAM from samap import SAMAP from samap.analysis import GenePairFinder, get_mapping_scores, sankey_plot +from samap.sam import SAM # Path to example data relative to repo root EXAMPLE_DATA = Path(__file__).parent.parent.parent / "example_data" diff --git a/tests/regression/__init__.py b/tests/regression/__init__.py new file mode 100644 index 0000000..fb148ca --- /dev/null +++ b/tests/regression/__init__.py @@ -0,0 +1 @@ +"""Golden regression tests for SAMap.""" diff --git a/tests/regression/conftest.py b/tests/regression/conftest.py new file mode 100644 index 0000000..92b4b73 --- /dev/null +++ b/tests/regression/conftest.py @@ -0,0 +1,21 @@ +"""Pytest config for regression tests.""" + +from __future__ import annotations + +import pytest + + +def pytest_addoption(parser: pytest.Parser) -> None: + """Register the --regenerate-golden CLI flag.""" + parser.addoption( + "--regenerate-golden", + action="store_true", + default=False, + help="Regenerate golden fixture(s) instead of comparing against them.", + ) + + +@pytest.fixture +def regenerate_golden(request: pytest.FixtureRequest) -> bool: + """True when the user asked to regenerate golden fixtures.""" + return request.config.getoption("--regenerate-golden") diff --git a/tests/regression/fixtures/golden_3species.npz b/tests/regression/fixtures/golden_3species.npz new file mode 100644 index 0000000..82e15ab Binary files /dev/null and b/tests/regression/fixtures/golden_3species.npz differ diff --git a/tests/regression/test_golden_output.py b/tests/regression/test_golden_output.py new file mode 100644 index 0000000..e0b92b8 --- /dev/null +++ b/tests/regression/test_golden_output.py @@ -0,0 +1,326 @@ +"""Golden regression test: full 3-species SAMap pipeline. + +Captures the numeric outputs of the current SAMap implementation so that +subsequent optimization/refactoring can be verified to produce identical +results (to floating-point tolerance). + +Regenerate the fixture with: + pytest tests/regression/test_golden_output.py -m slow --regenerate-golden + +Run the comparison with: + pytest tests/regression/test_golden_output.py -m slow +""" + +from __future__ import annotations + +import os +import random +import types +from pathlib import Path +from typing import Any + +# Nudge thread counts toward determinism. These must be set before numpy/ +# numba import, but since conftest.py imports numpy first there's no hard +# guarantee. The numba-parallel functions in SAMap write to distinct array +# indices (no reductions), so thread count there is benign — these env vars +# are belt-and-suspenders. +os.environ.setdefault("OMP_NUM_THREADS", "1") +os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") +os.environ.setdefault("MKL_NUM_THREADS", "1") +os.environ.setdefault("NUMBA_NUM_THREADS", "1") + +import numpy as np +import pytest +import scipy.sparse as sp + +# --------------------------------------------------------------------------- +# Paths & constants +# --------------------------------------------------------------------------- + +_HERE = Path(__file__).parent +_FIXTURES = _HERE / "fixtures" +_GOLDEN = _FIXTURES / "golden_3species.npz" +_EXAMPLE_DATA = _HERE.parent.parent / "example_data" + +_SEED = 42 +_RTOL = 1e-4 +_ATOL = 1e-6 + +_SPECIES = { + "pl": "planarian.h5ad", + "sc": "schistosome.h5ad", + "hy": "hydra.h5ad", +} + + +# --------------------------------------------------------------------------- +# Deterministic hnswlib wrapper +# --------------------------------------------------------------------------- + + +class _DeterministicHNSWIndex: + """Thin wrapper around hnswlib.Index that forces deterministic behaviour. + + hnswlib's multi-threaded ``add_items`` produces a non-deterministic index + (insertion order races). We force single-threaded insertion and query, and + pin the construction seed. The wrapper delegates everything else. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + import hnswlib as _real_hnswlib + + self._index = _real_hnswlib.Index(*args, **kwargs) + + def init_index(self, *args: Any, **kwargs: Any) -> Any: + # hnswlib defaults random_seed to 100 already; make it explicit. + kwargs.setdefault("random_seed", _SEED) + return self._index.init_index(*args, **kwargs) + + def add_items(self, *args: Any, **kwargs: Any) -> Any: + kwargs["num_threads"] = 1 + return self._index.add_items(*args, **kwargs) + + def knn_query(self, *args: Any, **kwargs: Any) -> Any: + kwargs["num_threads"] = 1 + return self._index.knn_query(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self._index, name) + + +def _patch_hnswlib(monkeypatch: pytest.MonkeyPatch) -> None: + """Replace the hnswlib module seen by samap with a deterministic shim. + + ``samap.core.knn`` does a top-level ``import hnswlib`` — that's the + only call site reached during the golden pipeline (samap.sam's hnswlib + usage is gated behind ``SAM.run()``/``calculate_nnm``, which the pipeline + does not call — input SAMs are pre-computed and loaded from h5ad). + """ + import samap.core.knn as knn_mod + + fake = types.SimpleNamespace(Index=_DeterministicHNSWIndex) + monkeypatch.setattr(knn_mod, "hnswlib", fake) + + +# --------------------------------------------------------------------------- +# Pipeline runner +# --------------------------------------------------------------------------- + + +def _fix_seeds() -> None: + np.random.seed(_SEED) + random.seed(_SEED) + + +def _run_pipeline(monkeypatch: pytest.MonkeyPatch) -> Any: + """Run the full 3-species SAMap pipeline with determinism controls.""" + _patch_hnswlib(monkeypatch) + _fix_seeds() + + from samap import SAMAP + from samap.sam import SAM + + sams: dict[str, Any] = {} + for sid, fname in _SPECIES.items(): + sam = SAM() + sam.load_data(str(_EXAMPLE_DATA / fname)) + sams[sid] = sam + + sm = SAMAP(sams, f_maps=str(_EXAMPLE_DATA / "maps") + os.sep) + # umap=False: UMAP is stochastic and we don't pin its output anyway. + sm.run(n_iterations=3, umap=False) + return sm + + +def _pack_sparse(prefix: str, mat: sp.spmatrix, out: dict[str, np.ndarray]) -> None: + csr = sp.csr_matrix(mat) + csr.sort_indices() + # Drop explicit zeros so structural comparison is stable. + csr.eliminate_zeros() + out[f"{prefix}_data"] = np.ascontiguousarray(csr.data, dtype=np.float64) + out[f"{prefix}_indices"] = np.ascontiguousarray(csr.indices, dtype=np.int64) + out[f"{prefix}_indptr"] = np.ascontiguousarray(csr.indptr, dtype=np.int64) + out[f"{prefix}_shape"] = np.asarray(csr.shape, dtype=np.int64) + + +def _extract_outputs(sm: Any) -> dict[str, np.ndarray]: + """Capture all numeric outputs we want to pin.""" + out: dict[str, np.ndarray] = {} + + # Stitched cross-species kNN graph (feeds UMAP / downstream analysis). + _pack_sparse("conn", sm.samap.adata.obsp["connectivities"], out) + + # Refined gene-homology graph (correlation-reweighted). + _pack_sparse("gnnm_refined", sm.gnnm_refined, out) + + # Original BLAST homology graph (reindexed post-run). Should be fully + # deterministic irrespective of hnswlib — good sanity anchor. + _pack_sparse("gnnm", sm.gnnm, out) + + # Per-species SAM gene weights. These come from the input h5ad (SAMap + # does not modify them) and serve as a load-integrity check. + for sid in _SPECIES: + w = sm.sams[sid].adata.var["weights"].to_numpy() + out[f"weights_{sid}"] = np.ascontiguousarray(w, dtype=np.float64) + + return out + + +# --------------------------------------------------------------------------- +# Comparison helpers +# --------------------------------------------------------------------------- + + +def _compare_sparse_strict(prefix: str, golden: Any, actual: dict[str, np.ndarray]) -> None: + """Assert sparse-matrix equality: exact structure + allclose data.""" + ( + np.testing.assert_array_equal(golden[f"{prefix}_shape"], actual[f"{prefix}_shape"]), + f"{prefix}: shape mismatch", + ) + ( + np.testing.assert_array_equal(golden[f"{prefix}_indptr"], actual[f"{prefix}_indptr"]), + f"{prefix}: indptr mismatch (different nnz pattern)", + ) + ( + np.testing.assert_array_equal(golden[f"{prefix}_indices"], actual[f"{prefix}_indices"]), + f"{prefix}: indices mismatch (different sparsity pattern)", + ) + np.testing.assert_allclose( + golden[f"{prefix}_data"], + actual[f"{prefix}_data"], + rtol=_RTOL, + atol=_ATOL, + err_msg=f"{prefix}: nonzero values diverge", + ) + + +def _compare_sparse_as_dense(prefix: str, golden: Any, actual: dict[str, np.ndarray]) -> None: + """Fallback: compare sparse matrices as dense, elementwise allclose. + + Used if structural comparison fails due to tiny values crossing the + zero threshold. Reconstructs both CSR matrices and compares the + elementwise difference without materializing a huge dense array. + """ + g_shape = tuple(golden[f"{prefix}_shape"]) + a_shape = tuple(actual[f"{prefix}_shape"]) + assert g_shape == a_shape, f"{prefix}: shape mismatch {g_shape} vs {a_shape}" + + g = sp.csr_matrix( + (golden[f"{prefix}_data"], golden[f"{prefix}_indices"], golden[f"{prefix}_indptr"]), + shape=g_shape, + ) + a = sp.csr_matrix( + (actual[f"{prefix}_data"], actual[f"{prefix}_indices"], actual[f"{prefix}_indptr"]), + shape=a_shape, + ) + diff = (g - a).tocsr() + # Absolute-difference check is enough here — relative tolerance on a + # sparse graph with many tiny edges is noisy. + max_abs = np.abs(diff.data).max() if diff.nnz else 0.0 + assert max_abs <= max(_ATOL, _RTOL), ( + f"{prefix}: max abs elementwise difference {max_abs:.3e} exceeds " + f"tolerance (checked {diff.nnz} differing entries)" + ) + + +def _report_sparse_mismatch(prefix: str, golden: Any, actual: dict[str, np.ndarray]) -> str: + """Produce a diagnostic string when structural comparison fails.""" + g_nnz = len(golden[f"{prefix}_data"]) + a_nnz = len(actual[f"{prefix}_data"]) + g_sum = float(golden[f"{prefix}_data"].sum()) + a_sum = float(actual[f"{prefix}_data"].sum()) + return ( + f" {prefix}: nnz golden={g_nnz} actual={a_nnz} " + f"(Δ={a_nnz - g_nnz}), sum golden={g_sum:.6g} actual={a_sum:.6g}" + ) + + +# --------------------------------------------------------------------------- +# The test +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +def test_golden_3species(regenerate_golden: bool, monkeypatch: pytest.MonkeyPatch) -> None: + """Pin full 3-species SAMap pipeline output against a golden fixture. + + Runs the pipeline end-to-end on the example hydra/planarian/schistosome + data and asserts that the stitched kNN graph, the refined gene-homology + graph, and the per-species gene weights match the stored golden within + rtol=1e-4. + + Determinism notes + ----------------- + The only known source of run-to-run variation in the core algorithm is + hnswlib's multi-threaded ``add_items`` in ``_united_proj``. We patch the + ``hnswlib`` module reference inside ``samap.core.projection`` with a wrapper + that forces single-threaded index construction and a fixed seed. + + If this test fails after a refactor with *structural* differences in the + kNN graph (different nnz pattern) but numerically similar overall weight + distribution, it likely means the refactor changed hnswlib invocation in + a way the shim no longer covers. Either extend the shim or loosen to a + top-k-overlap comparison. + """ + if not _EXAMPLE_DATA.exists(): + pytest.skip("Example data not available at example_data/") + + sm = _run_pipeline(monkeypatch) + actual = _extract_outputs(sm) + + if regenerate_golden: + _FIXTURES.mkdir(parents=True, exist_ok=True) + np.savez_compressed(_GOLDEN, **actual) + pytest.skip(f"Regenerated golden fixture → {_GOLDEN}") + + if not _GOLDEN.exists(): + pytest.fail( + f"Golden fixture missing: {_GOLDEN}\n" + "Generate it with: pytest tests/regression/test_golden_output.py " + "-m slow --regenerate-golden" + ) + + golden = np.load(_GOLDEN) + + # --- Dense vectors: per-species gene weights ------------------------ + for sid in _SPECIES: + key = f"weights_{sid}" + np.testing.assert_allclose( + golden[key], + actual[key], + rtol=_RTOL, + atol=_ATOL, + err_msg=f"Gene weights for species '{sid}' diverged from golden", + ) + + # --- Sparse graphs --------------------------------------------------- + # gnnm (BLAST homology, reindexed) should be bit-for-bit deterministic + # — it has no stochastic inputs. Hard structural check. + _compare_sparse_strict("gnnm", golden, actual) + + # gnnm_refined and conn depend on hnswlib. Try strict first; if the + # sparsity pattern shifts (e.g. a few edges crossing the zero threshold) + # fall back to elementwise comparison. + failures: list[str] = [] + for prefix in ("gnnm_refined", "conn"): + try: + _compare_sparse_strict(prefix, golden, actual) + except AssertionError as e_strict: + try: + _compare_sparse_as_dense(prefix, golden, actual) + except AssertionError as e_dense: + failures.append( + f"{prefix} failed both strict and dense comparison.\n" + f" strict: {e_strict}\n" + f" dense: {e_dense}\n" + f"{_report_sparse_mismatch(prefix, golden, actual)}" + ) + + if failures: + pytest.fail( + "Golden regression mismatch:\n" + + "\n".join(failures) + + "\n\nIf this divergence is expected (e.g. after an intentional " + "algorithmic change), regenerate the fixture with " + "--regenerate-golden." + ) diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py new file mode 100644 index 0000000..2e45ada --- /dev/null +++ b/tests/unit/test_backend.py @@ -0,0 +1,510 @@ +"""Unit tests for samap.core._backend — the CPU/GPU dispatch layer. + +CPU tests run unconditionally. GPU tests are skipped unless cupy is installed +*and* a CUDA device is visible (they provide coverage on CUDA CI only). +""" + +from __future__ import annotations + +import numpy as np +import pytest +import scipy.sparse as sp +from scipy.sparse.linalg import LinearOperator as ScipyLinearOperator + +from samap.core._backend import HAS_CUPY, Backend, COOBuilder + +if HAS_CUPY: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse + + _CUDA = cp.is_available() +else: + _CUDA = False + +gpu_only = pytest.mark.skipif(not _CUDA, reason="requires cupy + a CUDA device") + + +# --------------------------------------------------------------------------- +# Backend construction / device selection +# --------------------------------------------------------------------------- + + +class TestBackendInit: + def test_cpu_backend(self) -> None: + bk = Backend("cpu") + assert bk.device == "cpu" + assert bk.gpu is False + assert bk.xp is np + assert bk.sp is sp + # spla should expose svds/LinearOperator + assert hasattr(bk.spla, "svds") + assert hasattr(bk.spla, "LinearOperator") + + def test_auto_without_cuda(self) -> None: + """On a machine without cupy/CUDA, auto → cpu silently.""" + bk = Backend("auto") + # Either we have a GPU (CI) or we don't (dev laptop); either is valid. + assert bk.device in ("cpu", "cuda") + assert bk.gpu == (bk.device == "cuda") + + @pytest.mark.skipif(HAS_CUPY, reason="tests the no-cupy error path") + def test_cuda_without_cupy_raises(self) -> None: + with pytest.raises(RuntimeError, match="cupy is not installed"): + Backend("cuda") + + def test_invalid_device_raises(self) -> None: + with pytest.raises(ValueError, match="must be 'cpu', 'cuda', or 'auto'"): + Backend("tpu") # type: ignore[arg-type] + + def test_repr(self) -> None: + bk = Backend("cpu") + assert "cpu" in repr(bk) + assert "gpu=False" in repr(bk) + + @gpu_only + def test_cuda_backend(self) -> None: + bk = Backend("cuda") + assert bk.device == "cuda" + assert bk.gpu is True + assert bk.xp is cp + assert bk.sp is cpx_sparse + + +# --------------------------------------------------------------------------- +# Compat shims — CPU +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bk_cpu() -> Backend: + return Backend("cpu") + + +@pytest.fixture +def small_csr() -> sp.csr_matrix: + """A fixed 3x3 CSR with known structure.""" + data = np.array([1.0, 2.0, 3.0, 4.0]) + row = np.array([0, 0, 1, 2]) + col = np.array([0, 2, 1, 0]) + return sp.csr_matrix((data, (row, col)), shape=(3, 3)) + + +class TestNonzeroCPU: + def test_nonzero_on_csr(self, bk_cpu: Backend, small_csr: sp.csr_matrix) -> None: + rows, cols = bk_cpu.nonzero(small_csr) + # scipy returns sorted (row-major) for CSR + np.testing.assert_array_equal(rows, [0, 0, 1, 2]) + np.testing.assert_array_equal(cols, [0, 2, 1, 0]) + + def test_nonzero_on_dense(self, bk_cpu: Backend) -> None: + A = np.array([[0, 5, 0], [0, 0, 7]]) + rows, cols = bk_cpu.nonzero(A) + np.testing.assert_array_equal(rows, [0, 1]) + np.testing.assert_array_equal(cols, [1, 2]) + + +class TestSparseFromCoo: + def test_basic_csr(self, bk_cpu: Backend) -> None: + data = [1.0, 2.0, 3.0] + row = [0, 1, 2] + col = [2, 1, 0] + A = bk_cpu.sparse_from_coo(data, row, col, shape=(3, 3), fmt="csr") + assert A.format == "csr" + expected = np.array([[0, 0, 1], [0, 2, 0], [3, 0, 0]], dtype=float) + np.testing.assert_array_equal(A.toarray(), expected) + + def test_csc_format(self, bk_cpu: Backend) -> None: + A = bk_cpu.sparse_from_coo([5.0], [1], [1], shape=(2, 2), fmt="csc") + assert A.format == "csc" + assert A[1, 1] == 5.0 + + def test_duplicates_are_summed(self, bk_cpu: Backend) -> None: + # Two entries at (0, 0) → summed + A = bk_cpu.sparse_from_coo([1.0, 2.0, 10.0], [0, 0, 1], [0, 0, 1], shape=(2, 2)) + assert A[0, 0] == 3.0 + assert A[1, 1] == 10.0 + + +class TestSetdiag: + def test_setdiag_zero_eliminates(self, bk_cpu: Backend) -> None: + A = sp.csr_matrix(np.array([[1, 2, 0], [0, 3, 4], [5, 0, 6]], dtype=float)) + nnz_before = A.nnz + bk_cpu.setdiag(A, 0) + np.testing.assert_array_equal(A.diagonal(), [0, 0, 0]) + # Diagonal entries removed from structure + assert A.nnz == nnz_before - 3 + + def test_setdiag_scalar_nonzero(self, bk_cpu: Backend) -> None: + A = sp.csr_matrix(np.eye(3)) + bk_cpu.setdiag(A, 7.0) + np.testing.assert_array_equal(A.diagonal(), [7, 7, 7]) + # Off-diagonal untouched + assert A[0, 1] == 0.0 + + def test_setdiag_array(self, bk_cpu: Backend) -> None: + A = sp.csr_matrix(np.zeros((3, 3))) + bk_cpu.setdiag(A, np.array([10.0, 20.0, 30.0])) + np.testing.assert_array_equal(A.diagonal(), [10, 20, 30]) + + def test_setdiag_converts_to_csr(self, bk_cpu: Backend) -> None: + A = sp.csc_matrix(np.eye(3)) + out = bk_cpu.setdiag(A, 0) + assert out.format == "csr" + + def test_setdiag_no_warning(self, bk_cpu: Backend) -> None: + """The shim suppresses SparseEfficiencyWarning on structural change.""" + A = sp.csr_matrix((3, 3)) # all-zero, so setdiag changes structure + with warnings_error(sp.SparseEfficiencyWarning): + bk_cpu.setdiag(A, 1.0) + + +def warnings_error(*categories): + """Context manager: turn given warning categories into errors.""" + import warnings + + class _Ctx: + def __enter__(self): + self._mgr = warnings.catch_warnings() + self._mgr.__enter__() + for cat in categories: + warnings.simplefilter("error", cat) + + def __exit__(self, *exc): + return self._mgr.__exit__(*exc) + + return _Ctx() + + +class TestSvds: + def test_cpu_passes_solver(self, bk_cpu: Backend) -> None: + """On CPU, solver kwarg passes through to scipy.""" + # Build a rank-2 5x4 matrix + rng = np.random.default_rng(42) + A = sp.random(5, 4, density=0.6, random_state=rng) + u, s, vt = bk_cpu.svds(A, k=2, solver="arpack") + assert s.shape == (2,) + assert u.shape == (5, 2) + assert vt.shape == (2, 4) + # Singular values are non-negative + assert (s >= 0).all() + + def test_cpu_svds_numerics(self, bk_cpu: Backend) -> None: + """Sanity check: recovered singular values match numpy.linalg.svd.""" + rng = np.random.default_rng(0) + M = rng.standard_normal((6, 4)) + A = sp.csr_matrix(M) + _, s, _ = bk_cpu.svds(A, k=3) + s_true = np.linalg.svd(M, compute_uv=False) + # svds returns ascending; np.linalg.svd descending + np.testing.assert_allclose(np.sort(s)[::-1], s_true[:3], rtol=1e-6) + + +class TestLinearOperator: + def test_dispatch_to_scipy(self, bk_cpu: Backend) -> None: + # 3x3 identity via matvec + lo = bk_cpu.LinearOperator( + shape=(3, 3), matvec=lambda x: x, rmatvec=lambda x: x, dtype=np.float64 + ) + assert isinstance(lo, ScipyLinearOperator) + x = np.array([1.0, 2.0, 3.0]) + np.testing.assert_array_equal(lo @ x, x) + + def test_in_svds(self, bk_cpu: Backend) -> None: + """LinearOperator can feed into svds (implicit-matrix SVD).""" + M = np.diag([5.0, 3.0, 1.0, 0.1]) + lo = bk_cpu.LinearOperator( + shape=(4, 4), + matvec=lambda x: M @ x, + rmatvec=lambda x: M.T @ x, + dtype=np.float64, + ) + _, s, _ = bk_cpu.svds(lo, k=2) + np.testing.assert_allclose(sorted(s), [3.0, 5.0], rtol=1e-6) + + +# --------------------------------------------------------------------------- +# Data movement — CPU backend (identity ops) +# --------------------------------------------------------------------------- + + +class TestDataMovementCPU: + def test_to_device_noop_dense(self, bk_cpu: Backend) -> None: + x = np.arange(5.0) + assert bk_cpu.to_device(x) is x + + def test_to_device_noop_sparse(self, bk_cpu: Backend, small_csr) -> None: + assert bk_cpu.to_device(small_csr) is small_csr + + def test_to_host_noop(self, bk_cpu: Backend) -> None: + x = np.arange(5.0) + assert bk_cpu.to_host(x) is x + + def test_asfortran_noop_cpu(self, bk_cpu: Backend) -> None: + x = np.zeros((3, 4), order="C") + out = bk_cpu.asfortran_if_gpu(x) + assert out is x + assert out.flags.c_contiguous # unchanged + + def test_free_pool_noop(self, bk_cpu: Backend) -> None: + # Should not raise + bk_cpu.free_pool() + + +# --------------------------------------------------------------------------- +# COOBuilder +# --------------------------------------------------------------------------- + + +class TestCOOBuilder: + def test_single_adds(self, bk_cpu: Backend) -> None: + b = COOBuilder(bk_cpu, shape=(3, 3)) + b.add(0, 1, 5.0) + b.add(1, 2, 3.0) + b.add(2, 0, 7.0) + A = b.finalize(fmt="csr") + expected = np.array([[0, 5, 0], [0, 0, 3], [7, 0, 0]], dtype=float) + np.testing.assert_array_equal(A.toarray(), expected) + assert A.format == "csr" + + def test_batch_adds(self, bk_cpu: Backend) -> None: + b = COOBuilder(bk_cpu, shape=(2, 2)) + b.add_batch([0, 1], [0, 1], [1.0, 2.0]) + A = b.finalize() + np.testing.assert_array_equal(A.toarray(), [[1, 0], [0, 2]]) + + def test_mixed_adds(self, bk_cpu: Backend) -> None: + b = COOBuilder(bk_cpu, shape=(3, 3)) + b.add(0, 0, 1.0) + b.add_batch([1, 2], [1, 2], [2.0, 3.0]) + b.add(0, 2, 4.0) + A = b.finalize() + expected = np.array([[1, 0, 4], [0, 2, 0], [0, 0, 3]], dtype=float) + np.testing.assert_array_equal(A.toarray(), expected) + + def test_empty_builder(self, bk_cpu: Backend) -> None: + b = COOBuilder(bk_cpu, shape=(4, 4)) + A = b.finalize() + assert A.shape == (4, 4) + assert A.nnz == 0 + + def test_custom_dtype(self, bk_cpu: Backend) -> None: + b = COOBuilder(bk_cpu, shape=(2, 2), dtype=np.float32) + b.add(0, 0, 1.0) + A = b.finalize() + assert A.dtype == np.float32 + + def test_duplicates_summed(self, bk_cpu: Backend) -> None: + b = COOBuilder(bk_cpu, shape=(2, 2)) + b.add(0, 0, 1.0) + b.add(0, 0, 2.0) + b.add_batch([0], [0], [3.0]) + A = b.finalize() + assert A[0, 0] == 6.0 + + def test_csc_finalize(self, bk_cpu: Backend) -> None: + b = COOBuilder(bk_cpu, shape=(3, 3)) + b.add(1, 1, 9.0) + A = b.finalize(fmt="csc") + assert A.format == "csc" + assert A[1, 1] == 9.0 + + def test_matches_lil_fancy_assign(self, bk_cpu: Backend) -> None: + """COOBuilder should produce the same matrix as the lil_matrix + fancy-index pattern it replaces.""" + rows = np.array([0, 2, 1, 3, 2]) + cols = np.array([1, 0, 3, 2, 2]) + vals = np.array([1.5, 2.5, 3.5, 4.5, 5.5]) + + # Reference: lil fancy assignment (note: lil OVERWRITES on duplicate + # keys, not sums — so we use unique keys here to keep semantics + # aligned with COO summing). + lil = sp.lil_matrix((4, 4)) + lil[rows, cols] = vals + ref = lil.tocsr() + + # COOBuilder path + b = COOBuilder(bk_cpu, shape=(4, 4)) + b.add_batch(rows, cols, vals) + ours = b.finalize(fmt="csr") + + np.testing.assert_array_equal(ours.toarray(), ref.toarray()) + + def test_end_to_end_dot(self, bk_cpu: Backend) -> None: + """Build two matrices with COOBuilder and multiply them.""" + b1 = COOBuilder(bk_cpu, shape=(2, 3)) + b1.add_batch([0, 0, 1], [0, 2, 1], [1.0, 2.0, 3.0]) + A = b1.finalize() + + b2 = COOBuilder(bk_cpu, shape=(3, 2)) + b2.add_batch([0, 1, 2], [0, 1, 0], [4.0, 5.0, 6.0]) + B = b2.finalize() + + C = A.dot(B) + # Manually: A = [[1,0,2],[0,3,0]], B = [[4,0],[0,5],[6,0]] + # C = [[1*4+2*6, 0], [0, 3*5]] = [[16, 0], [0, 15]] + np.testing.assert_array_equal(C.toarray(), [[16, 0], [0, 15]]) + + +# --------------------------------------------------------------------------- +# GPU tests — skipped on machines without CUDA +# --------------------------------------------------------------------------- + + +@gpu_only +class TestBackendGPU: + @pytest.fixture + def bk_gpu(self) -> Backend: + return Backend("cuda") + + def test_nonzero_via_coo(self, bk_gpu: Backend) -> None: + # Build a small cupy CSR directly + data = cp.array([1.0, 2.0, 3.0]) + row = cp.array([0, 1, 2]) + col = cp.array([2, 1, 0]) + A = cpx_sparse.csr_matrix((data, (row, col)), shape=(3, 3)) + r, c = bk_gpu.nonzero(A) + # Row indices may come back in sorted order depending on cupy's COO + # internals — check as a set of pairs. + pairs = set(zip(cp.asnumpy(r).tolist(), cp.asnumpy(c).tolist())) + assert pairs == {(0, 2), (1, 1), (2, 0)} + + def test_sparse_from_coo_on_gpu(self, bk_gpu: Backend) -> None: + A = bk_gpu.sparse_from_coo([1.0, 2.0], [0, 1], [1, 0], shape=(2, 2)) + assert cpx_sparse.issparse(A) + dense = cp.asnumpy(A.toarray()) + np.testing.assert_array_equal(dense, [[0, 1], [2, 0]]) + + def test_svds_strips_solver(self, bk_gpu: Backend) -> None: + """On GPU, solver= should be silently dropped, not passed to cupy.""" + M = cp.asarray(np.diag([5.0, 3.0, 1.0]).astype(np.float64)) + A = cpx_sparse.csr_matrix(M) + # If solver/v0 were NOT stripped, cupy would raise TypeError. + _, s, _ = bk_gpu.svds(A, k=2, solver="arpack", v0=np.ones(3)) + s_host = cp.asnumpy(s) + np.testing.assert_allclose(sorted(s_host), [3.0, 5.0], rtol=1e-5) + + def test_linear_operator_dispatch(self, bk_gpu: Backend) -> None: + lo = bk_gpu.LinearOperator( + shape=(3, 3), + matvec=lambda x: 2 * x, + rmatvec=lambda x: 2 * x, + dtype=np.float64, + ) + # Should be a cupy LinearOperator + from cupyx.scipy.sparse.linalg import LinearOperator as CupyLO + + assert isinstance(lo, CupyLO) + + def test_free_pool(self, bk_gpu: Backend) -> None: + # Allocate something, free, check no error + _ = cp.zeros(1000) + bk_gpu.free_pool() + + +@gpu_only +class TestDataMovementGPU: + @pytest.fixture + def bk_gpu(self) -> Backend: + return Backend("cuda") + + def test_dense_roundtrip(self, bk_gpu: Backend) -> None: + x_host = np.arange(6.0).reshape(2, 3) + x_dev = bk_gpu.to_device(x_host) + assert isinstance(x_dev, cp.ndarray) + x_back = bk_gpu.to_host(x_dev) + assert isinstance(x_back, np.ndarray) + np.testing.assert_array_equal(x_back, x_host) + + def test_sparse_csr_roundtrip(self, bk_gpu: Backend) -> None: + A_host = sp.random(5, 5, density=0.3, format="csr", random_state=0) + A_dev = bk_gpu.to_device(A_host) + assert cpx_sparse.issparse(A_dev) + assert A_dev.format == "csr" + A_back = bk_gpu.to_host(A_dev) + assert sp.issparse(A_back) + np.testing.assert_allclose(A_back.toarray(), A_host.toarray()) + + def test_sparse_csc_roundtrip(self, bk_gpu: Backend) -> None: + A_host = sp.random(4, 4, density=0.4, format="csc", random_state=1) + A_dev = bk_gpu.to_device(A_host) + assert A_dev.format == "csc" + A_back = bk_gpu.to_host(A_dev) + np.testing.assert_allclose(A_back.toarray(), A_host.toarray()) + + def test_sparse_coo_roundtrip(self, bk_gpu: Backend) -> None: + A_host = sp.random(4, 4, density=0.4, format="coo", random_state=2) + A_dev = bk_gpu.to_device(A_host) + assert A_dev.format == "coo" + A_back = bk_gpu.to_host(A_dev) + np.testing.assert_allclose(A_back.toarray(), A_host.toarray()) + + def test_lil_converted_to_csr(self, bk_gpu: Backend) -> None: + """LIL (no GPU equivalent) should be silently routed through CSR.""" + A_host = sp.lil_matrix((3, 3)) + A_host[0, 1] = 5.0 + A_host[2, 2] = 7.0 + A_dev = bk_gpu.to_device(A_host) + assert A_dev.format == "csr" + np.testing.assert_allclose(cp.asnumpy(A_dev.toarray()), A_host.toarray()) + + def test_to_device_idempotent(self, bk_gpu: Backend) -> None: + """Calling to_device on already-device data should be a no-op.""" + x_dev = cp.arange(5.0) + assert bk_gpu.to_device(x_dev) is x_dev + + def test_to_host_on_host_is_noop(self, bk_gpu: Backend) -> None: + x = np.arange(5.0) + assert bk_gpu.to_host(x) is x + + def test_asfortran_converts(self, bk_gpu: Backend) -> None: + x = cp.zeros((3, 4), order="C") + assert x.flags.c_contiguous + out = bk_gpu.asfortran_if_gpu(x) + assert out.flags.f_contiguous + + def test_asfortran_noop_if_already_f(self, bk_gpu: Backend) -> None: + x = cp.zeros((3, 4), order="F") + assert bk_gpu.asfortran_if_gpu(x) is x + + +@gpu_only +class TestCOOBuilderGPU: + @pytest.fixture + def bk_gpu(self) -> Backend: + return Backend("cuda") + + def test_result_on_device(self, bk_gpu: Backend) -> None: + b = COOBuilder(bk_gpu, shape=(3, 3)) + b.add(0, 1, 5.0) + b.add_batch([1, 2], [2, 0], [3.0, 7.0]) + A = b.finalize() + assert cpx_sparse.issparse(A) + dense = cp.asnumpy(A.toarray()) + np.testing.assert_array_equal(dense, [[0, 5, 0], [0, 0, 3], [7, 0, 0]]) + + def test_cpu_gpu_builders_agree(self, bk_gpu: Backend) -> None: + """Same triplets → same dense matrix regardless of backend.""" + bk_cpu = Backend("cpu") + rng = np.random.default_rng(42) + n = 10 + rows = rng.integers(0, n, size=30) + cols = rng.integers(0, n, size=30) + vals = rng.standard_normal(30) + + b_cpu = COOBuilder(bk_cpu, shape=(n, n)) + b_cpu.add_batch(rows, cols, vals) + A_cpu = b_cpu.finalize().toarray() + + b_gpu = COOBuilder(bk_gpu, shape=(n, n)) + b_gpu.add_batch(rows, cols, vals) + A_gpu = cp.asnumpy(b_gpu.finalize().toarray()) + + np.testing.assert_allclose(A_gpu, A_cpu, rtol=1e-12) + + def test_accepts_device_arrays_in_batch(self, bk_gpu: Backend) -> None: + """add_batch should accept cupy arrays (brought to host internally).""" + b = COOBuilder(bk_gpu, shape=(3, 3)) + b.add_batch(cp.array([0, 1]), cp.array([1, 2]), cp.array([4.0, 6.0])) + A = b.finalize() + dense = cp.asnumpy(A.toarray()) + np.testing.assert_array_equal(dense, [[0, 4, 0], [0, 0, 6], [0, 0, 0]]) diff --git a/tests/unit/test_coarsening.py b/tests/unit/test_coarsening.py new file mode 100644 index 0000000..038c2e5 --- /dev/null +++ b/tests/unit/test_coarsening.py @@ -0,0 +1,560 @@ +"""Equivalence tests for streaming mutual-NN construction in coarsening. + +The streaming implementation in ``_compute_mutual_graph`` replaces the +original monolithic ``D = B @ nnm_internal.T`` with per-species-pair block +computation. These tests verify the output is numerically identical. + +The reference implementation here (``_reference_mutual_graph``) reproduces the +*original* code path verbatim: build full block matrices, materialise D, +mutualise, scale, top-k. This is the spec we test against. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest +import scipy.sparse as spp + +from samap.core.coarsening import _compute_mutual_graph +from samap.core.correlation import _replace +from samap.core.homology import _tanh_scale +from samap.utils import sparse_knn + +# --------------------------------------------------------------------------- +# Reference implementation (reproduces the original _mapper body) +# --------------------------------------------------------------------------- + + +def _reference_mutual_graph( + nnms_in: dict[str, Any], + neigh_from_keys: dict[str, bool], + B: spp.csr_matrix, + offsets: dict[str, int], + n_cells: dict[str, int], + sids: list[str], + k1: int, + N: int, + *, + pairwise: bool, + threshold: float, + scale_edges_by_corr: bool, + wPCA: Any, +) -> spp.csr_matrix: + """Original monolithic D = B @ nnm_internal.T path, for comparison.""" + any_nfk = any(neigh_from_keys[sid] for sid in sids) + + # Build block-diag nnm_internal. For nfk species, the effective block + # is M @ M.T (co-clustering), otherwise the expanded kNN directly. + eff_blocks: list[Any] = [] + for sid in sids: + blk = nnms_in[sid] + if neigh_from_keys[sid]: + eff_blocks.append(blk.dot(blk.T)) + else: + eff_blocks.append(blk) + nnm_internal = spp.block_diag(eff_blocks).tocsr() + + D = B.dot(nnm_internal.T).tocsr() + if not any_nfk and threshold > 0: + D.data[D.data < threshold] = 0 + D.eliminate_zeros() + + D = D.multiply(D.T).tocsr() + D.data[:] = D.data**0.5 + + if scale_edges_by_corr: + x, y = D.nonzero() + vals = _replace(wPCA, x, y) + vals[vals < 1e-3] = 1e-3 + F = D.copy() + F.data[:] = vals + ma = np.asarray(F.max(1).todense()) + ma[ma == 0] = 1 + F = F.multiply(1 / ma).tocsr() + F.data[:] = _tanh_scale(F.data, center=0.7, scale=10) + ma = np.asarray(D.max(1).todense()) + ma[ma == 0] = 1 + D = F.multiply(D).tocsr() + D.data[:] = np.sqrt(D.data) + ma2 = np.asarray(D.max(1).todense()) + ma2[ma2 == 0] = 1 + D = D.multiply(ma / ma2).tocsr() + + if not pairwise or len(sids) == 2: + return sparse_knn(D, k1).tocsr() + + # pairwise top-k per species pair + row = np.array([], dtype="int64") + col = np.array([], dtype="int64") + data = np.array([], dtype="float64") + for a in sids: + ra = np.arange(offsets[a], offsets[a] + n_cells[a]) + for b in sids: + if a == b: + continue + rb = np.arange(offsets[b], offsets[b] + n_cells[b]) + Dsub = sparse_knn(D[ra][:, rb], k1).tocoo() + row = np.append(row, ra[Dsub.row]) + col = np.append(col, rb[Dsub.col]) + data = np.append(data, Dsub.data) + return spp.coo_matrix((data, (row, col)), shape=(N, N)).tocsr() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_sym_knn(rng: np.random.Generator, n: int, k: int, values: bool = False) -> spp.csr_matrix: + """Build a symmetric kNN-ish sparse matrix with positive entries. + + If ``values`` is False the matrix is binary {0,1}; if True, nonzeros are + drawn uniform in (0.5, 1.5) to exercise the threshold floor. + """ + rows: list[int] = [] + cols: list[int] = [] + vals: list[float] = [] + for i in range(n): + nbrs = rng.choice(n, size=min(k, n), replace=False) + nbrs = nbrs[nbrs != i][: max(k - 1, 1)] + for j in nbrs: + v = float(rng.uniform(0.5, 1.5)) if values else 1.0 + rows.extend([i, j]) + cols.extend([j, i]) + vals.extend([v, v]) + M = spp.csr_matrix((vals, (rows, cols)), shape=(n, n)) + M.sum_duplicates() + if not values: + M.data[:] = 1.0 + M.setdiag(0) + M.eliminate_zeros() + return M + + +def _make_cross_B( + rng: np.random.Generator, + sids: list[str], + n_cells: dict[str, int], + offsets: dict[str, int], + N: int, + k: int, +) -> spp.csr_matrix: + """Build a block-off-diagonal cross-species kNN (like mdata['knn']).""" + rows: list[int] = [] + cols: list[int] = [] + vals: list[float] = [] + for a in sids: + na = n_cells[a] + off_a = offsets[a] + for b in sids: + if a == b: + continue + nb = n_cells[b] + off_b = offsets[b] + for i in range(na): + nbrs = rng.choice(nb, size=min(k, nb), replace=False) + for j in nbrs: + rows.append(off_a + i) + cols.append(off_b + int(j)) + vals.append(float(rng.uniform(0.2, 1.0))) + return spp.csr_matrix((vals, (rows, cols)), shape=(N, N)) + + +@pytest.fixture +def two_species_inputs( + rng: np.random.Generator, +) -> dict[str, Any]: + """Synthetic 2-species input: ~80 and ~120 cells, random kNN structure.""" + sids = ["spA", "spB"] + n_cells = {"spA": 80, "spB": 120} + offsets = {"spA": 0, "spB": 80} + N = 200 + + nnms_in = { + "spA": _make_sym_knn(rng, 80, k=8, values=False), + "spB": _make_sym_knn(rng, 120, k=8, values=False), + } + neigh_from_keys = {"spA": False, "spB": False} + B = _make_cross_B(rng, sids, n_cells, offsets, N, k=10) + + return { + "sids": sids, + "n_cells": n_cells, + "offsets": offsets, + "N": N, + "nnms_in": nnms_in, + "neigh_from_keys": neigh_from_keys, + "B": B, + } + + +@pytest.fixture +def three_species_inputs( + rng: np.random.Generator, +) -> dict[str, Any]: + """Synthetic 3-species input for pairwise top-k testing.""" + sids = ["x", "y", "z"] + n_cells = {"x": 60, "y": 70, "z": 50} + offsets = {"x": 0, "y": 60, "z": 130} + N = 180 + + nnms_in = {sid: _make_sym_knn(rng, n_cells[sid], k=6, values=False) for sid in sids} + neigh_from_keys = dict.fromkeys(sids, False) + B = _make_cross_B(rng, sids, n_cells, offsets, N, k=8) + + return { + "sids": sids, + "n_cells": n_cells, + "offsets": offsets, + "N": N, + "nnms_in": nnms_in, + "neigh_from_keys": neigh_from_keys, + "B": B, + } + + +@pytest.fixture +def nfk_inputs(rng: np.random.Generator) -> dict[str, Any]: + """2-species input where one species uses the coclustering path.""" + sids = ["a", "b"] + n_cells = {"a": 90, "b": 70} + offsets = {"a": 0, "b": 90} + N = 160 + + # species a: 4 clusters, one-hot membership + cl_a = rng.integers(0, 4, size=90) + M_a = np.zeros((90, 4)) + M_a[np.arange(90), cl_a] = 1 + M_a = spp.csr_matrix(M_a) + + nnms_in = { + "a": M_a, + "b": _make_sym_knn(rng, 70, k=6, values=False), + } + neigh_from_keys = {"a": True, "b": False} + B = _make_cross_B(rng, sids, n_cells, offsets, N, k=8) + + return { + "sids": sids, + "n_cells": n_cells, + "offsets": offsets, + "N": N, + "nnms_in": nnms_in, + "neigh_from_keys": neigh_from_keys, + "B": B, + } + + +# --------------------------------------------------------------------------- +# Equivalence tests +# --------------------------------------------------------------------------- + + +def _assert_sparse_equal(A: spp.spmatrix, B: spp.spmatrix, atol: float = 1e-12) -> None: + """Assert two sparse matrices are numerically identical.""" + assert A.shape == B.shape, f"shape mismatch: {A.shape} vs {B.shape}" + diff = (A - B).tocoo() + if diff.nnz: + max_abs = float(np.abs(diff.data).max()) + assert max_abs <= atol, ( + f"max abs diff {max_abs} > atol {atol}; " + f"{diff.nnz} entries differ; " + f"nnz(A)={A.nnz}, nnz(B)={B.nnz}" + ) + + +class TestTwoSpecies: + """Equivalence tests for the common 2-species case.""" + + @pytest.mark.parametrize("chunksize", [10_000, 30, 7]) + def test_basic(self, two_species_inputs: dict[str, Any], chunksize: int) -> None: + """Streaming == reference, 2 species, no scaling, various chunk sizes. + + chunksize=7 forces many small chunks within each species to exercise + the chunk-boundary index bookkeeping. + """ + inp = two_species_inputs + k1 = 15 + + ref = _reference_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + got = _compute_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + chunksize=chunksize, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + _assert_sparse_equal(got, ref) + + def test_with_scale_edges_by_corr( + self, two_species_inputs: dict[str, Any], rng: np.random.Generator + ) -> None: + """Streaming == reference with correlation-based edge rescaling. + + wPCA rows must correlate with the kNN structure for nonzero effect; + a pure random wPCA exercises the path regardless. + """ + inp = two_species_inputs + k1 = 12 + wPCA = rng.standard_normal((inp["N"], 50)) + + ref = _reference_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + threshold=0.1, + scale_edges_by_corr=True, + wPCA=wPCA, + ) + got = _compute_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + chunksize=25, + threshold=0.1, + scale_edges_by_corr=True, + wPCA=wPCA, + ) + _assert_sparse_equal(got, ref) + + def test_non_pairwise(self, two_species_inputs: dict[str, Any]) -> None: + """pairwise=False (global per-row top-k) matches reference.""" + inp = two_species_inputs + k1 = 10 + + ref = _reference_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=False, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + got = _compute_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=False, + chunksize=1000, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + _assert_sparse_equal(got, ref) + + +class TestThreeSpecies: + """Multi-species with pairwise per-block top-k.""" + + def test_pairwise_topk(self, three_species_inputs: dict[str, Any]) -> None: + inp = three_species_inputs + k1 = 10 + + ref = _reference_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + got = _compute_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + chunksize=20, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + _assert_sparse_equal(got, ref) + + def test_with_scale_edges_by_corr( + self, three_species_inputs: dict[str, Any], rng: np.random.Generator + ) -> None: + inp = three_species_inputs + k1 = 8 + wPCA = rng.standard_normal((inp["N"], 40)) + + ref = _reference_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + threshold=0.1, + scale_edges_by_corr=True, + wPCA=wPCA, + ) + got = _compute_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + chunksize=15, + threshold=0.1, + scale_edges_by_corr=True, + wPCA=wPCA, + ) + _assert_sparse_equal(got, ref) + + +class TestCoclustering: + """The neigh_from_keys (nfk) coclustering path.""" + + def test_nfk_one_species(self, nfk_inputs: dict[str, Any]) -> None: + """One species uses coclustering; threshold is disabled (matches original).""" + inp = nfk_inputs + k1 = 12 + + ref = _reference_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + threshold=0.0, + scale_edges_by_corr=False, + wPCA=None, + ) + got = _compute_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + inp["B"], + inp["offsets"], + inp["n_cells"], + inp["sids"], + k1, + inp["N"], + pairwise=True, + chunksize=30, + threshold=0.0, + scale_edges_by_corr=False, + wPCA=None, + ) + _assert_sparse_equal(got, ref) + + +class TestEdgeCases: + """Degenerate inputs.""" + + def test_single_species_no_cross(self, rng: np.random.Generator) -> None: + """One species → no cross-species pairs → empty Dk.""" + sids = ["only"] + n_cells = {"only": 50} + offsets = {"only": 0} + N = 50 + nnms_in = {"only": _make_sym_knn(rng, 50, k=5)} + neigh_from_keys = {"only": False} + B = spp.csr_matrix((N, N)) # empty + + got = _compute_mutual_graph( + nnms_in, + neigh_from_keys, + B, + offsets, + n_cells, + sids, + 10, + N, + pairwise=True, + chunksize=1000, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + assert got.shape == (N, N) + assert got.nnz == 0 + + def test_empty_cross_species(self, two_species_inputs: dict[str, Any]) -> None: + """B has no entries → Dk should be empty.""" + inp = two_species_inputs + B_empty = spp.csr_matrix((inp["N"], inp["N"])) + + got = _compute_mutual_graph( + inp["nnms_in"], + inp["neigh_from_keys"], + B_empty, + inp["offsets"], + inp["n_cells"], + inp["sids"], + 10, + inp["N"], + pairwise=True, + chunksize=1000, + threshold=0.1, + scale_edges_by_corr=False, + wPCA=None, + ) + assert got.nnz == 0 diff --git a/tests/unit/test_correlation.py b/tests/unit/test_correlation.py new file mode 100644 index 0000000..8158853 --- /dev/null +++ b/tests/unit/test_correlation.py @@ -0,0 +1,581 @@ +"""Equivalence tests for batched correlation computation in correlation.py. + +The streaming path in ``_compute_pair_corrs`` computes ``Xavg`` in +per-pair-batch tiles instead of materialising the full N × G matrix. These +tests verify the output matches the materialised path bit-identically (or +to machine precision — the arithmetic is identical, order is the same). + +They also verify the dict-free kernel against a reference Pearson/Xi +implementation built directly from NumPy. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest +import scipy.sparse as spp + +from samap.core._backend import Backend +from samap.core.correlation import ( + _compute_pair_corrs, + _corr_kernel, + _replace, + _replace_vectorized, + _resolve_batch_size, + _xicorr, + replace_corr, +) + +# --------------------------------------------------------------------------- +# Reference (pure-NumPy) correlation for a single pair +# --------------------------------------------------------------------------- + + +def _pearson_np(x: np.ndarray, y: np.ndarray) -> float: + """Textbook Pearson — matches the kernel's formula exactly.""" + return float(((x - x.mean()) * (y - y.mean()) / x.std() / y.std()).sum() / x.size) + + +def _ref_corr( + nnms: spp.csr_matrix, + Xs: spp.csc_matrix, + p: np.ndarray, + ps_int: np.ndarray, + sp_starts: np.ndarray, + sp_lens: np.ndarray, + mode: str, +) -> np.ndarray: + """Reference: materialise Xavg, loop in pure Python, correlate.""" + Xavg = np.asarray(nnms.dot(Xs).todense()) + n_pairs = p.shape[0] + res = np.zeros(n_pairs) + for j in range(n_pairs): + g1, g2 = p[j] + s1, s2 = ps_int[j] + st1, ln1 = sp_starts[s1], sp_lens[s1] + st2, ln2 = sp_starts[s2], sp_lens[s2] + + xcol = Xavg[:, g1] + ycol = Xavg[:, g2] + xx = np.concatenate((xcol[st1 : st1 + ln1], xcol[st2 : st2 + ln2])) + yy = np.concatenate((ycol[st1 : st1 + ln1], ycol[st2 : st2 + ln2])) + + if mode == "pearson": + res[j] = _pearson_np(xx, yy) + else: + res[j] = _xicorr(xx, yy) + return res + + +# --------------------------------------------------------------------------- +# Fixtures: synthetic 2-species inputs +# --------------------------------------------------------------------------- + + +def _make_knn(rng: np.random.Generator, n: int, k: int) -> spp.csr_matrix: + """Sparse symmetric kNN with self-loops (averaging operator).""" + rows, cols = [], [] + for i in range(n): + nbrs = rng.choice(n, size=min(k, n), replace=False) + if i not in nbrs: + nbrs[0] = i + for j in nbrs: + rows.append(i) + cols.append(j) + M = spp.csr_matrix((np.ones(len(rows)), (rows, cols)), shape=(n, n)) + M.sum_duplicates() + M.data[:] = 1.0 + return M + + +@pytest.fixture +def corr_inputs(rng: np.random.Generator) -> dict[str, Any]: + """Synthetic 2-species input for correlation testing. + + ~500 cells split 300/200, 200 genes split 120/80, ~500 gene pairs. + Expression is sparse random (20% density) → smoothed Xavg is moderately + dense (realistic). + """ + n_a, n_b = 300, 200 + g_a, g_b = 120, 80 + N = n_a + n_b + G = g_a + g_b + + # Row-normalised averaging operator over full manifold + knn = _make_knn(rng, N, k=15) + rs = np.asarray(knn.sum(1)).flatten() + rs[rs == 0] = 1 + nnms = knn.multiply(1.0 / rs[:, None]).tocsr() + + # Block-diagonal expression: species A uses genes [0, g_a), B uses [g_a, G) + Xa = spp.random(n_a, g_a, density=0.2, format="csr", random_state=1, dtype=np.float32) + Xb = spp.random(n_b, g_b, density=0.2, format="csr", random_state=2, dtype=np.float32) + Xs = spp.block_diag([Xa, Xb]).tocsc() + + # Species layout + sp_starts = np.array([0, n_a], dtype=np.int64) + sp_lens = np.array([n_a, n_b], dtype=np.int64) + + # Gene pairs: each is cross-species (gene from [0,g_a) × gene from [g_a,G)) + n_pairs = 500 + p1 = rng.integers(0, g_a, size=n_pairs) + p2 = rng.integers(g_a, G, size=n_pairs) + p = np.column_stack((p1, p2)).astype(np.int64) + # species IDs: gene < g_a → species 0, else species 1 + ps_int = np.column_stack((np.zeros(n_pairs, dtype=np.int64), np.ones(n_pairs, dtype=np.int64))) + + return { + "nnms": nnms, + "Xs": Xs, + "p": p, + "ps_int": ps_int, + "sp_starts": sp_starts, + "sp_lens": sp_lens, + "N": N, + } + + +@pytest.fixture +def corr_inputs_3sp(rng: np.random.Generator) -> dict[str, Any]: + """3-species variant: pairs span all three species combinations.""" + n = [150, 100, 120] + g = [60, 50, 40] + N = sum(n) + + knn = _make_knn(rng, N, k=12) + rs = np.asarray(knn.sum(1)).flatten() + rs[rs == 0] = 1 + nnms = knn.multiply(1.0 / rs[:, None]).tocsr() + + X_blocks = [ + spp.random(n[i], g[i], density=0.2, format="csr", random_state=i + 10, dtype=np.float32) + for i in range(3) + ] + Xs = spp.block_diag(X_blocks).tocsc() + + n_off = np.cumsum([0, *n]) + g_off = np.cumsum([0, *g]) + sp_starts = n_off[:-1].astype(np.int64) + sp_lens = np.array(n, dtype=np.int64) + + # Generate pairs across all three species combinations + n_pairs_per = 120 + p_list, ps_list = [], [] + combos = [(0, 1), (0, 2), (1, 2)] + for s1, s2 in combos: + p1 = rng.integers(g_off[s1], g_off[s1 + 1], size=n_pairs_per) + p2 = rng.integers(g_off[s2], g_off[s2 + 1], size=n_pairs_per) + p_list.append(np.column_stack((p1, p2))) + ps_list.append(np.column_stack((np.full(n_pairs_per, s1), np.full(n_pairs_per, s2)))) + p = np.vstack(p_list).astype(np.int64) + ps_int = np.vstack(ps_list).astype(np.int64) + # shuffle so batches don't align with species combos + perm = rng.permutation(p.shape[0]) + p, ps_int = p[perm], ps_int[perm] + + return { + "nnms": nnms, + "Xs": Xs, + "p": p, + "ps_int": ps_int, + "sp_starts": sp_starts, + "sp_lens": sp_lens, + "N": N, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestKernelAgainstReference: + """Dict-free kernel matches a pure-NumPy reference.""" + + def test_pearson_vs_numpy(self, corr_inputs: dict[str, Any]) -> None: + inp = corr_inputs + ref = _ref_corr( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + "pearson", + ) + got = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "pearson", + None, + ) + np.testing.assert_allclose(got, ref, rtol=1e-12, atol=1e-14) + + def test_xi_vs_numpy(self, corr_inputs: dict[str, Any]) -> None: + inp = corr_inputs + ref = _ref_corr( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + "xi", + ) + got = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "xi", + None, + ) + np.testing.assert_allclose(got, ref, rtol=1e-12, atol=1e-14) + + +class TestStreamingEquivalence: + """Streaming path (batch_size=int) matches materialised (batch_size=None).""" + + @pytest.mark.parametrize("batch_size", [1, 7, 64, 256, 10_000]) + def test_pearson_batched(self, corr_inputs: dict[str, Any], batch_size: int) -> None: + """Streaming Pearson == materialised, across a range of batch sizes. + + batch_size=1 is the strictest correctness check (every pair isolated); + batch_size=10_000 exercises the single-batch fallthrough. + """ + inp = corr_inputs + ref = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "pearson", + None, + ) + got = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "pearson", + batch_size, + ) + np.testing.assert_allclose(got, ref, rtol=1e-10, atol=1e-12) + + @pytest.mark.parametrize("batch_size", [1, 32, 500]) + def test_xi_batched(self, corr_inputs: dict[str, Any], batch_size: int) -> None: + inp = corr_inputs + ref = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "xi", + None, + ) + got = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "xi", + batch_size, + ) + np.testing.assert_allclose(got, ref, rtol=1e-10, atol=1e-12) + + @pytest.mark.parametrize("batch_size", [1, 50, 200]) + def test_three_species_pearson(self, corr_inputs_3sp: dict[str, Any], batch_size: int) -> None: + """3-species, shuffled pairs across all combos — exercises mixed-batch + species indexing and gene-overlap between batches.""" + inp = corr_inputs_3sp + ref = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "pearson", + None, + ) + got = _compute_pair_corrs( + inp["nnms"], + inp["Xs"], + inp["p"], + inp["ps_int"], + inp["sp_starts"], + inp["sp_lens"], + inp["N"], + "pearson", + batch_size, + ) + np.testing.assert_allclose(got, ref, rtol=1e-10, atol=1e-12) + + +class TestKernelDirect: + """Low-level kernel sanity checks.""" + + def test_kernel_empty_pairs(self) -> None: + """Zero pairs → zero-length result.""" + sp_starts = np.array([0, 10], dtype=np.int64) + sp_lens = np.array([10, 10], dtype=np.int64) + # dummy CSC + M = spp.csc_matrix((20, 5)) + res = _corr_kernel( + np.empty(0, dtype=np.int64), + np.empty(0, dtype=np.int64), + np.empty(0, dtype=np.int64), + np.empty(0, dtype=np.int64), + sp_starts, + sp_lens, + M.indptr, + M.indices, + M.data, + 20, + True, + ) + assert res.size == 0 + + +# --------------------------------------------------------------------------- +# _replace (per-pair Pearson over dense wPCA rows) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def replace_inputs(rng: np.random.Generator) -> dict[str, Any]: + """Dense embedding + random index pairs for _replace tests.""" + n, d = 800, 50 + X = rng.standard_normal((n, d)).astype(np.float64) + n_pairs = 1000 + xi = rng.integers(0, n, size=n_pairs).astype(np.int64) + yi = rng.integers(0, n, size=n_pairs).astype(np.int64) + return {"X": X, "xi": xi, "yi": yi, "n": n, "d": d, "n_pairs": n_pairs} + + +class TestReplaceVectorized: + """_replace_vectorized matches numba _replace and pure-numpy reference.""" + + def test_against_numpy_corrcoef(self, replace_inputs: dict[str, Any]) -> None: + """Vectorised form matches np.corrcoef pairwise (rtol=1e-12).""" + inp = replace_inputs + bk = Backend("cpu") + + got = _replace_vectorized(inp["X"], inp["xi"], inp["yi"], bk) + + # Reference via np.corrcoef — O(n_pairs) loop, but authoritative + ref = np.array( + [np.corrcoef(inp["X"][i], inp["X"][j])[0, 1] for i, j in zip(inp["xi"], inp["yi"])] + ) + np.testing.assert_allclose(got, ref, rtol=1e-12, atol=1e-14) + + def test_against_numba(self, replace_inputs: dict[str, Any]) -> None: + """Vectorised form matches numba _replace (the CPU fast path).""" + inp = replace_inputs + bk = Backend("cpu") + + numba_res = _replace(inp["X"], inp["xi"], inp["yi"]) + vec_res = _replace_vectorized(inp["X"], inp["xi"], inp["yi"], bk) + + np.testing.assert_allclose(vec_res, numba_res, rtol=1e-12, atol=1e-14) + + @pytest.mark.parametrize("batch_size", [1, 7, 100, 500, 2000]) + def test_batched_matches_full(self, replace_inputs: dict[str, Any], batch_size: int) -> None: + """Chunked vectorised == single-shot vectorised (all batch sizes). + + batch_size=1 is the tightest correctness probe; 2000 > n_pairs + exercises the fallthrough. + """ + inp = replace_inputs + bk = Backend("cpu") + + full = _replace_vectorized(inp["X"], inp["xi"], inp["yi"], bk, batch_size=None) + chunked = _replace_vectorized(inp["X"], inp["xi"], inp["yi"], bk, batch_size=batch_size) + np.testing.assert_allclose(chunked, full, rtol=0, atol=0) + + def test_float32_input(self, replace_inputs: dict[str, Any]) -> None: + """float32 input → float64 output, matches float64 input path. + + wPCA is often stored float32 for memory; the vectorised form + upcasts internally to match _replace's float64 arithmetic. + """ + inp = replace_inputs + bk = Backend("cpu") + X32 = inp["X"].astype(np.float32) + + res32 = _replace_vectorized(X32, inp["xi"], inp["yi"], bk) + res64 = _replace_vectorized(inp["X"], inp["xi"], inp["yi"], bk) + + assert res32.dtype == np.float64 + # float32 input has less precision → looser tolerance + np.testing.assert_allclose(res32, res64, rtol=1e-5, atol=1e-7) + + def test_zero_variance_row(self, rng: np.random.Generator) -> None: + """Constant row → std=0 → nan (matches _replace behaviour).""" + bk = Backend("cpu") + X = rng.standard_normal((10, 20)) + X[3, :] = 5.0 # constant → zero variance + + xi = np.array([3, 0], dtype=np.int64) + yi = np.array([1, 2], dtype=np.int64) + + vec = _replace_vectorized(X, xi, yi, bk) + numba = _replace(X, xi, yi) + + assert np.isnan(vec[0]) + assert np.isnan(numba[0]) + np.testing.assert_allclose(vec[1], numba[1], rtol=1e-12) + + +class TestReplaceCorrDispatcher: + """replace_corr routes to numba on CPU, vectorised on GPU.""" + + def test_cpu_backend_uses_numba(self, replace_inputs: dict[str, Any]) -> None: + """CPU backend → numba path; result matches _replace directly.""" + inp = replace_inputs + bk = Backend("cpu") + + disp = replace_corr(inp["X"], inp["xi"], inp["yi"], bk) + numba = _replace(inp["X"], inp["xi"], inp["yi"]) + # CPU dispatch IS the numba path → bit-identical + np.testing.assert_array_equal(disp, numba) + + def test_bk_none_defaults_to_numba(self, replace_inputs: dict[str, Any]) -> None: + """bk=None (backward-compat) → numba path.""" + inp = replace_inputs + disp = replace_corr(inp["X"], inp["xi"], inp["yi"], bk=None) + numba = _replace(inp["X"], inp["xi"], inp["yi"]) + np.testing.assert_array_equal(disp, numba) + + def test_mock_gpu_uses_vectorized(self, replace_inputs: dict[str, Any]) -> None: + """Mock Backend with gpu=True → vectorised path. + + We can't test a real GPU path on CI; this verifies the dispatch + logic by constructing a duck-typed backend with gpu=True + xp=numpy. + """ + + class _MockGPUBackend: + gpu = True + xp = np # numpy stands in for cupy here + + inp = replace_inputs + bk = _MockGPUBackend() + + disp = replace_corr(inp["X"], inp["xi"], inp["yi"], bk, batch_size=100) + ref = _replace_vectorized(inp["X"], inp["xi"], inp["yi"], bk, batch_size=100) + np.testing.assert_array_equal(disp, ref) + # and that it matches numba to fp tolerance + numba = _replace(inp["X"], inp["xi"], inp["yi"]) + np.testing.assert_allclose(disp, numba, rtol=1e-12, atol=1e-14) + + +# --------------------------------------------------------------------------- +# _resolve_batch_size auto-selection heuristic +# --------------------------------------------------------------------------- + + +class TestResolveBatchSize: + """Auto-selection of materialised vs streaming based on estimated memory.""" + + def test_explicit_passthrough(self, rng: np.random.Generator) -> None: + """Explicit batch_size (non-'auto') is returned unchanged.""" + nnms = spp.eye(100, format="csr") + Xs = spp.random(100, 50, density=0.1, format="csc") + + # All explicit values pass through untouched — including None. + assert _resolve_batch_size(None, nnms, Xs) is None + assert _resolve_batch_size(32, nnms, Xs) == 32 + assert _resolve_batch_size(9999, nnms, Xs) == 9999 + + def test_tiny_data_materialises(self, rng: np.random.Generator) -> None: + """Toy-scale data (hundreds of cells) → auto picks materialised. + + At 500 cells × 200 genes, even 100% density is 800 KB — far under + the 2 GB default threshold. + """ + # Realistic toy: 500 cells, avg 15 neighbours, 20% expression density + nnms = spp.random(500, 500, density=15 / 500, format="csr") + Xs = spp.random(500, 200, density=0.2, format="csc") + + got = _resolve_batch_size("auto", nnms, Xs, mem_threshold_gb=2.0) + assert got is None + + def test_million_cell_streams(self) -> None: + """Million-cell scale → auto picks streaming. + + 1M cells × 10k genes × ~50% density (after kNN fill-in from 5% + input density with k~20) ≈ 60 GB CSC. Well over any threshold. + We mock shapes/nnz rather than allocating a real million-entry + matrix — _resolve_batch_size only reads .shape and .nnz. + """ + + class _MockSparse: + def __init__(self, shape: tuple[int, int], nnz: int) -> None: + self.shape = shape + self.nnz = nnz + + n_cells, n_genes = 1_000_000, 10_000 + k = 20 + expr_nnz = int(n_cells * n_genes * 0.05) # 5% expression density + + nnms = _MockSparse(shape=(n_cells, n_cells), nnz=n_cells * k) + Xs = _MockSparse(shape=(n_cells, n_genes), nnz=expr_nnz) + + got = _resolve_batch_size("auto", nnms, Xs, mem_threshold_gb=2.0) + assert got == 512 + + def test_threshold_boundary(self) -> None: + """Crossing the threshold flips the decision. + + Fixed shapes → fixed estimate. Vary mem_threshold_gb above and + below the estimate to verify the boundary logic. + """ + + class _MockSparse: + def __init__(self, shape: tuple[int, int], nnz: int) -> None: + self.shape = shape + self.nnz = nnz + + # 100k cells × 5k genes, k=15, density 10% → output density ~80% + # → est = 100k * 5k * 0.8 * 12 bytes ≈ 4.8 GB + n_cells, n_genes = 100_000, 5_000 + nnms = _MockSparse(shape=(n_cells, n_cells), nnz=n_cells * 15) + Xs = _MockSparse(shape=(n_cells, n_genes), nnz=int(n_cells * n_genes * 0.10)) + + # threshold above estimate → materialise + assert _resolve_batch_size("auto", nnms, Xs, mem_threshold_gb=10.0) is None + # threshold below estimate → stream + assert _resolve_batch_size("auto", nnms, Xs, mem_threshold_gb=1.0) == 512 + + def test_zero_sized_inputs(self) -> None: + """Degenerate 0-cell / 0-gene inputs → materialise (trivial).""" + nnms = spp.csr_matrix((0, 0)) + Xs = spp.csc_matrix((0, 0)) + assert _resolve_batch_size("auto", nnms, Xs) is None + + nnms = spp.eye(50, format="csr") + Xs = spp.csc_matrix((50, 0)) + assert _resolve_batch_size("auto", nnms, Xs) is None diff --git a/tests/unit/test_expand.py b/tests/unit/test_expand.py new file mode 100644 index 0000000..185e6e0 --- /dev/null +++ b/tests/unit/test_expand.py @@ -0,0 +1,262 @@ +"""Unit tests for :mod:`samap.core.expand`. + +Compares the BFS neighbourhood expansion against the original matrix-power +implementation. The two are exactly equivalent when every cell's budget is +large enough to absorb its full reachable-within-NH-hops set. When budgets +truncate, they may pick different marginal neighbours (see module docstring). +""" + +from __future__ import annotations + +import numpy as np +import pytest +import scipy.sparse as sp +from sklearn.neighbors import kneighbors_graph + +from samap.core.expand import ( + _smart_expand, + _smart_expand_bfs, + _smart_expand_matpow, +) + + +@pytest.fixture +def rng() -> np.random.Generator: + return np.random.default_rng(42) + + +def _make_knn_graph( + n_cells: int, k: int, n_clusters: int, rng: np.random.Generator +) -> tuple[sp.csr_matrix, np.ndarray]: + """Build a weighted kNN graph over synthetic blob data. + + Returns the (symmetrised, row-normalised) connectivity matrix and an + integer cluster-label array. Row normalisation mimics what scanpy's + connectivities look like (weights in (0, 1], diagonal absent). + """ + # Place cluster centres on a circle so they're well-separated. + centres = ( + np.stack( + [ + np.cos(2 * np.pi * np.arange(n_clusters) / n_clusters), + np.sin(2 * np.pi * np.arange(n_clusters) / n_clusters), + ], + axis=1, + ) + * 10.0 + ) + labels = rng.integers(0, n_clusters, size=n_cells) + pts = centres[labels] + rng.normal(scale=1.0, size=(n_cells, 2)) + + # Distance-weighted kNN, exclude self. + A = kneighbors_graph(pts, n_neighbors=k, mode="distance", include_self=False) + # Convert distances → similarities (Gaussian-ish), symmetrise, drop diag. + A.data = np.exp(-A.data / A.data.mean()) + A = A.maximum(A.T).tocsr() + A.setdiag(0.0) + A.eliminate_zeros() + return A, labels + + +def _edge_set(A: sp.spmatrix) -> set[tuple[int, int]]: + A = A.tocoo() + return set(zip(A.row.tolist(), A.col.tolist(), strict=True)) + + +# --------------------------------------------------------------------------- +# Exact-equivalence regime: budget ≥ reachable set. +# --------------------------------------------------------------------------- + + +def test_bfs_matches_matpow_when_budget_covers_reachable(rng: np.random.Generator) -> None: + """With budget ≥ reachable-within-NH-hops, BFS and matpow must agree exactly. + + Here NH=1 (two hops total) on a small sparse graph; the reachable set per + cell is well under the budget, so both algorithms collect the full set + and the choice of in-ring ranking is irrelevant. + """ + n, k = 120, 4 + nnm, _ = _make_knn_graph(n, k, n_clusters=6, rng=rng) + + # Generous budget — well above the 2-hop reachable count for k=4. + K = np.full(n, 60, dtype=np.int64) + + out_old = _smart_expand_matpow(nnm, K.copy(), NH=1) + out_new = _smart_expand_bfs(nnm, K.copy(), NH=1) + + old_edges = _edge_set(out_old) + new_edges = _edge_set(out_new) + + # matpow can include self-loops at even hops (an nnm^2 diagonal entry + # survives the ring subtraction if nnm itself has no diagonal). BFS + # never collects self. Strip self-loops from the matpow output before + # comparing. + old_edges = {(r, c) for (r, c) in old_edges if r != c} + + assert old_edges == new_edges, ( + f"edge-set mismatch: " + f"{len(old_edges - new_edges)} matpow-only, " + f"{len(new_edges - old_edges)} bfs-only" + ) + + +def test_bfs_matches_matpow_single_hop(rng: np.random.Generator) -> None: + """With NH=0 (direct neighbours only), both algorithms are pure top-k. + + This is the trivial case — no multi-hop expansion, so the in-ring + ranking is identical (both use the edge weights directly). + """ + n, k = 200, 10 + nnm, labels = _make_knn_graph(n, k, n_clusters=5, rng=rng) + + # Per-cell budget = cluster size (typical usage). + _, ix, counts = np.unique(labels, return_inverse=True, return_counts=True) + K = counts[ix].astype(np.int64) + + out_old = _smart_expand_matpow(nnm, K.copy(), NH=0) + out_new = _smart_expand_bfs(nnm, K.copy(), NH=0) + + assert _edge_set(out_old) == _edge_set(out_new) + + +# --------------------------------------------------------------------------- +# Truncation regime: budget < reachable set. Characterise divergence. +# --------------------------------------------------------------------------- + + +def test_bfs_near_matpow_when_budget_truncates(rng: np.random.Generator) -> None: + """With tight budgets, BFS and matpow may differ at the margin. + + Both prioritise by hop distance, so divergence is confined to the *last* + ring a cell draws from and only when that ring overflows the remaining + budget. We assert high Jaccard similarity and that per-cell output + sizes match exactly (both fill the same budget). + """ + n, k = 500, 20 + nnm, labels = _make_knn_graph(n, k, n_clusters=8, rng=rng) + + _, ix, counts = np.unique(labels, return_inverse=True, return_counts=True) + K = counts[ix].astype(np.int64) + + out_old = _smart_expand_matpow(nnm, K.copy(), NH=3) + out_new = _smart_expand_bfs(nnm, K.copy(), NH=3) + + # Strip self-loops from matpow (see exact-equivalence test). + old_edges = {(r, c) for (r, c) in _edge_set(out_old) if r != c} + new_edges = _edge_set(out_new) + + inter = len(old_edges & new_edges) + union = len(old_edges | new_edges) + jaccard = inter / union if union else 1.0 + + # Per-cell output cardinality should be very close (both fill budget or + # exhaust reachable set; matpow may have +1 from a self-loop). + old_nnz = np.asarray((out_old != 0).sum(axis=1)).ravel() + new_nnz = np.asarray((out_new != 0).sum(axis=1)).ravel() + # Allow matpow up to +1 per cell (the self-loop). + assert np.all(new_nnz <= old_nnz) + assert np.all(old_nnz - new_nnz <= 1) + + # In practice Jaccard is ~0.95+ here. 0.9 is a conservative floor — + # if this fails the algorithms have diverged meaningfully and should be + # investigated, not just have the threshold lowered. + assert jaccard > 0.9, f"Jaccard={jaccard:.3f} — BFS diverged from matpow" + + +# --------------------------------------------------------------------------- +# Structural invariants +# --------------------------------------------------------------------------- + + +def test_bfs_output_binarized(rng: np.random.Generator) -> None: + """BFS output data is all 1.0 — structure is the only signal.""" + n, k = 100, 8 + nnm, _ = _make_knn_graph(n, k, n_clusters=4, rng=rng) + K = np.full(n, 30, dtype=np.int64) + out = _smart_expand_bfs(nnm, K, NH=2) + assert out.nnz > 0 + np.testing.assert_array_equal(out.data, np.ones(out.nnz)) + + +def test_bfs_no_self_loops(rng: np.random.Generator) -> None: + """BFS never collects a cell as its own neighbour.""" + n, k = 100, 8 + nnm, _ = _make_knn_graph(n, k, n_clusters=4, rng=rng) + K = np.full(n, 50, dtype=np.int64) + out = _smart_expand_bfs(nnm, K, NH=3) + assert out.diagonal().sum() == 0 + + +def test_bfs_respects_budget(rng: np.random.Generator) -> None: + """No cell collects more neighbours than its budget.""" + n, k = 200, 12 + nnm, labels = _make_knn_graph(n, k, n_clusters=5, rng=rng) + _, ix, counts = np.unique(labels, return_inverse=True, return_counts=True) + K = counts[ix].astype(np.int64) + out = _smart_expand_bfs(nnm, K, NH=3) + nnz_per_row = np.asarray((out != 0).sum(axis=1)).ravel() + assert np.all(nnz_per_row <= K) + + +def test_bfs_zero_budget(rng: np.random.Generator) -> None: + """Cells with K=0 contribute nothing.""" + n, k = 50, 5 + nnm, _ = _make_knn_graph(n, k, n_clusters=3, rng=rng) + K = np.zeros(n, dtype=np.int64) + K[::3] = 10 # only every third cell gets a budget + out = _smart_expand_bfs(nnm, K, NH=2) + nnz_per_row = np.asarray((out != 0).sum(axis=1)).ravel() + assert np.all(nnz_per_row[K == 0] == 0) + assert np.any(nnz_per_row[K > 0] > 0) + + +def test_bfs_disconnected_component(rng: np.random.Generator) -> None: + """BFS on a disconnected node collects nothing (gracefully).""" + n = 50 + nnm, _ = _make_knn_graph(n, k=5, n_clusters=3, rng=rng) + # Isolate the last cell. + nnm = nnm.tolil() + nnm[n - 1, :] = 0 + nnm[:, n - 1] = 0 + nnm = nnm.tocsr() + nnm.eliminate_zeros() + K = np.full(n, 20, dtype=np.int64) + out = _smart_expand_bfs(nnm, K, NH=3) + # Isolated cell has no neighbours to collect. + assert out.getrow(n - 1).nnz == 0 + # Other cells still work. + assert out.getrow(0).nnz > 0 + + +def test_bfs_empty_graph() -> None: + """Zero-cell input → zero-cell output.""" + nnm = sp.csr_matrix((0, 0), dtype=np.float64) + K = np.array([], dtype=np.int64) + out = _smart_expand_bfs(nnm, K, NH=3) + assert out.shape == (0, 0) + assert out.nnz == 0 + + +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- + + +def test_dispatch_legacy_true_calls_matpow(rng: np.random.Generator) -> None: + """_smart_expand(legacy=True) delegates to matpow.""" + n, k = 80, 6 + nnm, _ = _make_knn_graph(n, k, n_clusters=4, rng=rng) + K = np.full(n, 30, dtype=np.int64) + out_dispatch = _smart_expand(nnm, K.copy(), NH=2, legacy=True) + out_direct = _smart_expand_matpow(nnm, K.copy(), NH=2) + assert _edge_set(out_dispatch) == _edge_set(out_direct) + + +def test_dispatch_legacy_false_calls_bfs(rng: np.random.Generator) -> None: + """_smart_expand(legacy=False) delegates to BFS.""" + n, k = 80, 6 + nnm, _ = _make_knn_graph(n, k, n_clusters=4, rng=rng) + K = np.full(n, 30, dtype=np.int64) + out_dispatch = _smart_expand(nnm, K.copy(), NH=2, legacy=False) + out_direct = _smart_expand_bfs(nnm, K.copy(), NH=2) + assert _edge_set(out_dispatch) == _edge_set(out_direct) diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py new file mode 100644 index 0000000..1ab27ed --- /dev/null +++ b/tests/unit/test_knn.py @@ -0,0 +1,241 @@ +"""Unit tests for samap.core.knn — CPU/GPU kNN dispatch.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from samap.core._backend import HAS_CUPY, Backend +from samap.core.knn import HAS_FAISS, _hnswlib_knn, approximate_knn + +if HAS_CUPY: + import cupy as cp + + _CUDA = cp.is_available() +else: + _CUDA = False + +# The FAISS GPU path needs cupy (for device arrays), cuda, and a GPU-enabled +# faiss build. On the macOS dev machine none of these hold. +_FAISS_GPU_AVAILABLE = _CUDA and HAS_FAISS +if _FAISS_GPU_AVAILABLE: + import faiss + + _FAISS_GPU_AVAILABLE = hasattr(faiss, "StandardGpuResources") + +gpu_only = pytest.mark.skipif(not _FAISS_GPU_AVAILABLE, reason="requires cupy + CUDA + faiss-gpu") + + +# --------------------------------------------------------------------------- +# Reference exact cosine kNN (numpy brute-force) +# --------------------------------------------------------------------------- + + +def _brute_force_cosine_knn( + queries: np.ndarray, database: np.ndarray, k: int +) -> tuple[np.ndarray, np.ndarray]: + """Exact cosine kNN by dense matmul — reference for recall / exactness.""" + qn = queries / np.linalg.norm(queries, axis=1, keepdims=True) + dn = database / np.linalg.norm(database, axis=1, keepdims=True) + sims = qn @ dn.T # (n_q, n_d) cosine similarities + # Top-k by similarity = smallest-k by distance + idx = np.argpartition(-sims, kth=k - 1, axis=1)[:, :k] + # Sort within each row so neighbours are ordered near → far + row_sims = np.take_along_axis(sims, idx, axis=1) + order = np.argsort(-row_sims, axis=1) + idx_sorted = np.take_along_axis(idx, order, axis=1) + dist_sorted = 1.0 - np.take_along_axis(row_sims, order, axis=1) + return idx_sorted, dist_sorted + + +@pytest.fixture +def bk_cpu() -> Backend: + return Backend("cpu") + + +@pytest.fixture +def small_data() -> tuple[np.ndarray, np.ndarray]: + """~500 database points, 50 queries, 16 dims — small enough for exact ref.""" + rng = np.random.default_rng(42) + n_db, n_q, dim = 500, 50, 16 + db = rng.standard_normal((n_db, dim)).astype(np.float32) + q = rng.standard_normal((n_q, dim)).astype(np.float32) + return q, db + + +# --------------------------------------------------------------------------- +# Output-format contracts +# --------------------------------------------------------------------------- + + +class TestOutputFormat: + def test_shapes(self, bk_cpu: Backend, small_data) -> None: + q, db = small_data + k = 5 + idx, dist = approximate_knn(q, db, k=k, metric="cosine", bk=bk_cpu) + assert idx.shape == (q.shape[0], k) + assert dist.shape == (q.shape[0], k) + + def test_indices_are_int(self, bk_cpu: Backend, small_data) -> None: + q, db = small_data + idx, _ = approximate_knn(q, db, k=5, metric="cosine", bk=bk_cpu) + assert np.issubdtype(idx.dtype, np.integer) + # All in bounds + assert idx.min() >= 0 + assert idx.max() < db.shape[0] + + def test_cosine_distances_in_range(self, bk_cpu: Backend, small_data) -> None: + q, db = small_data + _, dist = approximate_knn(q, db, k=5, metric="cosine", bk=bk_cpu) + # cosine distance = 1 - cos_sim, always in [0, 2] + assert dist.min() >= 0.0 - 1e-6 + assert dist.max() <= 2.0 + 1e-6 + + def test_distances_sorted_ascending(self, bk_cpu: Backend, small_data) -> None: + q, db = small_data + _, dist = approximate_knn(q, db, k=10, metric="cosine", bk=bk_cpu) + # Each row should be non-decreasing (closest first) + assert (np.diff(dist, axis=1) >= -1e-6).all() + + def test_default_bk_is_cpu(self, small_data) -> None: + """bk=None should create a CPU backend silently.""" + q, db = small_data + idx, _dist = approximate_knn(q, db, k=3) + assert idx.shape == (q.shape[0], 3) + + +# --------------------------------------------------------------------------- +# CPU HNSW recall vs brute force +# --------------------------------------------------------------------------- + + +class TestHnswRecall: + def test_recall_above_95pct(self, small_data) -> None: + """With ef=200, M=48 and k=10 on 500 points, HNSW recall is near-perfect.""" + q, db = small_data + k = 10 + idx_hnsw, _ = _hnswlib_knn(q, db, k=k, metric="cosine") + idx_exact, _ = _brute_force_cosine_knn(q, db, k=k) + + # Recall@k: for each query, fraction of HNSW neighbours that appear + # in the exact top-k. + hits = 0 + for row_h, row_e in zip(idx_hnsw, idx_exact): + hits += len(set(row_h.tolist()) & set(row_e.tolist())) + recall = hits / (q.shape[0] * k) + assert recall > 0.95, f"HNSW recall too low: {recall:.3f}" + + def test_distances_close_to_exact(self, small_data) -> None: + """HNSW distances should match brute-force to float32 precision when + the same neighbour is found.""" + q, db = small_data + _idx_hnsw, dist_hnsw = _hnswlib_knn(q, db, k=1, metric="cosine") + _, dist_exact = _brute_force_cosine_knn(q, db, k=1) + + # For queries where HNSW found the true nearest neighbour, the + # distance should match. + np.testing.assert_allclose(dist_hnsw, dist_exact, atol=1e-5) + + def test_deterministic_single_thread(self, small_data) -> None: + """num_threads=1 + fixed seed gives reproducible index → reproducible + results. (Proxy for golden-test determinism.)""" + q, db = small_data + idx_a, dist_a = _hnswlib_knn(q, db, k=5, num_threads=1) + idx_b, dist_b = _hnswlib_knn(q, db, k=5, num_threads=1) + np.testing.assert_array_equal(idx_a, idx_b) + np.testing.assert_array_equal(dist_a, dist_b) + + +# --------------------------------------------------------------------------- +# GPU brute-force (FAISS) — exact +# --------------------------------------------------------------------------- + + +@gpu_only +class TestFaissGPU: + @pytest.fixture + def bk_gpu(self) -> Backend: + return Backend("cuda") + + def test_exact_vs_brute_force(self, bk_gpu: Backend, small_data) -> None: + """GpuIndexFlatIP is exact — neighbour sets must match brute-force.""" + q, db = small_data + k = 10 + + idx_faiss, dist_faiss = approximate_knn(q, db, k=k, metric="cosine", bk=bk_gpu) + idx_exact, dist_exact = _brute_force_cosine_knn(q, db, k=k) + + # Neighbour sets should be identical (exact search). Use sets per row + # to ignore tie-breaking order differences. + for i in range(q.shape[0]): + assert set(idx_faiss[i].tolist()) == set(idx_exact[i].tolist()), ( + f"row {i}: FAISS {idx_faiss[i]} != exact {idx_exact[i]}" + ) + # Distances match to float32 precision. + np.testing.assert_allclose( + np.sort(dist_faiss, axis=1), + np.sort(dist_exact, axis=1), + atol=1e-5, + ) + + def test_accepts_cupy_arrays(self, bk_gpu: Backend) -> None: + rng = np.random.default_rng(0) + q = cp.asarray(rng.standard_normal((20, 8)).astype(np.float32)) + db = cp.asarray(rng.standard_normal((100, 8)).astype(np.float32)) + idx, dist = approximate_knn(q, db, k=5, metric="cosine", bk=bk_gpu) + # Results are returned on host + assert isinstance(idx, np.ndarray) + assert isinstance(dist, np.ndarray) + assert idx.shape == (20, 5) + + def test_resources_cached_on_backend(self, bk_gpu: Backend) -> None: + res1 = bk_gpu.faiss_gpu_resources() + res2 = bk_gpu.faiss_gpu_resources() + assert res1 is res2 + assert res1 is not None + + def test_non_cosine_metric_raises(self, bk_gpu: Backend, small_data) -> None: + """The FAISS-GPU path is cosine-only; other metrics must fail loud. + + (Dispatching via approximate_knn would work because the non-cosine + branch never reaches _faiss_gpu_knn — this tests the internal directly.) + """ + from samap.core.knn import _faiss_gpu_knn + + q, db = small_data + with pytest.raises(ValueError, match="only supports metric='cosine'"): + _faiss_gpu_knn(q, db, k=5, metric="l2", bk=bk_gpu) + + def test_faiss_matches_hnsw_distances(self, bk_gpu: Backend, small_data) -> None: + """Sanity cross-check: FAISS and HNSW agree on nearest-neighbour distance.""" + q, db = small_data + _, dist_gpu = approximate_knn(q, db, k=1, metric="cosine", bk=bk_gpu) + _, dist_cpu = _hnswlib_knn(q, db, k=1, metric="cosine") + np.testing.assert_allclose(dist_gpu, dist_cpu, atol=1e-5) + + +# --------------------------------------------------------------------------- +# Graceful fallback — GPU backend without faiss-gpu +# --------------------------------------------------------------------------- + + +class TestFallback: + def test_cpu_backend_always_uses_hnswlib(self, bk_cpu: Backend, small_data, caplog) -> None: + """On CPU backend, no fallback warning — hnswlib is the direct path.""" + q, db = small_data + with caplog.at_level("WARNING"): + approximate_knn(q, db, k=3, metric="cosine", bk=bk_cpu) + assert "faiss" not in caplog.text.lower() + + @pytest.mark.skipif( + not _CUDA or _FAISS_GPU_AVAILABLE, + reason="needs CUDA but *without* faiss-gpu to test the fallback", + ) + def test_gpu_without_faiss_warns_and_falls_back(self, small_data, caplog) -> None: + """GPU backend + no faiss-gpu → warning + hnswlib path.""" + q, db = small_data + bk = Backend("cuda") + with caplog.at_level("WARNING"): + idx, _ = approximate_knn(q, db, k=3, metric="cosine", bk=bk) + assert "faiss" in caplog.text.lower() + assert idx.shape == (q.shape[0], 3) diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py new file mode 100644 index 0000000..3aae406 --- /dev/null +++ b/tests/unit/test_pca.py @@ -0,0 +1,311 @@ +"""Unit tests for samap.sam.pca — ARPACK vs randomized SVD with implicit centering. + +Randomized SVD and ARPACK find *different orthonormal bases* for the same +leading singular subspace — so we compare subspace angles and reconstruction +error, not raw component matrices. +""" + +from __future__ import annotations + +import numpy as np +import pytest +import scipy.sparse as sp +from scipy.linalg import subspace_angles +from sklearn.decomposition import PCA + +from samap.core._backend import HAS_CUPY, Backend +from samap.sam.pca import _pca_with_sparse, randomized_svd_implicit_center + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def low_rank_sparse(): + """Sparse matrix with a strong low-rank signal. + + Built as (scaled low-rank) + small noise, then sparsified to ~10% density. + The signal is boosted 100× relative to the noise so that the random + sparsification (which zeros 90% of entries — a *large* structured + perturbation) does not swamp the top singular directions. Without that + boost the post-sparsification spectrum collapses to an almost-flat noise + floor after the first 3-4 modes, making subspace comparison meaningless. + """ + rng = np.random.default_rng(42) + n, m = 500, 200 + rank = 15 + # Exponential singular spectrum, scaled large so it survives sparsification. + svals = 100.0 * np.exp(-0.3 * np.arange(rank)) + U = rng.standard_normal((n, rank)) + V = rng.standard_normal((m, rank)) + # Orthonormalise so svals are the actual pre-sparsification singular values + U, _ = np.linalg.qr(U) + V, _ = np.linalg.qr(V) + dense = (U * svals) @ V.T + 0.01 * rng.standard_normal((n, m)) + mask = rng.random((n, m)) > 0.10 + dense[mask] = 0.0 + X = sp.csr_matrix(dense.astype(np.float64)) + # rank=15 is the true signal rank — tests should compare subspaces up to + # ~10-12 PCs; beyond that the noise-floor PCs have no unique "right + # answer" and randomized vs ARPACK will legitimately differ. + return X, n, m + + +# Within-signal-rank PC count for subspace comparisons. Beyond this the +# singular spectrum flattens out (noise floor from sparsification) and there +# is no unique correct basis — randomized and ARPACK will find different but +# equally valid directions. +_SIGNAL_PCS = 12 + + +# --------------------------------------------------------------------------- +# Core equivalence — ARPACK vs randomized +# --------------------------------------------------------------------------- + + +class TestArpackVsRandomized: + """Both solvers should recover the same leading singular subspace.""" + + def test_explained_variance_close(self, low_rank_sparse): + """Top-PC explained variances agree within ~1%. + + ARPACK converges tightly; randomized SVD is approximate but with 4 + power iterations and 10 oversamples it should match the top modes + very well. + """ + X, _, _ = low_rank_sparse + k = 50 + out_arp = _pca_with_sparse(X, k, svd_solver="arpack", seed=0) + out_rnd = _pca_with_sparse(X, k, svd_solver="randomized", seed=0) + + var_arp = out_arp["variance"] + var_rnd = out_rnd["variance"] + + # Top 10 PCs: should be very close (< 1% relative error) + rel_err_top = np.abs(var_arp[:10] - var_rnd[:10]) / var_arp[:10] + assert rel_err_top.max() < 0.01, ( + f"Top-10 variance relative error {rel_err_top.max():.4f} > 1%" + ) + + # All k PCs: bound the aggregate. The tail PCs can drift more + # (randomized SVD is less accurate there) but the *sum* of variances + # should be within a few percent — that's what matters for PCA. + var_sum_arp = var_arp.sum() + var_sum_rnd = var_rnd.sum() + assert abs(var_sum_arp - var_sum_rnd) / var_sum_arp < 0.05 + + def test_subspace_angle_small(self, low_rank_sparse): + """Principal angles between the two component bases should be near zero. + + We compare only the top `_SIGNAL_PCS` — within the true signal rank. + Beyond that the noise floor has no unique basis (any orthonormal set + of noise directions is equally correct) so randomized and ARPACK will + find different but equally valid ones. The full k=50 subspaces would + show a ~1 rad max angle purely from that ambiguity. + """ + X, _, _ = low_rank_sparse + k = 50 # oversampled — we still truncate the comparison below + out_arp = _pca_with_sparse(X, k, svd_solver="arpack", seed=0) + out_rnd = _pca_with_sparse(X, k, svd_solver="randomized", seed=0) + + # subspace_angles takes column vectors; components are row vectors + angles = subspace_angles( + out_arp["components"][:_SIGNAL_PCS].T, + out_rnd["components"][:_SIGNAL_PCS].T, + ) + # With n_power=4 and a clear spectral gap, the top subspace should + # align to well under 0.1 rad. + assert angles.max() < 0.1, ( + f"Max principal angle {angles.max():.4f} rad — subspaces disagree" + ) + + def test_reconstruction_error_similar(self, low_rank_sparse): + """Low-rank reconstruction ``X_pca @ components`` should be near-identical. + + This is basis-invariant: the product ``U·Σ·Vᵀ`` is the same for any + pair of orthonormal bases spanning the same singular subspace. + Restricted to the signal PCs only — noise-floor PCs reconstruct + different (but equally valid) noise approximations. + """ + X, _, _ = low_rank_sparse + k = 50 + out_arp = _pca_with_sparse(X, k, svd_solver="arpack", seed=0) + out_rnd = _pca_with_sparse(X, k, svd_solver="randomized", seed=0) + + rec_arp = out_arp["X_pca"][:, :_SIGNAL_PCS] @ out_arp["components"][:_SIGNAL_PCS] + rec_rnd = out_rnd["X_pca"][:, :_SIGNAL_PCS] @ out_rnd["components"][:_SIGNAL_PCS] + + diff = np.linalg.norm(rec_arp - rec_rnd) + ref = np.linalg.norm(rec_arp) + assert diff / ref < 0.05, f"Reconstruction differs by {diff / ref:.2%}" + + def test_output_shapes_and_dtypes(self, low_rank_sparse): + """Both paths return the same dict schema.""" + X, n, m = low_rank_sparse + k = 20 + for solver in ("arpack", "randomized"): + out = _pca_with_sparse(X, k, svd_solver=solver, seed=0) + assert out["X_pca"].shape == (n, k) + assert out["components"].shape == (k, m) + assert out["variance"].shape == (k,) + assert out["variance_ratio"].shape == (k,) + # Variances should be descending + assert (np.diff(out["variance"]) <= 1e-10).all() + + def test_variance_ratio_matches_variance(self, low_rank_sparse): + """variance_ratio should be variance / total_var, same total for both.""" + X, _, _ = low_rank_sparse + k = 10 + out_arp = _pca_with_sparse(X, k, svd_solver="arpack", seed=0) + out_rnd = _pca_with_sparse(X, k, svd_solver="randomized", seed=0) + # Both divide by the same total (sum of column variances of X) + total_arp = out_arp["variance"][0] / out_arp["variance_ratio"][0] + total_rnd = out_rnd["variance"][0] / out_rnd["variance_ratio"][0] + assert abs(total_arp - total_rnd) / total_arp < 1e-6 + + +# --------------------------------------------------------------------------- +# Ground truth — compare against sklearn dense PCA +# --------------------------------------------------------------------------- + + +class TestVsSklearnDense: + """sklearn.PCA on the densified matrix is the ground-truth reference.""" + + @pytest.fixture(scope="class") + def sklearn_ref(self, low_rank_sparse): + X, _, _ = low_rank_sparse + pca = PCA(n_components=50, svd_solver="full") + X_pca = pca.fit_transform(X.toarray()) + return { + "X_pca": X_pca, + "components": pca.components_, + "variance": pca.explained_variance_, + } + + @pytest.mark.parametrize("svd_solver", ["arpack", "randomized"]) + def test_variance_vs_sklearn(self, low_rank_sparse, sklearn_ref, svd_solver): + """Explained variances match sklearn within 1% on the top PCs.""" + X, _, _ = low_rank_sparse + out = _pca_with_sparse(X, 50, svd_solver=svd_solver, seed=0) + var_ref = sklearn_ref["variance"][:10] + var_got = out["variance"][:10] + rel_err = np.abs(var_ref - var_got) / var_ref + assert rel_err.max() < 0.01, ( + f"{svd_solver}: top-10 variance differs from sklearn by {rel_err.max():.4f}" + ) + + @pytest.mark.parametrize("svd_solver", ["arpack", "randomized"]) + def test_subspace_vs_sklearn(self, low_rank_sparse, sklearn_ref, svd_solver): + """Component subspace aligns with sklearn's (within the signal rank).""" + X, _, _ = low_rank_sparse + out = _pca_with_sparse(X, 50, svd_solver=svd_solver, seed=0) + angles = subspace_angles( + sklearn_ref["components"][:_SIGNAL_PCS].T, + out["components"][:_SIGNAL_PCS].T, + ) + assert angles.max() < 0.1 + + +# --------------------------------------------------------------------------- +# API & edge cases +# --------------------------------------------------------------------------- + + +class TestAPIAndEdgeCases: + def test_arpack_is_default(self, low_rank_sparse): + """Omitting svd_solver picks ARPACK — no behaviour change for callers.""" + X, _, _ = low_rank_sparse + out_default = _pca_with_sparse(X, 10, seed=0) + out_arpack = _pca_with_sparse(X, 10, svd_solver="arpack", seed=0) + np.testing.assert_array_equal(out_default["variance"], out_arpack["variance"]) + + def test_randomized_rejects_mu_axis_1(self, low_rank_sparse): + """Row-centering (mu_axis=1) is not supported on the randomized path.""" + X, _, _ = low_rank_sparse + with pytest.raises(ValueError, match="mu_axis=0"): + _pca_with_sparse(X, 10, svd_solver="randomized", mu_axis=1) + + def test_randomized_accepts_precomputed_mu(self, low_rank_sparse): + """Passing an explicit mean should give the same result as auto-computing it.""" + X, _, _ = low_rank_sparse + mu = np.asarray(X.mean(axis=0)) + out_auto = randomized_svd_implicit_center(X, 10, mu=None, seed=0) + out_manual = randomized_svd_implicit_center(X, 10, mu=mu, seed=0) + np.testing.assert_allclose(out_auto["variance"], out_manual["variance"], rtol=1e-10) + + def test_randomized_seeded_determinism(self, low_rank_sparse): + """Same seed → identical output; different seed → different sketch.""" + X, _, _ = low_rank_sparse + out1 = randomized_svd_implicit_center(X, 10, seed=7) + out2 = randomized_svd_implicit_center(X, 10, seed=7) + out3 = randomized_svd_implicit_center(X, 10, seed=8) + np.testing.assert_array_equal(out1["X_pca"], out2["X_pca"]) + # With well-separated spectrum the *variances* converge regardless of + # seed, but X_pca bytes will differ through the random sketch. + assert not np.array_equal(out1["X_pca"], out3["X_pca"]) + + def test_n_power_improves_tail_accuracy(self, low_rank_sparse): + """More power iterations → tighter match to ARPACK on the tail PCs. + + This is a sanity check on the power-iteration plumbing. With + n_power=0 the tail diverges; with n_power=4 it should tighten up. + """ + X, _, _ = low_rank_sparse + out_ref = _pca_with_sparse(X, 30, svd_solver="arpack", seed=0) + + out_p0 = randomized_svd_implicit_center(X, 30, n_power=0, seed=0) + out_p4 = randomized_svd_implicit_center(X, 30, n_power=4, seed=0) + + # Compare tail (PCs 20-30) explained-variance error + err_p0 = np.abs(out_p0["variance"][20:] - out_ref["variance"][20:]).sum() + err_p4 = np.abs(out_p4["variance"][20:] - out_ref["variance"][20:]).sum() + # Power iterations must help, or the plumbing is broken + assert err_p4 < err_p0 + + +# --------------------------------------------------------------------------- +# GPU +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not HAS_CUPY, reason="cupy not installed") +class TestGPU: + """Same comparisons with a CUDA backend.""" + + def test_gpu_matches_cpu_randomized(self, low_rank_sparse): + """GPU and CPU randomized SVD recover the same subspace. + + We don't compare bytes (cupy and numpy RNG streams differ, and float + accumulation order differs on GPU) — we compare subspace angles. + """ + X, _, _ = low_rank_sparse + bk_gpu = Backend("cuda") + X_gpu = bk_gpu.to_device(X) + + out_cpu = randomized_svd_implicit_center(X, 30, seed=0, bk=Backend("cpu")) + out_gpu = randomized_svd_implicit_center(X_gpu, 30, seed=0, bk=bk_gpu) + + angles = subspace_angles( + out_cpu["components"][:_SIGNAL_PCS].T, + out_gpu["components"][:_SIGNAL_PCS].T, + ) + assert angles.max() < 0.1 + + # Variances should agree to within numerical tolerance — both are + # approximating the same singular values. + np.testing.assert_allclose(out_cpu["variance"][:10], out_gpu["variance"][:10], rtol=0.02) + + def test_gpu_vs_arpack(self, low_rank_sparse): + """GPU randomized SVD matches CPU ARPACK ground truth.""" + X, _, _ = low_rank_sparse + bk_gpu = Backend("cuda") + X_gpu = bk_gpu.to_device(X) + + out_ref = _pca_with_sparse(X, 30, svd_solver="arpack", seed=0) + out_gpu = randomized_svd_implicit_center(X_gpu, 30, seed=0, bk=bk_gpu) + + rel_err = np.abs(out_ref["variance"][:10] - out_gpu["variance"][:10]) + rel_err /= out_ref["variance"][:10] + assert rel_err.max() < 0.02 diff --git a/tests/unit/test_projection.py b/tests/unit/test_projection.py new file mode 100644 index 0000000..4d89cc4 --- /dev/null +++ b/tests/unit/test_projection.py @@ -0,0 +1,341 @@ +"""Equivalence tests for the precomposed feature-translation path. + +These pin the new :func:`_mapping_window_fast` / :func:`_compute_sigma` / +:func:`_projection_precompute` against a direct reimplementation of the +legacy materialise-Xtr path. Both must agree to ~1e-6 rtol — the rewrite +is an algebraic reshuffling, not an approximation. + +If these tests start failing after the backward-compat ``_mapping_window`` +wrapper is removed, they should still pass: they exercise +:func:`_mapping_window_fast` directly. +""" + +from __future__ import annotations + +import numpy as np +import pytest +import scipy.sparse as spp +from sklearn.preprocessing import StandardScaler + +from samap.core._backend import Backend +from samap.core.projection import ( + _compute_sigma, + _mapping_window_fast, + _projection_precompute, +) + +# --------------------------------------------------------------------------- # +# Fixtures — small synthetic 2- and 3-species inputs # +# --------------------------------------------------------------------------- # + + +class _MockAdata: + """Minimal adata stand-in: just the fields projection.py reads.""" + + def __init__( + self, + X: spp.csr_matrix, + var_names: np.ndarray, + weights: np.ndarray, + PCs: np.ndarray, + ) -> None: + self.X = X + self.var_names = var_names + import pandas as pd + + self.var = pd.DataFrame({"weights": weights}, index=var_names) + self.varm = {"PCs_SAMap": PCs} + + def __getitem__(self, key): # adata[:, gene_names] + _, cols = key + # map gene names → column indices (projection.py slices by var_names match) + name_to_ix = {n: i for i, n in enumerate(self.var_names)} + ix = np.array([name_to_ix[c] for c in cols]) + return _MockAdata( + X=self.X[:, ix], + var_names=self.var_names[ix], + weights=self.var["weights"].values[ix], + PCs=self.varm["PCs_SAMap"][ix], + ) + + +class _MockSAM: + def __init__(self, adata: _MockAdata) -> None: + self.adata = adata + + +def _make_species( + sid: str, + n_cells: int, + n_genes: int, + npcs: int, + rng: np.random.Generator, +) -> tuple[_MockSAM, np.ndarray]: + """Build one mock species with random sparse counts, weights, and PC loadings.""" + var_names = np.array([f"{sid}_gene{i:03d}" for i in range(n_genes)]) + X = spp.random(n_cells, n_genes, density=0.35, format="csr", random_state=rng) + X.data *= 10 + X = X.astype(np.float64) + weights = rng.uniform(0.1, 1.0, n_genes) + PCs = rng.standard_normal((n_genes, npcs)).astype(np.float64) + sam = _MockSAM(_MockAdata(X, var_names, weights, PCs)) + return sam, var_names + + +def _make_gnnm( + gns_list: list[np.ndarray], + rng: np.random.Generator, + density: float = 0.2, +) -> tuple[spp.csr_matrix, np.ndarray]: + """Random block-off-diagonal homology graph (no within-species edges).""" + gns = np.concatenate(gns_list) + g_total = gns.size + # build per-species-pair off-diagonal blocks + sizes = [g.size for g in gns_list] + offsets = np.cumsum([0, *sizes]) + A = spp.lil_matrix((g_total, g_total), dtype=np.float64) + for i in range(len(sizes)): + for j in range(len(sizes)): + if i == j: + continue + r0, r1 = offsets[i], offsets[i + 1] + c0, c1 = offsets[j], offsets[j + 1] + block = spp.random(sizes[i], sizes[j], density=density, format="csr", random_state=rng) + block.data = np.abs(block.data) + 0.01 # strictly positive edges + A[r0:r1, c0:c1] = block + return A.tocsr(), gns + + +@pytest.fixture +def bk() -> Backend: + return Backend("cpu") + + +@pytest.fixture +def synth2(bk): # 2 species + rng = np.random.default_rng(42) + sam_a, gns_a = _make_species("aa", n_cells=60, n_genes=25, npcs=8, rng=rng) + sam_b, gns_b = _make_species("bb", n_cells=45, n_genes=20, npcs=6, rng=rng) + sams = {"aa": sam_a, "bb": sam_b} + gnnm, gns = _make_gnnm([gns_a, gns_b], rng) + return sams, gnnm, gns + + +@pytest.fixture +def synth3(bk): # 3 species + rng = np.random.default_rng(123) + sam_a, gns_a = _make_species("aa", n_cells=40, n_genes=18, npcs=5, rng=rng) + sam_b, gns_b = _make_species("bb", n_cells=35, n_genes=15, npcs=5, rng=rng) + sam_c, gns_c = _make_species("cc", n_cells=30, n_genes=12, npcs=4, rng=rng) + sams = {"aa": sam_a, "bb": sam_b, "cc": sam_c} + gnnm, gns = _make_gnnm([gns_a, gns_b, gns_c], rng) + return sams, gnnm, gns + + +# --------------------------------------------------------------------------- # +# Reference implementation — legacy materialise-Xtr path # +# --------------------------------------------------------------------------- # + + +def _legacy_wpca(sams, gnnm, gns, pairwise: bool): + """Direct transcription of the pre-refactor _mapping_window wpca logic. + + Kept here as a test oracle — the production path no longer materialises Xtr. + """ + from samap.core.homology import _tanh_scale + from samap.utils import q as _q + + std = StandardScaler(with_mean=False) + + gnnm_corr = gnnm.copy() + gnnm_corr.data[:] = _tanh_scale(gnnm_corr.data) + + gs, adatas, Ws, ss = {}, {}, {}, {} + species_indexer, genes_indexer = [], [] + for sid in sams: + gs[sid] = gns[np.isin(gns, _q(sams[sid].adata.var_names))] + adatas[sid] = sams[sid].adata[:, gs[sid]] + Ws[sid] = adatas[sid].var["weights"].values + ss[sid] = std.fit_transform(adatas[sid].X).multiply(Ws[sid][None, :]).tocsr() + species_indexer.append(np.arange(ss[sid].shape[0])) + genes_indexer.append(np.arange(gs[sid].size)) + for i in range(1, len(species_indexer)): + species_indexer[i] += species_indexer[i - 1].max() + 1 + genes_indexer[i] += genes_indexer[i - 1].max() + 1 + + su = np.asarray(gnnm_corr.sum(0)) + su[su == 0] = 1 + gnnm_corr = gnnm_corr.multiply(1 / su).tocsr() + + X = spp.block_diag(list(ss.values())).tocsr() + W = np.concatenate(list(Ws.values())).flatten() + + if pairwise: + Xtr_rows = [] + for i in range(len(sams)): + xtr = [] + for j in range(len(sams)): + if i != j: + gsub = gnnm_corr[genes_indexer[i]][:, genes_indexer[j]] + su = np.asarray(gsub.sum(0)) + su[su == 0] = 1 + gsub = gsub.multiply(1 / su).tocsr() + x = X[species_indexer[i]][:, genes_indexer[i]].dot(gsub) + xtr.append(std.fit_transform(x).multiply(W[genes_indexer[j]][None, :])) + else: + xtr.append(spp.csr_matrix((species_indexer[i].size, genes_indexer[i].size))) + Xtr_rows.append(spp.hstack(xtr)) + Xtr = spp.vstack(Xtr_rows) + else: + Xtr_rows = [] + for i in range(len(sams)): + x = X[species_indexer[i]].dot(gnnm_corr) + Xtr_rows.append(std.fit_transform(x).multiply(W[None, :])) + Xtr = spp.vstack(Xtr_rows) + Xc = (X + Xtr).tocsr() + + mus = [np.asarray(Xc[species_indexer[i]].mean(0)).flatten() for i in range(len(sams))] + + import scipy as sp_full + + C = sp_full.linalg.block_diag(*[adatas[sid].varm["PCs_SAMap"] for sid in sams]) + M = np.vstack(mus).dot(C) + it = 0 + PCAs = [] + for sid in sams: + PCAs.append(Xc[:, it : it + gs[sid].size].dot(adatas[sid].varm["PCs_SAMap"])) + it += gs[sid].size + wpca = np.hstack(PCAs) + for i in range(len(sams)): + wpca[species_indexer[i]] -= M[i] + + return wpca, gnnm_corr + + +# --------------------------------------------------------------------------- # +# Tests # +# --------------------------------------------------------------------------- # + + +class TestComputeSigma: + """Sigma quadratic-form must match sklearn's StandardScaler exactly.""" + + @pytest.mark.parametrize("seed", [0, 1, 42, 999]) + def test_matches_sklearn(self, bk, seed): + rng = np.random.default_rng(seed) + n, g1, g2 = 50, 30, 20 + X = spp.random(n, g1, density=0.3, format="csr", random_state=rng).astype(np.float64) + G = spp.random(g1, g2, density=0.25, format="csr", random_state=rng).astype(np.float64) + + truth = StandardScaler(with_mean=False).fit(X @ G).scale_ + + XtX = (X.T @ X).tocsr() + mu = np.asarray(X.mean(0)).flatten() + sigma = _compute_sigma(XtX, mu, G, n, bk) + + np.testing.assert_allclose(sigma, truth, rtol=1e-12, atol=1e-14) + + def test_zero_variance_columns_map_to_one(self, bk): + """StandardScaler replaces zero-variance scale with 1.0; so must we.""" + n, g1, g2 = 20, 10, 5 + X = spp.random(n, g1, density=0.3, format="csr", random_state=0).astype(np.float64) + # G with an all-zero column → zero-variance output column + G = spp.random(g1, g2, density=0.3, format="csr", random_state=1).astype(np.float64).tolil() + G[:, 2] = 0 + G = G.tocsr() + + truth = StandardScaler(with_mean=False).fit(X @ G).scale_ + + XtX = (X.T @ X).tocsr() + mu = np.asarray(X.mean(0)).flatten() + sigma = _compute_sigma(XtX, mu, G, n, bk) + + assert sigma[2] == 1.0 + np.testing.assert_allclose(sigma, truth, rtol=1e-12, atol=1e-14) + + +class TestWPCAEquivalence: + """The full wpca output must match the legacy materialise-Xtr path.""" + + def test_2species_pairwise(self, synth2, bk): + sams, gnnm, gns = synth2 + pre = _projection_precompute(sams, gns, bk) + out = _mapping_window_fast(gnnm, pre, K=5, pairwise=True) + wpca_old, gnnm_corr_old = _legacy_wpca(sams, gnnm, gns, pairwise=True) + + np.testing.assert_allclose(out["wPCA"], wpca_old, rtol=1e-6, atol=1e-10) + np.testing.assert_allclose( + out["gnnm_corr"].toarray(), gnnm_corr_old.toarray(), rtol=1e-12, atol=1e-14 + ) + + def test_2species_all_to_all(self, synth2, bk): + sams, gnnm, gns = synth2 + pre = _projection_precompute(sams, gns, bk) + out = _mapping_window_fast(gnnm, pre, K=5, pairwise=False) + wpca_old, _ = _legacy_wpca(sams, gnnm, gns, pairwise=False) + + np.testing.assert_allclose(out["wPCA"], wpca_old, rtol=1e-6, atol=1e-10) + + def test_3species_pairwise(self, synth3, bk): + sams, gnnm, gns = synth3 + pre = _projection_precompute(sams, gns, bk) + out = _mapping_window_fast(gnnm, pre, K=5, pairwise=True) + wpca_old, _ = _legacy_wpca(sams, gnnm, gns, pairwise=True) + + np.testing.assert_allclose(out["wPCA"], wpca_old, rtol=1e-6, atol=1e-10) + + def test_3species_all_to_all(self, synth3, bk): + """3+ species: pairwise vs all-to-all differ due to normalisation scope. + + With 2 species the global and per-pair column-normalisations of the + homology graph coincide; with 3+ they don't (each column gets + contributions from multiple species). This test guards the all-to-all + branch specifically. + """ + sams, gnnm, gns = synth3 + pre = _projection_precompute(sams, gns, bk) + out = _mapping_window_fast(gnnm, pre, K=5, pairwise=False) + wpca_old, _ = _legacy_wpca(sams, gnnm, gns, pairwise=False) + + np.testing.assert_allclose(out["wPCA"], wpca_old, rtol=1e-6, atol=1e-10) + + def test_precompute_is_iteration_invariant(self, synth2, bk): + """Precompute dict shouldn't depend on gnnm — reuse across iterations.""" + sams, gnnm, gns = synth2 + pre = _projection_precompute(sams, gns, bk) + + # Two different homology graphs, same precompute + rng = np.random.default_rng(7) + gnnm2 = gnnm.copy() + gnnm2.data = rng.uniform(0.01, 1.0, gnnm2.data.size) + + out1 = _mapping_window_fast(gnnm, pre, K=5, pairwise=True) + out2 = _mapping_window_fast(gnnm2, pre, K=5, pairwise=True) + + # Different inputs → different outputs (sanity: precompute isn't stale-caching) + assert not np.allclose(out1["wPCA"], out2["wPCA"]) + + # But both match their respective legacy oracles + wpca_old1, _ = _legacy_wpca(sams, gnnm, gns, pairwise=True) + wpca_old2, _ = _legacy_wpca(sams, gnnm2, gns, pairwise=True) + np.testing.assert_allclose(out1["wPCA"], wpca_old1, rtol=1e-6, atol=1e-10) + np.testing.assert_allclose(out2["wPCA"], wpca_old2, rtol=1e-6, atol=1e-10) + + +class TestBackwardCompatWrapper: + """The old _mapping_window signature still works (via internal precompute).""" + + def test_wrapper_equivalent_to_fast_path(self, synth2, bk): + from samap.core.projection import _mapping_window + + sams, gnnm, gns = synth2 + out_wrapper = _mapping_window(sams, gnnm, gns, K=5, pairwise=True) + + pre = _projection_precompute(sams, gns, bk) + out_fast = _mapping_window_fast(gnnm, pre, K=5, pairwise=True) + + np.testing.assert_allclose(out_wrapper["wPCA"], out_fast["wPCA"], rtol=1e-12) + # knn structure should be identical since wpca is identical + np.testing.assert_allclose( + out_wrapper["knn"].toarray(), out_fast["knn"].toarray(), rtol=1e-12 + ) diff --git a/tests/unit/test_rsc_compat.py b/tests/unit/test_rsc_compat.py new file mode 100644 index 0000000..3d4184b --- /dev/null +++ b/tests/unit/test_rsc_compat.py @@ -0,0 +1,79 @@ +"""Tests for the rapids-singlecell optional-dispatch layer. + +Only the CPU fallback path is testable here (rsc not installed in CI). +The GPU path is a thin passthrough to rsc — trusted upstream. +""" + +from __future__ import annotations + +import numpy as np +import scipy.sparse as sp +from anndata import AnnData + +from samap._rsc_compat import HAS_RSC, leiden, umap +from samap.core._backend import Backend + + +def _tiny_adata_with_neighbors() -> AnnData: + """Minimal AnnData with pre-set neighbors — enough for sc.tl.umap/leiden.""" + rng = np.random.default_rng(0) + n = 40 + # 2D blobs so UMAP/Leiden have something to find + X = np.vstack( + [ + rng.normal([0, 0], 0.1, (n // 2, 2)), + rng.normal([5, 5], 0.1, (n // 2, 2)), + ] + ).astype(np.float32) + adata = AnnData(X) + # Fake a connectivity graph (scanpy's neighbors output) + from sklearn.neighbors import kneighbors_graph + + conn = kneighbors_graph(X, 5, mode="connectivity", include_self=False) + adata.obsp["connectivities"] = sp.csr_matrix(conn) + adata.obsp["distances"] = kneighbors_graph(X, 5, mode="distance") + adata.uns["neighbors"] = {"params": {"n_neighbors": 5, "method": "umap"}} + return adata + + +class TestRscCompatModule: + def test_has_rsc_false_in_test_env(self): + """rsc is not installed in the CI env — dispatch takes CPU path.""" + assert HAS_RSC is False + + def test_imports_cleanly_without_rsc(self): + """Module must import without rsc present (optional dependency).""" + # Re-import to prove the import itself doesn't require rsc + import importlib + + import samap._rsc_compat as mod + + importlib.reload(mod) + assert mod.HAS_RSC is False + + +class TestCPUFallback: + """With a CPU backend and no rsc, both wrappers should call scanpy.""" + + def test_umap_cpu_path(self): + adata = _tiny_adata_with_neighbors() + bk = Backend("cpu") + # Should not raise; writes X_umap + umap(adata, bk) + assert "X_umap" in adata.obsm + assert adata.obsm["X_umap"].shape == (40, 2) + + def test_leiden_cpu_path(self): + adata = _tiny_adata_with_neighbors() + bk = Backend("cpu") + leiden(adata, bk, resolution=0.5) + assert "leiden" in adata.obs + # Two well-separated blobs → should find ≥2 clusters + assert adata.obs["leiden"].nunique() >= 2 + + def test_leiden_respects_key_added(self): + adata = _tiny_adata_with_neighbors() + bk = Backend("cpu") + leiden(adata, bk, key_added="my_clusters", resolution=0.5) + assert "my_clusters" in adata.obs + assert "leiden" not in adata.obs # default key not used diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2d7bec6..1e34d50 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd -import pytest import scipy.sparse as sp from samap.utils import df_to_dict, sparse_knn, substr, to_vn, to_vo @@ -149,9 +148,6 @@ class TestToVo: def test_converts_back_from_vn(self) -> None: """Test that to_vo converts semicolon strings back to pairs.""" - # This test requires samalg to be installed - pytest.importorskip("samalg") - vn_strings = np.array(["gene1;gene2", "gene3;gene4"]) result = to_vo(vn_strings)