Skip to content

Commit a95e2fd

Browse files
committed
add gene2path and initial filtering
1 parent 8971ad1 commit a95e2fd

3 files changed

Lines changed: 478 additions & 12 deletions

File tree

enrichment_auc/gene2path.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""
2+
Transform Gene-level Data to Pathway-level Scores Using Single-Sample Methods
3+
4+
This module provides a comprehensive function similar to FUNCellA's gene2path.R
5+
for reducing gene-level data to pathway-level activity scores using various
6+
single-sample pathway enrichment methods with filtering capabilities.
7+
"""
8+
9+
import numpy as np
10+
import pandas as pd
11+
from typing import Dict, List, Optional, Union, Literal
12+
import warnings
13+
14+
# Import your existing metrics
15+
from enrichment_auc.metrics.mean import MEAN
16+
from enrichment_auc.metrics.bina import BINA
17+
from enrichment_auc.metrics.cerno import AUC
18+
from enrichment_auc.metrics.aucell import AUCELL
19+
from enrichment_auc.metrics.jasmine import JASMINE
20+
from enrichment_auc.metrics.z import Z as ZSCORE
21+
from enrichment_auc.metrics.gsea import SSGSEA
22+
23+
# Import preprocessing functions
24+
from enrichment_auc.preprocess.filter import (
25+
filter as variance_filter,
26+
filter_coverage,
27+
filter_size,
28+
)
29+
30+
31+
def gene2path(
32+
data: Union[np.ndarray, pd.DataFrame],
33+
genesets: Dict[str, List[str]],
34+
genes: Optional[List[str]] = None,
35+
method: Literal[
36+
"CERNO", "MEAN", "BINA", "AUCELL", "JASMINE", "ZSCORE", "SSGSEA", "AUC"
37+
] = "CERNO",
38+
filt_cov: float = 0,
39+
filt_min: int = 15,
40+
filt_max: int = 500,
41+
aucell_threshold: float = 0.05,
42+
variance_filter_threshold: Optional[float] = None,
43+
) -> pd.DataFrame:
44+
"""
45+
Transform gene-level data to pathway-level scores using single-sample methods.
46+
47+
This function reduces gene-level data to pathway-level activity scores using a variety of
48+
single-sample pathway enrichment methods. It supports filtering pathways based on coverage
49+
and size, similar to FUNCellA's gene2path function.
50+
51+
Args:
52+
data: Gene expression matrix with genes as rows and samples as columns
53+
genesets: Dictionary mapping pathway names to lists of gene identifiers
54+
genes: List of gene names (if None, uses data index if DataFrame or creates range)
55+
method: Single-sample enrichment method to use
56+
filt_cov: Minimum fraction of pathway genes that must be present in data (0-1)
57+
filt_min: Minimum number of genes a pathway must have
58+
filt_max: Maximum number of genes a pathway can have
59+
aucell_threshold: Threshold parameter for AUCELL method (fraction of top genes to consider)
60+
variance_filter_threshold: If provided, filter genes by variance (keep top fraction)
61+
62+
Returns:
63+
DataFrame with pathways as rows and samples as columns containing pathway activity scores
64+
65+
Methods:
66+
- CERNO: Non-parametric method based on gene expression ranks using Mann-Whitney U statistic
67+
- MEAN: Simple mean expression of pathway genes per sample
68+
- BINA: Binary scoring based on proportion of expressed genes with logit transformation
69+
- AUCELL: Area Under the Curve method for gene set enrichment
70+
- JASMINE: Dropout-aware method for single-cell data with effect size adjustment
71+
- ZSCORE: Z-score based method using Stouffer integration
72+
- SSGSEA: Single-sample Gene Set Enrichment Analysis (requires R)
73+
74+
Example:
75+
>>> import pandas as pd
76+
>>> import numpy as np
77+
>>> # Create example data
78+
>>> data = pd.DataFrame(np.random.randn(1000, 50)) # 1000 genes, 50 samples
79+
>>> data.index = [f"Gene_{i}" for i in range(1000)]
80+
>>> genesets = {
81+
... "Pathway1": [f"Gene_{i}" for i in range(0, 20)],
82+
... "Pathway2": [f"Gene_{i}" for i in range(10, 30)]
83+
... }
84+
>>> scores = gene2path(data, genesets, method="CERNO")
85+
"""
86+
87+
# Input validation
88+
if data is None or len(data) == 0:
89+
raise ValueError("No data provided")
90+
91+
if genesets is None or len(genesets) == 0:
92+
raise ValueError("No pathway list provided")
93+
94+
# Convert to numpy array if needed and get gene names
95+
if isinstance(data, pd.DataFrame):
96+
if genes is None:
97+
genes = list(data.index)
98+
data_array = data.values
99+
sample_names = list(data.columns)
100+
else:
101+
data_array = np.array(data)
102+
if genes is None:
103+
genes = [f"Gene_{i}" for i in range(data_array.shape[0])]
104+
sample_names = [f"Sample_{i}" for i in range(data_array.shape[1])]
105+
106+
if len(genes) != data_array.shape[0]:
107+
raise ValueError("Number of genes must match number of rows in data")
108+
109+
# Validate method
110+
valid_methods = [
111+
"CERNO",
112+
"MEAN",
113+
"BINA",
114+
"AUCELL",
115+
"JASMINE",
116+
"ZSCORE",
117+
"SSGSEA",
118+
]
119+
if method not in valid_methods:
120+
raise ValueError(f"Method must be one of {valid_methods}")
121+
122+
print(f"Starting gene2path transformation using {method} method")
123+
print(f"Input data: {data_array.shape[0]} genes x {data_array.shape[1]} samples")
124+
print(f"Input pathways: {len(genesets)}")
125+
126+
# Apply variance filtering if requested
127+
if variance_filter_threshold is not None:
128+
print(f"Applying variance filtering (keep top {variance_filter_threshold:.2%})")
129+
if isinstance(data, pd.DataFrame):
130+
filtered_data = variance_filter(data, leave_best=variance_filter_threshold)
131+
data_array = filtered_data.values
132+
genes = list(filtered_data.index)
133+
else:
134+
# Use enhanced filter function for numpy arrays
135+
filtered_result = variance_filter(
136+
data_array, leave_best=variance_filter_threshold, genes=genes
137+
)
138+
if isinstance(filtered_result, tuple):
139+
data_array, genes = filtered_result
140+
else:
141+
raise ValueError("Unexpected return type from variance filter")
142+
print(f"After variance filtering: {len(genes)} genes")
143+
144+
# Ensure genes is not None for the rest of the function
145+
if genes is None:
146+
raise ValueError("Gene names are required for filtering operations")
147+
148+
# Filter pathways by size (filt_min, filt_max)
149+
if filt_min > 0 or filt_max < float("inf"):
150+
genesets = filter_size(genesets, min_size=filt_min, max_size=filt_max)
151+
152+
# Filter pathways by coverage (filt_cov)
153+
if filt_cov > 0:
154+
genesets = filter_coverage(genesets, genes, min_coverage=filt_cov)
155+
156+
print(f"Final pathways for analysis: {len(genesets)}")
157+
158+
if len(genesets) == 0:
159+
warnings.warn("No pathways remain after filtering")
160+
return pd.DataFrame()
161+
162+
# Calculate pathway scores based on method
163+
print(f"Calculating {method} scores...")
164+
165+
if method == "CERNO":
166+
scores = AUC(genesets, data_array, genes) # Use existing AUC function
167+
elif method == "MEAN":
168+
scores = MEAN(genesets, data_array, genes)
169+
elif method == "BINA":
170+
scores = BINA(genesets, data_array, genes)
171+
elif method == "AUCELL":
172+
scores = AUCELL(genesets, data_array, genes, aucell_threshold)
173+
elif method == "JASMINE":
174+
scores = JASMINE(genesets, data_array, genes)
175+
elif method == "ZSCORE":
176+
scores = ZSCORE(genesets, data_array, genes)
177+
elif method == "SSGSEA":
178+
scores = SSGSEA(genesets, data_array, genes)
179+
else:
180+
raise ValueError(f"Method {method} not implemented")
181+
182+
# Create result DataFrame
183+
pathway_names = list(genesets.keys())
184+
result_df = pd.DataFrame(scores, index=pathway_names, columns=sample_names)
185+
186+
print(f"{method} scores calculated successfully")
187+
print(f"Output: {result_df.shape[0]} pathways x {result_df.shape[1]} samples")
188+
189+
return result_df
190+
191+
192+
# Example usage and utility functions
193+
def create_example_data(
194+
n_genes: int = 1000, n_samples: int = 50, n_pathways: int = 10
195+
) -> tuple:
196+
"""
197+
Create example data for testing gene2path function.
198+
199+
Args:
200+
n_genes: Number of genes
201+
n_samples: Number of samples
202+
n_pathways: Number of pathways to create
203+
204+
Returns:
205+
Tuple of (data, genesets, genes) for testing
206+
"""
207+
# Create gene expression data
208+
data = np.random.randn(n_genes, n_samples)
209+
210+
# Create gene names
211+
genes = [f"Gene_{i}" for i in range(n_genes)]
212+
213+
# Create pathways
214+
genesets = {}
215+
pathway_size_range = (15, 50)
216+
217+
for i in range(n_pathways):
218+
pathway_size = np.random.randint(pathway_size_range[0], pathway_size_range[1])
219+
pathway_genes = np.random.choice(genes, size=pathway_size, replace=False)
220+
genesets[f"Pathway_{i}"] = list(pathway_genes)
221+
222+
return data, genesets, genes
223+
224+
225+
def run_example():
226+
"""
227+
Run an example of the gene2path function with various methods.
228+
"""
229+
print("=== Gene2Path Example ===")
230+
231+
# Create example data
232+
data, genesets, genes = create_example_data(n_genes=500, n_samples=20, n_pathways=5)
233+
print(
234+
f"Created example data: {len(genes)} genes, {len(genesets)} pathways, {data.shape[1]} samples"
235+
)
236+
237+
# Test different methods
238+
methods = ["CERNO", "MEAN", "BINA", "AUCELL", "JASMINE", "ZSCORE", "SSGSEA", "AUC"]
239+
240+
for method in methods:
241+
try:
242+
print(f"\nTesting {method} method...")
243+
scores = gene2path(
244+
data=data,
245+
genesets=genesets,
246+
genes=genes,
247+
method=method, # type: ignore
248+
)
249+
print(
250+
f"✓ {method}: Generated {scores.shape[0]} pathway scores for {scores.shape[1]} samples"
251+
)
252+
253+
except Exception as e:
254+
print(f"✗ {method}: Failed with error: {str(e)}")
255+
256+
print("\n=== Example Complete ===")
257+
258+
259+
if __name__ == "__main__":
260+
# Run example when script is executed directly
261+
run_example()

0 commit comments

Comments
 (0)