Skip to content

Commit f34e8c3

Browse files
feat: add SVG-aware gene selection and spatial coherence validation
Introduce Moran's I-based spatial variability scoring to improve gene vocabulary selection and add a spatial coherence metric for validation. Spatial Statistics Module (NEW): - data/spatial_stats.py: lightweight Moran's I via KNN weights (numpy + scipy only, no new dependencies) - morans_i(), morans_i_batch(), spatial_coherence_score() SVG-aware Gene Selection: - build_vocab.py: --svg-weight (0-1) blends expression rank with Moran's I rank; --svg-k controls KNN graph size - Default svg_weight=0.0 preserves original behaviour - Stats CSV now includes morans_i column Spatial Coherence Validation: - engine.py: computes Moran's I correlation between predicted and ground-truth expression on top-50 SVGs during validation - train.py: logs spatial_coherence to SQLite Tests: - test_spatial_stats.py: 14 tests covering Moran's I (uniform, clustered, checkerboard, gradient) and coherence scoring Docs: - SC_BEST_PRACTICES.md: marked SVG selection and spatial coherence as implemented
1 parent 7d75160 commit f34e8c3

7 files changed

Lines changed: 515 additions & 13 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Visualization plots and spatial expression maps will be saved to the `./results`
113113
- **[Pathway Mapping](docs/PATHWAY_MAPPING.md)**: Clinical interpretability, pathway bottleneck design, and MSigDB integration.
114114
- **[Gene Analysis](docs/GENE_ANALYSIS.md)**: Modeling strategies for mapping morphology to high-dimensional gene spaces.
115115
- **[Data Structure](docs/DATA_STRUCTURE.md)**: Detailed breakdown of the HEST data structure on disk, metadata conventions, and preprocessing invariants.
116-
- **[Single-cell Best Practices](docs/SC_BEST_PRACTICES.md)**: Gap analysis and roadmap for alignment with industry standard recommendations.
116+
- **[Single-cell Best Practices](docs/SC_BEST_PRACTICES.md)**: Gap analysis and roadmap for alignment with standard recommendations.
117117

118118
## Development
119119

docs/SC_BEST_PRACTICES.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ These are areas where the project already follows industry best practices:
2121

2222
The following items are recommended for future sprints to improve model robustness and biological accuracy.
2323

24-
### 1. SVG-aware Gene Selection (Moran's I)
24+
### 1. SVG-aware Gene Selection (Moran's I)
2525

26-
**Priority: High**
26+
**Priority: High** **Implemented**
2727
**Rationale**: Currently, genes are selected based on total expression or pathway membership. However, the model's primary task is to learn spatial patterns. Selecting genes based on **Spatially Variable Gene (SVG)** metrics like Moran's I (available in Squidpy) would prioritise genes that have learned spatial coherence over those that are just highly expressed (like housekeeping genes).
2828

29+
**Usage**: `stf-build-vocab --svg-weight 0.5 --svg-k 6` enables a hybrid ranking that blends total expression with Moran's I spatial variability. See `data/spatial_stats.py` for the implementation.
30+
2931
### 2. Standardised Preprocessing Pipeline
3032

3133
**Priority: Medium-High**
@@ -41,10 +43,12 @@ The following items are recommended for future sprints to improve model robustne
4143
**Priority: Medium**
4244
**Rationale**: Adding explicit QC thresholds (e.g., minimum UMI count, minimum detected genes, maximum mitochondrial fraction) to the dataset loading scripts would protect the model from training on low-quality "noise" spots.
4345

44-
### 5. Spatial Coherence Validation Metrics
46+
### 5. Spatial Coherence Validation Metrics
4547

46-
**Priority: Medium**
47-
**Rationale**: Aggregate metrics like MSE or PCC don't capture whether the *spatial distribution* of predictions is realistic. Adding a validation step that compares the Moran's I of predicted vs. ground-truth expression would provide a much stronger biological validation signal.
48+
**Priority: Medium****Implemented**
49+
**Rationale**: Aggregate metrics like MSE or PCC don't capture whether the *spatial distribution* of predictions is realistic. A validation step now compares the Moran's I of predicted vs. ground-truth expression for the top-50 spatially variable genes, reporting a Pearson correlation as the **Spatial Coherence Score**.
50+
51+
**Integration**: Computed automatically during validation in `training/engine.py` and logged to SQLite as `spatial_coherence`. See `data/spatial_stats.py:spatial_coherence_score()`.
4852

4953
### 6. Preprocessing Documentation
5054

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""
2+
Spatial statistics utilities for gene selection.
3+
4+
Provides lightweight, dependency-free Moran's I computation for
5+
identifying spatially variable genes (SVGs) from spatial
6+
transcriptomics data.
7+
8+
Moran's I measures spatial autocorrelation: whether nearby spots tend
9+
to have similar (positive I) or dissimilar (negative I) expression
10+
for a given gene. Genes with high Moran's I show distinct spatial
11+
patterns and are the strongest learning targets for
12+
SpatialTranscriptFormer.
13+
"""
14+
15+
import numpy as np
16+
from scipy.spatial import KDTree
17+
from scipy.sparse import csr_matrix
18+
19+
20+
def _build_knn_weights(coords: np.ndarray, k: int = 6) -> csr_matrix:
21+
"""Build a row-normalised KNN spatial weight matrix.
22+
23+
Args:
24+
coords: (N, 2) array of spatial coordinates.
25+
k: Number of nearest neighbours per spot.
26+
27+
Returns:
28+
(N, N) sparse CSR matrix where ``W[i, j] = 1/k`` if j is one
29+
of the k nearest neighbours of i, else 0. Row-normalisation
30+
ensures that the weight contribution is independent of local
31+
spot density.
32+
"""
33+
n = coords.shape[0]
34+
tree = KDTree(coords)
35+
# k+1 because the first neighbour returned is the point itself
36+
_, indices = tree.query(coords, k=min(k + 1, n))
37+
38+
rows = []
39+
cols = []
40+
for i in range(n):
41+
neighbours = indices[i]
42+
neighbours = neighbours[neighbours != i][:k]
43+
for j in neighbours:
44+
rows.append(i)
45+
cols.append(j)
46+
47+
data = np.ones(len(rows), dtype=np.float64) / k
48+
W = csr_matrix((data, (rows, cols)), shape=(n, n))
49+
return W
50+
51+
52+
def morans_i(x: np.ndarray, W: csr_matrix) -> float:
53+
"""Compute Moran's I for a single variable.
54+
55+
.. math::
56+
57+
I = \\frac{N}{W_{sum}} \\cdot
58+
\\frac{\\sum_i \\sum_j w_{ij} (x_i - \\bar{x})(x_j - \\bar{x})}
59+
{\\sum_i (x_i - \\bar{x})^2}
60+
61+
Args:
62+
x: (N,) array of values (e.g. gene expression per spot).
63+
W: (N, N) sparse spatial weight matrix.
64+
65+
Returns:
66+
Moran's I statistic. Ranges roughly from -1 (perfect
67+
dispersion) through 0 (random) to +1 (perfect clustering).
68+
Returns 0.0 if variance is zero (constant gene).
69+
"""
70+
n = len(x)
71+
x_mean = x.mean()
72+
z = x - x_mean
73+
74+
denominator = np.sum(z ** 2)
75+
if denominator < 1e-12:
76+
return 0.0 # Constant expression → no spatial pattern
77+
78+
# W @ z gives the spatially-lagged deviation for each spot
79+
lag = W.dot(z)
80+
numerator = np.sum(z * lag)
81+
82+
W_sum = W.sum()
83+
if W_sum < 1e-12:
84+
return 0.0
85+
86+
I = (n / W_sum) * (numerator / denominator)
87+
return float(I)
88+
89+
90+
def morans_i_batch(
91+
expression: np.ndarray,
92+
coords: np.ndarray,
93+
k: int = 6,
94+
) -> np.ndarray:
95+
"""Compute Moran's I for all genes in an expression matrix.
96+
97+
Args:
98+
expression: (N, G) dense expression matrix (spots × genes).
99+
coords: (N, 2) spatial coordinates for each spot.
100+
k: Number of nearest neighbours for the spatial weight graph.
101+
102+
Returns:
103+
(G,) array of Moran's I scores, one per gene.
104+
"""
105+
if expression.shape[0] < k + 1:
106+
# Too few spots to build a meaningful KNN graph
107+
return np.zeros(expression.shape[1], dtype=np.float64)
108+
109+
W = _build_knn_weights(coords, k=k)
110+
n_genes = expression.shape[1]
111+
scores = np.empty(n_genes, dtype=np.float64)
112+
113+
for g in range(n_genes):
114+
scores[g] = morans_i(expression[:, g], W)
115+
116+
return scores
117+
118+
119+
def spatial_coherence_score(
120+
predicted: np.ndarray,
121+
ground_truth: np.ndarray,
122+
coords: np.ndarray,
123+
k: int = 6,
124+
top_k_genes: int = 50,
125+
) -> float:
126+
"""Compare spatial structure of predictions vs ground truth.
127+
128+
Computes Moran's I for both the predicted and ground-truth
129+
expression matrices, then returns the Pearson correlation between
130+
the two Moran's I vectors. A score near 1.0 means the model
131+
reproduces the correct spatial patterns; near 0 means random.
132+
133+
To keep computation fast (this runs every validation epoch), only
134+
the ``top_k_genes`` with highest ground-truth spatial variability
135+
are evaluated.
136+
137+
Args:
138+
predicted: (N, G) predicted expression matrix.
139+
ground_truth: (N, G) ground-truth expression matrix.
140+
coords: (N, 2) spatial coordinates.
141+
k: KNN neighbours for the spatial weight graph.
142+
top_k_genes: Number of top-Moran's-I genes to evaluate.
143+
144+
Returns:
145+
Pearson correlation between predicted and ground-truth
146+
Moran's I vectors. Returns 0.0 if computation fails.
147+
"""
148+
n_spots, n_genes = ground_truth.shape
149+
if n_spots < k + 1 or n_genes < 2:
150+
return 0.0
151+
152+
W = _build_knn_weights(coords, k=k)
153+
154+
# Compute Moran's I for ground truth
155+
mi_gt = np.empty(n_genes, dtype=np.float64)
156+
for g in range(n_genes):
157+
mi_gt[g] = morans_i(ground_truth[:, g], W)
158+
159+
# Select top-K genes by ground-truth Moran's I (most spatially variable)
160+
top_indices = np.argsort(mi_gt)[-top_k_genes:]
161+
162+
# Compute Moran's I for predictions on those genes only
163+
mi_pred = np.empty(len(top_indices), dtype=np.float64)
164+
mi_gt_top = mi_gt[top_indices]
165+
for i, g in enumerate(top_indices):
166+
mi_pred[i] = morans_i(predicted[:, g], W)
167+
168+
# Pearson correlation between the two Moran's I vectors
169+
if np.std(mi_gt_top) < 1e-12 or np.std(mi_pred) < 1e-12:
170+
return 0.0
171+
172+
corr = np.corrcoef(mi_gt_top, mi_pred)[0, 1]
173+
return float(corr) if np.isfinite(corr) else 0.0

src/spatial_transcript_former/recipes/hest/build_vocab.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99
from collections import defaultdict
1010
from scipy.sparse import csr_matrix
11+
from spatial_transcript_former.data.spatial_stats import morans_i_batch
1112

1213
# Add src to path
1314
sys.path.append(os.path.abspath("src"))
@@ -42,15 +43,23 @@ def scan_h5ad_files(data_dir):
4243
return sample_ids
4344

4445

45-
def calculate_global_genes(data_dir, ids, num_genes=1000, target_pathways=None):
46+
def calculate_global_genes(
47+
data_dir, ids, num_genes=1000, target_pathways=None,
48+
svg_weight=0.0, svg_k=6,
49+
):
4650
st_dir = os.path.join(data_dir, "st")
4751
if not ids:
4852
print("No samples provided for calculation.")
4953
return [], []
5054

5155
print(f"Scanning {len(ids)} samples in {st_dir}...")
56+
if svg_weight > 0:
57+
print(f"SVG mode: weight={svg_weight}, k={svg_k}")
5258

5359
gene_totals = defaultdict(float)
60+
# Moran's I accumulators (sum and count for averaging across samples)
61+
gene_morans_sum = defaultdict(float)
62+
gene_morans_count = defaultdict(int)
5463

5564
for sample_id in tqdm(ids):
5665
h5ad_path = os.path.join(st_dir, f"{sample_id}.h5ad")
@@ -78,10 +87,27 @@ def calculate_global_genes(data_dir, ids, num_genes=1000, target_pathways=None):
7887
for i, gene in enumerate(gene_names):
7988
gene_totals[gene] += float(sums[i])
8089

90+
# --- SVG: compute Moran's I per gene for this sample ---
91+
if svg_weight > 0 and "obsm" in f and "spatial" in f["obsm"]:
92+
coords = f["obsm"]["spatial"][:]
93+
# Densify the expression matrix for Moran's I
94+
if isinstance(mat, csr_matrix):
95+
dense_mat = mat.toarray()
96+
else:
97+
dense_mat = np.asarray(mat)
98+
99+
mi_scores = morans_i_batch(dense_mat, coords, k=svg_k)
100+
101+
for i, gene in enumerate(gene_names):
102+
gene_morans_sum[gene] += mi_scores[i]
103+
gene_morans_count[gene] += 1
104+
81105
except Exception as e:
82106
print(f"Error processing {sample_id}: {e}")
83107

84108
print(f"Aggregated counts for {len(gene_totals)} unique genes.")
109+
if svg_weight > 0:
110+
print(f"Computed Moran's I for {len(gene_morans_sum)} genes.")
85111

86112
prioritized_genes = set()
87113
if target_pathways:
@@ -108,18 +134,61 @@ def calculate_global_genes(data_dir, ids, num_genes=1000, target_pathways=None):
108134

109135
print(f"Found {len(prioritized_genes)} valid target pathway genes.")
110136

111-
# Sort all by total expression
112-
sorted_all = sorted(gene_totals.items(), key=lambda x: x[1], reverse=True)
137+
# --- Ranking: expression-only or hybrid ---
138+
all_genes = list(gene_totals.keys())
139+
140+
if svg_weight > 0 and gene_morans_sum:
141+
# Compute average Moran's I per gene
142+
gene_morans_avg = {
143+
g: gene_morans_sum[g] / gene_morans_count[g]
144+
for g in all_genes
145+
if gene_morans_count.get(g, 0) > 0
146+
}
147+
148+
# Rank by expression (lower rank = higher expression)
149+
expr_sorted = sorted(all_genes, key=lambda g: gene_totals[g], reverse=True)
150+
expr_rank = {g: r for r, g in enumerate(expr_sorted)}
151+
152+
# Rank by Moran's I (lower rank = higher spatial variability)
153+
mi_sorted = sorted(
154+
all_genes, key=lambda g: gene_morans_avg.get(g, 0.0), reverse=True
155+
)
156+
mi_rank = {g: r for r, g in enumerate(mi_sorted)}
157+
158+
# Hybrid score: weighted sum of ranks (lower = better)
159+
alpha = svg_weight
160+
hybrid_score = {
161+
g: (1 - alpha) * expr_rank[g] + alpha * mi_rank[g]
162+
for g in all_genes
163+
}
164+
sorted_all_genes = sorted(all_genes, key=lambda g: hybrid_score[g])
165+
166+
# Build stats list with Moran's I column
167+
sorted_all = [
168+
(g, gene_totals[g], gene_morans_avg.get(g, 0.0))
169+
for g in sorted_all_genes
170+
]
171+
print(
172+
f"Hybrid ranking: expression weight={(1 - alpha):.1f}, "
173+
f"SVG weight={alpha:.1f}"
174+
)
175+
else:
176+
# Expression-only ranking (original behaviour)
177+
sorted_all = sorted(gene_totals.items(), key=lambda x: x[1], reverse=True)
178+
sorted_all_genes = [g for g, _ in sorted_all]
179+
# Pad stats tuples with 0.0 Moran's I for consistent CSV format
180+
sorted_all = [(g, c, 0.0) for g, c in sorted_all]
113181

114182
top_genes = list(prioritized_genes)
115-
for g, _ in sorted_all:
183+
for g in sorted_all_genes:
116184
if len(top_genes) >= num_genes:
117185
break
118186
if g not in prioritized_genes:
119187
top_genes.append(g)
120188

121189
print(
122-
f"Final set: {len(prioritized_genes)} pathway genes + {len(top_genes) - len(prioritized_genes)} global genes"
190+
f"Final set: {len(prioritized_genes)} pathway genes + "
191+
f"{len(top_genes) - len(prioritized_genes)} global genes"
123192
)
124193

125194
return top_genes, sorted_all
@@ -147,6 +216,19 @@ def main():
147216
default=None,
148217
help="List of MSigDB pathway names to explicitly prioritize (e.g., HALLMARK_P53_PATHWAY)",
149218
)
219+
parser.add_argument(
220+
"--svg-weight",
221+
type=float,
222+
default=0.0,
223+
help="Weight for spatial variability (Moran's I) in gene ranking. "
224+
"0.0=expression-only (default), 1.0=SVG-only, 0.5=balanced.",
225+
)
226+
parser.add_argument(
227+
"--svg-k",
228+
type=int,
229+
default=6,
230+
help="Number of KNN neighbours for spatial weight matrix (default: 6).",
231+
)
150232

151233
args = parser.parse_args()
152234

@@ -160,14 +242,15 @@ def main():
160242
sys.exit(1)
161243

162244
top_genes, all_stats = calculate_global_genes(
163-
args.data_dir, ids, args.num_genes, target_pathways=args.pathways
245+
args.data_dir, ids, args.num_genes, target_pathways=args.pathways,
246+
svg_weight=args.svg_weight, svg_k=args.svg_k,
164247
)
165248

166249
print(f"Saving top {len(top_genes)} genes to {output_path}")
167250
with open(output_path, "w") as f:
168251
json.dump(top_genes, f, indent=4)
169252

170-
stats_df = pd.DataFrame(all_stats, columns=["gene", "total_counts"])
253+
stats_df = pd.DataFrame(all_stats, columns=["gene", "total_counts", "morans_i"])
171254
stats_df.to_csv(output_path.replace(".json", "_stats.csv"), index=False)
172255
print("Saved stats to CSV.")
173256

src/spatial_transcript_former/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def main():
184184
epoch_row["val_pcc"] = round(val_metrics["val_pcc"], 4)
185185
if val_metrics.get("pred_variance") is not None:
186186
epoch_row["pred_variance"] = round(val_metrics["pred_variance"], 6)
187+
if val_metrics.get("spatial_coherence") is not None:
188+
epoch_row["spatial_coherence"] = round(val_metrics["spatial_coherence"], 4)
187189
if val_metrics.get("attn_correlation") is not None:
188190
epoch_row["attn_correlation"] = round(val_metrics["attn_correlation"], 4)
189191

0 commit comments

Comments
 (0)