-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
55 lines (51 loc) · 1.97 KB
/
trainer.py
File metadata and controls
55 lines (51 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from tqdm import tqdm
import numpy as np
class Trainer(object):
def __init__(self, model, args, X, L_star = None):
self.model = model
self.args = args
self.L_star = None
if L_star is not None:
self.L_star = L_star
self.error_ls = []
self.X = X
if args.model_name == 'subgradient':
self.train = self.train_subgradient
if args.model_name == 'A_IRLS_combined':
self.train = self.train_A_ILRS_combined
if args.model_name == 'A_IRLS':
self.train = self.train_A_ILRS
def train_subgradient(self,args,U,V):
U0,V0 = U,V
max_iter = args.max_iter_subG
lr = args.lr
for _ in tqdm(range(max_iter)):
U0 += self.model(U0,V0,self.X, target_grad='u')*lr
V0 += self.model(U0,V0,self.X, target_grad='v')*lr
if self.L_star is not None:
error = np.sum(np.abs(U0@V0.T-self.L_star))
self.error_ls.append(error)
return U0,V0
def train_A_ILRS_combined(self,args,U,V):
U0,V0 = U,V
max_iter = args.max_iter_A_ILRS_combined
for _ in tqdm(range(max_iter)):
U0 = self.model(U0,V0,self.X, target_grad='u')
V0 = self.model(U0,V0,self.X, target_grad='v')
if self.L_star is not None:
error = np.sum(np.abs(U0@V0.T-self.L_star))
self.error_ls.append(error)
return U0,V0
def train_A_ILRS(self,args,U,V):
U0,V0 = U,V
d,n = self.X.shape
max_iter = args.max_iter_A_ILRS
for _ in tqdm(range(max_iter)):
for k in range(n):
V0[k,:] = self.model(U0,V0,self.X[:,k],target_grad='v')
for k in range(d):
U0[k,:] = self.model(U0,V0,self.X[k,:],target_grad='u')
if self.L_star is not None:
error = np.sum(np.abs(U0@V0.T-self.L_star))
self.error_ls.append(error)
return U0,V0