-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpathways.py
More file actions
148 lines (118 loc) · 4.86 KB
/
pathways.py
File metadata and controls
148 lines (118 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
MSigDB Hallmarks pathway utilities.
Downloads, parses, and converts MSigDB Hallmark gene sets into
a pathway membership matrix for initializing the SpatialTranscriptFormer.
"""
import json
import os
import torch
import urllib.request
from typing import Dict, List, Optional
# MSigDB collections URLs (v2024.1.Hs, gene symbols)
MSIGDB_URLS = {
"hallmarks": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/h.all.v2024.1.Hs.symbols.gmt",
"c2_medicus": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/c2.cp.kegg_medicus.v2024.1.Hs.symbols.gmt",
"c2_cgp": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/c2.cgp.v2024.1.Hs.symbols.gmt",
}
def download_msigdb_gmt(url: str, filename: str, cache_dir: str = ".cache") -> str:
"""
Download an MSigDB GMT file if not already cached.
Returns:
str: Path to the local GMT file.
"""
os.makedirs(cache_dir, exist_ok=True)
local_path = os.path.join(cache_dir, filename)
if not os.path.exists(local_path):
print(f"Downloading MSigDB gene sets from {url}...")
urllib.request.urlretrieve(url, local_path)
print(f"Saved to {local_path}")
return local_path
def parse_gmt(gmt_path: str) -> Dict[str, List[str]]:
"""
Parse a GMT file into a dict of {pathway_name: [gene_symbols]}.
GMT format: each line is tab-separated:
pathway_name \\t description \\t gene1 \\t gene2 \\t ...
"""
pathways = {}
with open(gmt_path, "r") as f:
for line in f:
parts = line.strip().split("\t")
if len(parts) < 3:
continue
name = parts[0]
genes = parts[2:] # Skip description at index 1
pathways[name] = genes
return pathways
def build_membership_matrix(
pathway_dict: Dict[str, List[str]], gene_list: List[str], scale: float = 1.0
) -> torch.Tensor:
"""
Build a binary membership matrix (num_pathways x num_genes).
Args:
pathway_dict: {pathway_name: [gene_symbols]} from parse_gmt().
gene_list: Ordered list of gene symbols (e.g., from global_genes.json).
scale: Value for member genes (default 1.0). Non-members are 0.
Returns:
torch.Tensor: Shape (num_pathways, num_genes).
"""
gene_to_idx = {g: i for i, g in enumerate(gene_list)}
num_pathways = len(pathway_dict)
num_genes = len(gene_list)
matrix = torch.zeros(num_pathways, num_genes)
pathway_names = []
for p_idx, (name, genes) in enumerate(pathway_dict.items()):
pathway_names.append(name)
matched = 0
for gene in genes:
if gene in gene_to_idx:
matrix[p_idx, gene_to_idx[gene]] = scale
matched += 1
return matrix, pathway_names
def get_pathway_init(
gene_list: List[str],
gmt_urls: Optional[List[str]] = None,
filter_names: Optional[List[str]] = None,
cache_dir: str = ".cache",
verbose: bool = True,
) -> tuple:
"""
Main entry point: download GMTs, match to gene list, return init matrix.
Args:
gene_list: Ordered list of gene symbols from global_genes.json.
gmt_urls: List of MSigDB GMT URLs to download. Defaults to Hallmarks and C2 KEGG.
filter_names: If provided, only include these specific pathway names.
cache_dir: Directory to cache the downloaded GMT file.
verbose: Print pathway coverage statistics.
Returns:
tuple: (membership_matrix [Tensor (P, G)], pathway_names [list of str])
"""
if gmt_urls is None:
gmt_urls = [MSIGDB_URLS["hallmarks"]]
combined_dict = {}
for url in gmt_urls:
filename = url.split("/")[-1]
local_path = download_msigdb_gmt(url, filename, cache_dir)
pathway_dict = parse_gmt(local_path)
# Filter if requested
if filter_names is not None:
pathway_dict = {k: v for k, v in pathway_dict.items() if k in filter_names}
# Merge with combined dict (don't overwrite if name collision occurs somehow)
for k, v in pathway_dict.items():
if k not in combined_dict:
combined_dict[k] = v
if not combined_dict:
raise ValueError("No pathways matched the provided filter or URLs.")
matrix, pathway_names = build_membership_matrix(combined_dict, gene_list)
if verbose:
total_genes = len(gene_list)
covered = (matrix.sum(dim=0) > 0).sum().item()
print(f"Pathways initialized: {len(pathway_names)}")
print(
f"Gene coverage: {covered}/{total_genes} ({100*covered/total_genes:.1f}%)"
)
for i, name in enumerate(pathway_names):
n_matched = int(matrix[i].sum().item())
short_name = name.replace("HALLMARK_", "")
if verbose and n_matched > 0:
print(f" {short_name}: {n_matched} genes")
return matrix, pathway_names