-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
194 lines (149 loc) · 7.24 KB
/
dataset.py
File metadata and controls
194 lines (149 loc) · 7.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import numpy as np
import pandas as pd
import warnings
import rdkit.Chem as Chem
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
from sklearn.decomposition import PCA, KernelPCA
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
import gc
import hashlib
def fp_str_to_array(str_):
return np.array([int(bit) for bit in list(str_)])
def calc_fps(smiles, fpgen):
mol = Chem.MolFromSmiles(smiles)
fp = fp_str_to_array(fpgen.GetFingerprint(mol).ToBitString())
return fp
class SynergyDataset:
def __init__(self,
input_train,
input_valid,
input_test,
gene_cell_df_path='data/gene_cell_df.csv',
meta_info_cell_df_path='data/meta_info_cell_df.csv'):
self.gene_cell_df = pd.read_csv(gene_cell_df_path,
index_col=0)
self.meta_info_cell_df = pd.read_csv(meta_info_cell_df_path,
index_col=0)
self.input_train = input_train
self.input_valid = input_valid
self.input_test = input_test
self.inputs = dict()
self.inputs['train'] = input_train
self.inputs['valid'] = input_valid
self.inputs['test'] = input_test
if not set(['Drug1', 'Drug2', 'cell_id', 'Y']).issubset(set(self.input_train.columns)):
warnings.warn('Columns "Drug1", "Drug2", "cell_id", "Y" should be in input dataframes')
self.mol_embed = dict()
self.gene_embed = dict()
self.gene_data = dict()
self.gene_meta_data = dict()
self.splits = dict()
self.scaler = None
self.fpgen = None
self.dim_red_algo = None
self.body_zones = None
self.gene_names = None
def supported_cell_id_names(self):
return self.meta_info_cell_df['cell_id'].tolist()
def load(self, mol_embed='fp', gene_embed='pca', **params):
print('Molecule embeddings creation')
if mol_embed == 'fp':
radius = params.get('radius', 4)
fpSize = params.get('fpSize', 64)
fpgen = GetMorganGenerator(radius=radius,
fpSize=fpSize)
self.fpgen = fpgen
def calc_fps(smiles):
mol = Chem.MolFromSmiles(smiles)
fp = fp_str_to_array(self.fpgen.GetFingerprint(mol).ToBitString())
return fp
else:
raise ValueError('Unknown molecule feature extraction method')
print('Gene data extraction')
self.mol_embed['train'] = pd.DataFrame()
self.mol_embed['train']['Drug1'] = self.input_train['Drug1'].apply(calc_fps)
self.mol_embed['train']['Drug2'] = self.input_train['Drug2'].apply(calc_fps)
self._load_gene_data(self.input_train, 'train')
self._load_gene_meta_data(self.input_train, 'train')
self.mol_embed['valid'] = pd.DataFrame()
self.mol_embed['valid']['Drug1'] = self.input_valid['Drug1'].apply(calc_fps)
self.mol_embed['valid']['Drug2'] = self.input_valid['Drug2'].apply(calc_fps)
self._load_gene_data(self.input_valid, 'valid')
self._load_gene_meta_data(self.input_valid, 'valid')
self.mol_embed['test'] = pd.DataFrame()
self.mol_embed['test']['Drug1'] = self.input_test['Drug1'].apply(calc_fps)
self.mol_embed['test']['Drug2'] = self.input_test['Drug2'].apply(calc_fps)
self._load_gene_data(self.input_test, 'test')
self._load_gene_meta_data(self.input_test, 'test')
gene_names = self.gene_data['train'].columns.tolist()[4:]
self.gene_names = gene_names
genes_train = self.gene_data['train'][gene_names]
genes_valid = self.gene_data['valid'][gene_names]
genes_test = self.gene_data['test'][gene_names]
scaler = MinMaxScaler()
self.scaler = scaler
genes_scaled_train = self.scaler.fit_transform(genes_train)
genes_scaled_valid = self.scaler.transform(genes_valid)
genes_scaled_test = self.scaler.transform(genes_test)
if gene_embed == 'pca':
print('Gene PCA features creation')
pca = PCA(n_components=params.get('n_components', 10))
self.dim_red_algo = pca
elif gene_embed == 'kernel_pca':
print('Gene kernel PCA features creation')
kernel_pca = KernelPCA(n_components=params.get('n_components', 10),
kernel=params.get('kernel', 'rbf'))
self.dim_red_algo = kernel_pca
else:
raise ValueError('Unknown gene expression encoding method')
gene_features_train = self.dim_red_algo.fit_transform(genes_scaled_train)
gene_features_valid = self.dim_red_algo.transform(genes_scaled_valid)
gene_features_test = self.dim_red_algo.transform(genes_scaled_test)
print('Cell body zone information encoding')
enc = OneHotEncoder(handle_unknown='ignore')
X = self.gene_meta_data['train'][['body_zone']]
enc.fit(X)
self.body_zones = enc.categories_[0].tolist()
df_train = pd.DataFrame()
df_train = self.mol_embed['train']
df_train['gene_features'] = gene_features_train.tolist()
df_train['body_zone'] = enc.transform(X).toarray().tolist()
self.splits['train'] = df_train
df_valid = pd.DataFrame()
df_valid = self.mol_embed['valid']
df_valid['gene_features'] = gene_features_valid.tolist()
df_valid['body_zone'] = enc.transform(
self.gene_meta_data['valid'][['body_zone']]
).toarray().tolist()
self.splits['valid'] = df_valid
df_test = pd.DataFrame()
df_test = self.mol_embed['test']
df_test['gene_features'] = gene_features_test.tolist()
df_test['body_zone'] = enc.transform(
self.gene_meta_data['test'][['body_zone']]
).toarray().tolist()
self.splits['test'] = df_test
def _load_gene_data(self, df, split_type):
gene_data = df.copy()
result = pd.merge(gene_data,
self.gene_cell_df,
how="left",
on='cell_id')
self.gene_data[split_type] = result
del gene_data
gc.collect()
def _load_gene_meta_data(self, df, split_type):
gene_data = df.copy()
result = pd.merge(gene_data,
self.meta_info_cell_df,
how="left",
on='cell_id')
self.gene_meta_data[split_type] = result
del gene_data
gc.collect()
def get_item(self, idx, split_type='train'):
smi1 = self.inputs[split_type].iloc[idx]['Drug1']
smi2 = self.inputs[split_type].iloc[idx]['Drug2']
cell_id = self.inputs[split_type].iloc[idx]['cell_id']
y = self.inputs[split_type].iloc[idx]['Y']
return smi1, smi2, cell_id, self.splits[split_type].iloc[idx], y