-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnmrshiftdb_to_graphs.py
More file actions
49 lines (40 loc) · 1.22 KB
/
nmrshiftdb_to_graphs.py
File metadata and controls
49 lines (40 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import pandas as pd
import pickle
import rdkit
from rdkit import Chem
from graph import Graph
# d_name = "nmrshiftdb.1H.pickle"
data = pickle.load(open('data/nmrshiftdb.13C.pickle', 'rb'))
BOND_TYPE_TO_LABEL = {
rdkit.Chem.rdchem.BondType.SINGLE : 1,
rdkit.Chem.rdchem.BondType.DOUBLE : 2,
rdkit.Chem.rdchem.BondType.TRIPLE : 3,
rdkit.Chem.rdchem.BondType.AROMATIC : 1.5
}
def load_graph(row):
mol = Chem.Mol(row['rdmol'])
n = mol.GetNumAtoms()
edges, efs = [], []
for e in mol.GetBonds():
edges += [(e.GetBeginAtomIdx()+1, e.GetEndAtomIdx()+1)]
efs += [BOND_TYPE_TO_LABEL[e.GetBondType()]]
vfs = [a.GetAtomicNum() for a in mol.GetAtoms()]
labels = []
sd = row['spect_dict'][0]
for i in range(n):
if i in sd:
labels += [sd[i]]
else:
labels += [-100]
return Graph(n, edges, vfs, efs), labels
d_train = {}
d_test = {}
for i, row in data.iterrows():
g, l = load_graph(row)
if row['morgan4_crc32']%10 <= 1:
d_test[i] = (g, l)
else:
d_train[i] = (g, l)
pickle.dump(d_train, open('data/nmrshiftdb.13C.train.pickle', 'wb'))
pickle.dump(d_test, open('data/nmrshiftdb.13C.test.pickle', 'wb'))