-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
132 lines (110 loc) · 4.65 KB
/
utils.py
File metadata and controls
132 lines (110 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# File: utils.py
# Includes all the the utility functions for the project!
import torch
from torch import nn
from tests.bernstein_comparison import ispsd #to put back
# A polynomial is represented as coefficients of
# [1, x, x^2, x^3, ... x^d]
def batch_multiply_poly_tensors(x: torch.Tensor, y: torch.Tensor):
# Inputs are: x = (batch, m)
# y = (batch, n)
# Output: (batch, m+n-1)
z = torch.bmm(x.unsqueeze(-1), y.unsqueeze(1))
return batch_sum_antidiagonals(z)
def batch_sum_antidiagonals(z: torch.Tensor):
# z is a 3D tensor, z = (batch, n1, n2)
# output is 2D, (batch, n1 + n2 - 1)
# and n1
b, n1, n2 = z.shape
zpad = torch.cat((z, torch.zeros((b, n1, n1 - 1), device=z.device)), -1)
zpad = zpad.as_strided(zpad.shape, (zpad.shape[2]*n1, zpad.shape[2]-1,1))
return torch.sum(zpad, 1) # sums the columns
def multiply_poly_tensors(x: torch.Tensor, y: torch.Tensor):
"""
Assumes that x and y are single dimensional arrays whose entries represent
the coefficients in the monomial basis
[y^d, x*y^(d-1), ..., x^d] (here x and y are not input but the variables x,y)
So [0, 1, 0, 0] represents the polynomial xy^2
And [2, 0, 0, 1, -1] represents 2y^4 + x^3y - x^4
>>> a = torch.Tensor([0,1,0,0])
>>> b = torch.Tensor([2, 0, 0, 1, -1])
>>> multiply_poly_tensors(a, a)
tensor([0., 0., 1., 0., 0., 0., 0.])
We get x^2y^4 as expected.
>>> multiply_poly_tensors(a, b)
tensor([ 0., 2., 0., 0., 1., -1., 0., 0.])
a*b is 2xy^6 + x^4 y^3 - x^5y^2
If x and y are lengths m and n, then the output size is m + n - 1
Credit for the implementation idea from
https://stackoverflow.com/questions/57347896/sum-all-diagonals-in-feature-maps-in-parallel-in-pytorch
"""
z = torch.outer(x, y)
return sum_antidiagonals(z)
def sum_antidiagonals(z: torch.Tensor):
# z is a 2D tensor
n1, n2 = z.shape
zpad = torch.cat((z, torch.zeros((n1, n1 - 1), device=z.device)), 1)
zpad = zpad.as_strided(zpad.shape, (zpad.shape[1]-1,1))
return torch.sum(zpad, 0) # sums the columns
def count_parameters(model):
# From https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def diff_normalized_mse_loss(x, y, normalizer):
if x.shape != y.shape:
print('ERROR: in normalized_mse_loss, shapes are', x.shape,'and ', y.shape)
return
if x.shape[0] != y.shape[0] or len(x.shape) < 2:
#print('ERROR: normalized_mse_loss expects a batch dimension, but shapes are', x.shape,'and ', y.shape)
#return
x = x.unsqueeze(-1)
y = y.unsqueeze(-1)
normalizer = normalizer.unsqueeze(-1)
if len(x.shape) == 3:
dims_to_reduce = (1,2)
else:
if len(x.shape) == 2:
dims_to_reduce = (1)
normalized_loss = torch.mean(torch.divide(torch.sum(torch.square(x - y),dim=dims_to_reduce), torch.sum(torch.square(normalizer),dim=dims_to_reduce) + 1))
#normalized_loss = torch.divide(torch.square(torch.norm(x - y)), torch.square(torch.norm(y)) + 1)
return normalized_loss
def normalized_mse_loss(x, y):
if x.shape != y.shape:
print('ERROR: in normalized_mse_loss, shapes are', x.shape,'and ', y.shape)
return
if x.shape[0] != y.shape[0] or len(x.shape) < 2:
#print('ERROR: normalized_mse_loss expects a batch dimension, but shapes are', x.shape,'and ', y.shape)
#return
x = x.unsqueeze(-1)
y = y.unsqueeze(-1)
if len(x.shape) == 3:
dims_to_reduce = (1,2)
else:
if len(x.shape) == 2:
dims_to_reduce = (1)
normalized_loss = torch.mean(torch.divide(torch.sum(torch.square(x - y),dim=dims_to_reduce), torch.sum(torch.square(y),dim=dims_to_reduce) + 1))
#normalized_loss = torch.divide(torch.square(torch.norm(x - y)), torch.square(torch.norm(y)) + 1)
return normalized_loss
def prepare_for_logger(kwargs):
mydict = {}
badkeys = []
for ky in kwargs.keys():
if type(kwargs[ky]) == type([1,2]):
for ind, elt in enumerate(kwargs[ky]):
mydict[f'ky_{ind}'] = elt
elif type(kwargs[ky]) == type("abc"):
mydict[kwargs[ky]] = 1
else:
mydict[ky] = kwargs[ky]
return mydict
if __name__ == "__main__":
import doctest
doctest.testmod()
def fraction_psd(mats, cutoff=0.0):
# mats is batch x dim x dim
numpsd = 0.0
numbatch = mats.shape[0]
with torch.no_grad():
for i in range(numbatch):
if ispsd(mats[i], cutoff=cutoff):
numpsd += 1
return numpsd / numbatch