-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptim.py
More file actions
149 lines (124 loc) · 6.34 KB
/
optim.py
File metadata and controls
149 lines (124 loc) · 6.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
from torch import nn
from torch.nn import functional as F
def generic_rule(activation, e, params, feedback, Theta, vec, fbk):
"""
Plasticity rule
This function receives the network weights as input and updates them.
It's worth noting that the weights provided are a cloned copy of the
model's parameters. Additionally, for model evaluation purposes, use
:meth:'torch.nn.utils._stateless' to replace the module parameters with
the cloned copy.
:param activation: (list) model activations,
:param e: (list) modulatory signals,
:param params: (dict) model parameters (weights),
:param feedback: (dict) feedback connections,
:param Theta: (list) plasticity coefficients,
:param vec: (list) vector of plasticity rule indices. It determines which
plasticity rule is applied during the parameter update process.
:param fbk: (str) the type of feedback matrix used in the model:
1) 'sym', which indicates that the feedback matrix is symmetric;
2) 'fix', which indicates that the feedback matrix is a fixed random
matrix.)
:return: None.
"""
# -- weight update
i = 0
for k, p in params.items():
if 'fd' in k:
# -- pseudo-gradient update rule
if 'Linear' in p.classname:
p.update = - Theta[0] * torch.matmul(e[i + 1].T, activation[i].view(activation[i].size(0), -1))
elif 'Conv2d' in p.classname:
p.update = - Theta[0] * F.conv2d(activation[i].transpose(0, 1), e[i + 1].transpose(0, 1),
padding=p.padding).transpose(0, 1)
# -- eHebb update rule
if '2' in vec:
if 'Linear' in p.classname:
p.update -= Theta[1] * torch.matmul(e[i + 1].T, e[i])
elif 'Conv2d' in p.classname:
pass # todo: add eHebb for conv layers
# -- Oja update rule
if '9' in vec:
if 'Linear' in p.classname:
p.update -= Theta[2] * (torch.matmul(activation[i + 1].T, activation[i]) -
torch.matmul(torch.matmul(activation[i + 1].T, activation[i + 1]), p))
elif 'Conv2d' in p.classname:
pass # todo: add Oja for conv layers
params[k] = p + p.update
params[k].__dict__.update({'classname': p.classname, 'padding': p.padding, 'out_channels': p.out_channels,
'pool_setting': p.pool_setting})
# -- max-pooling pre-pooling post-synaptic activation
if p.pool_setting is not None:
activation[i + 1] = F.max_pool2d(activation[i + 1], *p.pool_setting)
i += 1
# -- enforce symmetric feedbacks
if fbk == 'sym':
forward = [k for k, v in params.items() if 'fd' in k and 'weight' in k]
for (k, B), fd_k in zip(feedback.items(), forward):
params[k].data = params[fd_k]
params[k].__dict__.update({'classname': B.classname, 'padding': B.padding, 'out_channels': B.out_channels,
'pool_setting': B.pool_setting})
class MyOptimizer:
def __init__(self, update_rule, vec, fbk, Theta):
"""
Initialize the optim
:param update_rule: (function) weight update function,
:param vec: (list) vector of plasticity rule indices. It determines which plasticity
rule is applied during the parameter update process.
:param fbk: (str) the type of feedback matrix used in the model:
1) 'sym', which indicates that the feedback matrix is symmetric;
2) 'fix', which indicates that the feedback matrix is a fixed random matrix.)
:param Theta: meta-parameters
"""
self.update_rule = update_rule
self.vec = vec
self.fbk = fbk
self.Theta = Theta
@staticmethod
def error_signal(logits, label, activation, feedback, Beta):
"""
Error signal
This function computes the error signal for each layer of the network.
:param logits: (torch.Tensor) logits,
:param label: (torch.Tensor) labels,
:param activation: (list) model activations,
:param feedback: (dict) feedback connections,
:param Beta: (float) non-linearity smoothness parameter.
:return: e: (list) error signal.
"""
# -- max-pooling pre-pooling activations
activation = list(activation)
for i, (y, B) in enumerate(zip(activation[1:], feedback.values())):
if B.pool_setting is not None:
activation[i+1], B.pool_indices = F.max_pool2d(y, *B.pool_setting, return_indices=True)
# -- error signal
e = [F.softmax(logits) - F.one_hot(label, num_classes=logits.shape[1])]
for y, B in zip(reversed(activation), reversed(list(feedback.values()))):
if 'Linear' in B.classname:
e.insert(0, torch.matmul(e[0], B) * (1 - torch.exp(-Beta * y.view(y.size(0), -1))))
elif 'Conv2d' in B.classname:
# -- reshape error signal for conv2D->linear transition
if len(e[0].shape) == 2:
dim = int((e[0].shape[1] / B.out_channels) ** 0.5)
e[0] = e[0].view(-1, B.out_channels, dim, dim)
# -- max-unpooling post-pooling error signals
if B.pool_setting is not None:
e[0] = F.max_unpool2d(e[0], B.pool_indices, *B.pool_setting)
e.insert(0, F.conv_transpose2d(e[0], B, padding=B.padding) * (1 - torch.exp(-Beta * y)))
return e
def __call__(self, params, logits, label, activation, Beta):
"""
One step update of the inner-loop (derived formulation).
:param params: model parameters
:param logits: unnormalized prediction values
:param label: target class
:param activation: vector of activations
:param Beta: smoothness coefficient for non-linearity
:return: None.
"""
# -- error signal
feedback = {k: v for k, v in params.items() if 'fk' in k}
e = self.error_signal(logits, label, activation, feedback, Beta)
# -- weight update
self.update_rule([*activation, F.softmax(logits, dim=1)], e, params, feedback, self.Theta, self.vec, self.fbk)