@@ -41,16 +41,16 @@ def _min_norm_2d(vecs, dps):
4141 if (i ,j ) not in dps :
4242 dps [(i , j )] = 0.0
4343 for k in range (len (vecs [i ])):
44- dps [(i ,j )] += torch .dot (vecs [i ][k ], vecs [j ][k ]).data [ 0 ]
44+ dps [(i ,j )] += torch .mul (vecs [i ][k ], vecs [j ][k ]).sum (). data . cpu ()
4545 dps [(j , i )] = dps [(i , j )]
4646 if (i ,i ) not in dps :
4747 dps [(i , i )] = 0.0
4848 for k in range (len (vecs [i ])):
49- dps [(i ,i )] += torch .dot (vecs [i ][k ], vecs [i ][k ]).data [ 0 ]
49+ dps [(i ,i )] += torch .mul (vecs [i ][k ], vecs [i ][k ]).sum (). data . cpu ()
5050 if (j ,j ) not in dps :
5151 dps [(j , j )] = 0.0
5252 for k in range (len (vecs [i ])):
53- dps [(j , j )] += torch .dot (vecs [j ][k ], vecs [j ][k ]).data [ 0 ]
53+ dps [(j , j )] += torch .mul (vecs [j ][k ], vecs [j ][k ]).sum (). data . cpu ()
5454 c ,d = MinNormSolver ._min_norm_element_from2 (dps [(i ,i )], dps [(i ,j )], dps [(j ,j )])
5555 if d < dmin :
5656 dmin = d
@@ -184,16 +184,16 @@ def gradient_normalizers(grads, losses, normalization_type):
184184 gn = {}
185185 if normalization_type == 'l2' :
186186 for t in grads :
187- gn [t ] = np .sqrt (np .sum ([gr .pow (2 ).sum ().data [ 0 ] for gr in grads [t ]]))
187+ gn [t ] = np .sqrt (np .sum ([gr .pow (2 ).sum ().data . cpu () for gr in grads [t ]]))
188188 elif normalization_type == 'loss' :
189189 for t in grads :
190190 gn [t ] = losses [t ]
191191 elif normalization_type == 'loss+' :
192192 for t in grads :
193- gn [t ] = losses [t ] * np .sqrt (np .sum ([gr .pow (2 ).sum ().data [ 0 ] for gr in grads [t ]]))
193+ gn [t ] = losses [t ] * np .sqrt (np .sum ([gr .pow (2 ).sum ().data . cpu () for gr in grads [t ]]))
194194 elif normalization_type == 'none' :
195195 for t in grads :
196196 gn [t ] = 1.0
197197 else :
198198 print ('ERROR: Invalid Normalization Type' )
199- return gn
199+ return gn
0 commit comments