-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
94 lines (72 loc) · 2.91 KB
/
utils.py
File metadata and controls
94 lines (72 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from sklearn.metrics import roc_auc_score
from sklearn.metrics import auc as auc3
from sklearn.metrics import precision_recall_curve
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pickle
import numpy as np
import torch
import torch.optim as optim
import os
import random
import warnings
warnings.filterwarnings(action='ignore')
def seed_everything(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) # type: ignore
torch.backends.cudnn.deterministic = True # type: ignore
torch.backends.cudnn.benchmark = True # type: ignore
def get_device(cuda_num=None):
if cuda_num in [0, 1, 2, 3]:
cuda = "cuda:"+str(cuda_num)
device = cuda if torch.cuda.is_available() else "cpu"
else:
device = "cpu"
return device
def count_parameters(module):
counts = sum(p.numel() for p in module.parameters() if p.requires_grad)
return counts
def get_optimizer(params, opt_name, lr=1e-4, w_decay=None):
if opt_name in ['AdamW', 'adamw', 'AdamW', 'adamW']:
weight_decay = 0 if w_decay is None else w_decay
return optim.AdamW(params, lr=lr, weight_decay=weight_decay)
elif opt_name in ['Adam', 'adam']:
weight_decay = 0 if w_decay is None else w_decay
return optim.Adam(params, lr=lr, weight_decay=weight_decay)
elif opt_name in ['SGD', 'sgd']:
weight_decay = 0 if w_decay is None else w_decay
return optim.SGD(params, lr=lr, weight_decay=weight_decay)
def get_params(args_dict):
params = {
'name': args_dict['name'],
'depth_g' : args_dict['depth_g'],
'dim_in' : args_dict['dim_in'],
'dim_out' : args_dict['dim_out'],
'depth_d' : args_dict['depth_d'],
}
return params
def split_data(labels, valid_test_ratio=0.2, seed=315):
"""Splits the nodes into train, validation and test sets."""
x = list(range(len(labels)))
y = labels[:, 2]
train, temp, _, y_temp = train_test_split(x, y, test_size=valid_test_ratio, random_state=seed, stratify=y)
valid, test, _, _ = train_test_split(temp, y_temp, test_size=0.5, random_state=seed, stratify=y_temp)
return train, valid, test
def get_DataLoader(result, batch_size=128, shuffle=True):
if batch_size is None:
batch_size = len(result)
dataset = DatasetID(result[:, 0], result[:, 1], result[:, 2])
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
class DatasetID(Dataset):
def __init__(self, ids_gene, ids_disease, labels):
self.ids_gene = ids_gene
self.ids_disease = ids_disease
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.ids_gene[idx], self.ids_disease[idx], self.labels[idx]