-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathScreening.py
More file actions
102 lines (98 loc) · 4.16 KB
/
Screening.py
File metadata and controls
102 lines (98 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import time
import utils.utils as utils
# from utils.utils import *
import torch.nn as nn
import torch
import time
import os
import glob
os.environ['CUDA_LAUNCH_BLOCKING']='1'
import argparse
import time
from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator
from dataset.dataset import ESDataset
import pickle
import pandas as pd
from model.equiscore import EquiScore
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
now = time.localtime()
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
s = "%04d-%02d-%02d %02d:%02d:%02d" % (now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)
print (s)
os.chdir(os.path.abspath(os.path.dirname(__file__)))
from torch.multiprocessing import Process
def run(local_rank,args,*more_args,**kwargs):
args.local_rank = local_rank
#initial distribution training,'nccl'mode
torch.distributed.init_process_group(backend="nccl",init_method='env://',rank = args.local_rank,world_size = args.ngpu)
torch.cuda.set_device(args.local_rank)
seed_torch(seed = args.seed + args.local_rank)
args_dict = vars(args)
if args.FP:
args.N_atom_features = 39
else:
args.N_atom_features = 28
model =EquiScore(args) if args.model == 'EquiScore' else None
args.device = args.local_rank
best_name = args.save_model
model_name = best_name.split('/')[-1]
save_path = best_name.replace(model_name,'')
if not os.path.exists(save_path):
os.makedirs(save_path)
args.test_path = os.path.join(args.test_path,args.test_name)
test_keys_pro = glob.glob(args.test_path + '/*')
test_dataset = ESDataset(test_keys_pro,args, args.test_path,args.debug)
test_sampler = SequentialDistributedSampler(test_dataset,args.batch_size) if args.ngpu >= 1 else None
test_dataloader = DataLoaderX(test_dataset, batch_size = args.batch_size, \
shuffle=False, num_workers = 8, collate_fn=test_dataset.collate,pin_memory = True,sampler = test_sampler)
model = utils.initialize_model(model, args.device,args, load_save_file = best_name )[0]
model.eval()
with torch.no_grad():
test_pred = []
for i_batch, (g,full_g,Y) in enumerate(test_dataloader):
model.zero_grad()
g = g.to(args.local_rank,non_blocking=True)
full_g = full_g.to(args.local_rank,non_blocking=True)
pred = model(g,full_g)
if pred.dim()==2:
pred = torch.softmax(pred,dim = -1)[:,1]
pred = pred if args.loss_fn == 'auc_loss' else pred
test_pred.append(pred.data) if args.ngpu >= 1 else test_pred.append(pred.data)
# gather ngpu result to single tensor
if args.ngpu >= 1:
test_pred = distributed_concat(torch.concat(test_pred, dim=0),
len(test_sampler.dataset)).cpu().numpy()
else:
test_pred = torch.concat(test_pred, dim=0).cpu().numpy()
if args.local_rank==0:
os.makedirs(os.path.dirname(args.pred_save_path),exist_ok=True)
pd.DataFrame({"test_sample_path":test_keys_pro,
"test_pred":test_pred}).sort_values("test_pred",ascending=False).to_csv(args.pred_save_path,index=False)
if '__main__' == __name__:
'''distribution training'''
from torch import distributed as dist
import torch.multiprocessing as mp
from utils.dist_utils import *
from utils.parsing import parse_train_args
args = parse_train_args()
if args.ngpu>0:
cmd = get_available_gpu(num_gpu=args.ngpu, min_memory=6000, sample=3, nitro_restriction=False, verbose=True)
if cmd[-1] == ',':
os.environ['CUDA_VISIBLE_DEVICES']=cmd[:-1]
else:
os.environ['CUDA_VISIBLE_DEVICES']=cmd
os.environ["MASTER_ADDR"] = args.MASTER_ADDR
os.environ["MASTER_PORT"] = args.MASTER_PORT
from torch.multiprocessing import Process
world_size = args.ngpu
processes = []
for rank in range(world_size):
p = Process(target=run, args=(rank, args))
p.start()
processes.append(p)
for p in processes:
p.join()