-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathConv.py
More file actions
115 lines (92 loc) · 3.65 KB
/
Conv.py
File metadata and controls
115 lines (92 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import matplotlib.pyplot as plt
def construct_toeplitz_2d(input_size, filter_weights):
"""
Construct a Toeplitz matrix for 2D convolution
"""
input_h, input_w = input_size
filter_h, filter_w = filter_weights.shape[-2:]
output_h = input_h - filter_h + 1
output_w = input_w - filter_w + 1
toeplitz_matrix = torch.zeros((output_h * output_w, input_h * input_w))
for i in range(output_h):
for j in range(output_w):
row = i * output_w + j
for m in range(filter_h):
for n in range(filter_w):
col = (i + m) * input_w + (j + n)
toeplitz_matrix[row, col] = filter_weights[0, 0, m, n]
return toeplitz_matrix
def model(input, w):
"""
Forward pass
"""
sopl = nn.Softplus(beta=10)
return sopl(torch.matmul(input, w.T))
def stats(w, test_loader):
"""
Compute statistics
"""
# -- load test data
x, test_label = next(iter(test_loader))
x = x.view(x.size(0), -1)
# -- forward pass
y = model(x, w)
# -- compute reconstruction error
rec_err_a = ((x - torch.matmul(y, w)).norm(dim=1)**2).mean().item()
rec_err_b = ((torch.matmul(x, w.T) - torch.matmul(y, torch.matmul(w, w.T))).norm(dim=1)**2).mean().item()
# -- compute activation statistics
y_norm = [y_.norm(dim=1).mean().item() for y_ in [x, y]]
return rec_err_a, rec_err_b, y_norm
# -- params
Theta = 0.01
n_train, n_test = 5000, 10
# -- load data
inp_h, inp_w = 7, 7
data_dir = '../../data'
transform = transforms.Compose([transforms.Resize((inp_h, inp_w)), transforms.ToTensor()])
dataset_train = datasets.MNIST(data_dir, train=True, download=False, transform=transform)
dataset_test = datasets.MNIST(data_dir, train=False, download=False, transform=transform)
train_sampler = SubsetRandomSampler(np.random.choice(range(50000), n_train, False))
test_sampler = SubsetRandomSampler(np.random.choice(range(10000), n_test, False))
train_loader = DataLoader(dataset_train, batch_size=1, sampler=train_sampler)
test_loader = DataLoader(dataset_test, batch_size=n_test, sampler=test_sampler)
# -- construct filter
in_channels, out_channels, filter_h, filter_w = 1, 1, 3, 3
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(filter_h, filter_w), bias=False)
toeplitz_matrix = construct_toeplitz_2d((inp_h, inp_w), conv_layer.weight.data.clone())
mask = toeplitz_matrix.ne(0).int()
# -- training loop
rec_err_a, rec_err_b, y_norm = [], [], []
for i, (input_tensor, labels) in enumerate(train_loader):
# -- forward pass
input = input_tensor.view(1, -1)
output = model(input, toeplitz_matrix)
# -- apply Oja's subspace rule
toeplitz_matrix += Theta * (torch.matmul(output.T, input) -
torch.matmul(torch.matmul(output.T, output), toeplitz_matrix))
toeplitz_matrix *= mask
if i % 100 == 0 and i < 1000:
plt.imshow(toeplitz_matrix.detach().numpy())
plt.show()
plt.close()
# -- compute statistics todo: pass test data to compute statistics
my_stats = stats(toeplitz_matrix, test_loader)
rec_err_a.append(my_stats[0])
rec_err_b.append(my_stats[1])
y_norm.append(my_stats[2])
plt.plot(rec_err_a, label='Reconstruction Error A')
plt.show()
plt.close()
plt.plot(rec_err_b, label='Reconstruction Error B')
plt.show()
plt.close()
plt.plot(np.array(y_norm)[:, 0], label='Activation Norm')
plt.plot(np.array(y_norm)[:, 1], label='Activation Norm')
plt.show()
plt.close()