-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
382 lines (307 loc) · 13.8 KB
/
app.py
File metadata and controls
382 lines (307 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# Libraries
import time
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# ============================================================================
# PAPER IMPLEMENTATION: Perceptron-based Pooling Operations
# ============================================================================
class PerceptronPool2d(nn.Module):
"""
Single perceptron as pooling operation (Paper's core contribution)
- No activation function (paper found this works better - see Table 2)
- Parameters: W*H weights + 1 bias
"""
def __init__(self, kernel_size=2, stride=2):
super(PerceptronPool2d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
# Initializing weights similar to average pooling (0.25 for 2x2)
# but with small random variations as per paper's initialization strategy
init_value = 1.0 / (kernel_size * kernel_size)
self.weight = nn.Parameter(
torch.randn(kernel_size * kernel_size) * 0.1 + init_value
)
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x):
# x shape: [batch, channels, height, width]
batch, channels, h, w = x.shape
# Calculating output dimensions
out_h = (h - self.kernel_size) // self.stride + 1
out_w = (w - self.kernel_size) // self.stride + 1
# Extracting sliding windows using unfold
# Output: [batch, channels * kernel_size^2, num_patches]
patches = nn.functional.unfold(
x,
kernel_size=self.kernel_size,
stride=self.stride
)
# Reshape: [batch, channels, kernel_size^2, num_patches]
patches = patches.reshape(batch, channels, self.kernel_size**2, -1)
# Appling perceptron: weighted sum + bias (NO activation function)
output = torch.einsum('bckp,k->bcp', patches, self.weight) + self.bias
# Reshaping back to spatial dimensions
output = output.reshape(batch, channels, out_h, out_w)
return output
class NeuralNetPool2d(nn.Module):
"""
Multi-layer neural network as pooling (NN-4-1 from Table 3)
- Hidden layer: 4 perceptrons (no activation)
- Output layer: 1 perceptron (no activation)
- Total: 50 additional parameters
"""
def __init__(self, kernel_size=2, stride=2, hidden_neurons=4):
super(NeuralNetPool2d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
input_size = kernel_size * kernel_size
# First layer: 4 perceptrons
init_value = 1.0 / input_size
self.layer1_weights = nn.Parameter(
torch.randn(hidden_neurons, input_size) * 0.1 + init_value
)
self.layer1_bias = nn.Parameter(torch.zeros(hidden_neurons))
# Output layer: 1 perceptron
self.layer2_weights = nn.Parameter(
torch.randn(1, hidden_neurons) * 0.1 + (1.0 / hidden_neurons)
)
self.layer2_bias = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch, channels, h, w = x.shape
out_h = (h - self.kernel_size) // self.stride + 1
out_w = (w - self.kernel_size) // self.stride + 1
# Extracting patches
patches = nn.functional.unfold(
x, kernel_size=self.kernel_size, stride=self.stride
)
patches = patches.reshape(batch, channels, self.kernel_size**2, -1)
# Hidden layer (no activation function)
hidden = torch.einsum('bckp,hk->bchp', patches, self.layer1_weights)
hidden = hidden + self.layer1_bias.view(1, 1, -1, 1)
# Output layer (no activation function)
output = torch.einsum('bchp,oh->bcop', hidden, self.layer2_weights)
output = output + self.layer2_bias.view(1, 1, -1, 1)
output = output.squeeze(2).reshape(batch, channels, out_h, out_w)
return output
# ============================================================================
# Network Architecture
# ============================================================================
class Net(nn.Module):
"""
Network with perceptron-based pooling
Choose pool_type: 'perceptron', 'nn-4-1', or 'max' (baseline)
"""
def __init__(self, pool_type='perceptron'):
super(Net, self).__init__()
# Convolutional layers (unchanged)
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
# Pooling layers - THIS IS THE KEY CHANGE COMPARED TO THE LAST CODE I WROTE
if pool_type == 'perceptron':
print("Using Perceptron Pooling (10 params per layer)")
self.pool1 = PerceptronPool2d(kernel_size=2, stride=2)
self.pool2 = PerceptronPool2d(kernel_size=2, stride=2)
self.pool3 = PerceptronPool2d(kernel_size=2, stride=2)
elif pool_type == 'nn-4-1':
print("Using NN-4-1 Pooling (50 params per layer)")
self.pool1 = NeuralNetPool2d(kernel_size=2, stride=2, hidden_neurons=4)
self.pool2 = NeuralNetPool2d(kernel_size=2, stride=2, hidden_neurons=4)
self.pool3 = NeuralNetPool2d(kernel_size=2, stride=2, hidden_neurons=4)
else: # max pooling baseline
print("Using Max Pooling (baseline)")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# Fully connected layer
self.fc1 = nn.Linear(128 * 4 * 4, 100)
def forward(self, x):
x = self.bn1(torch.relu(self.conv1(x)))
x = self.pool1(x)
x = self.bn2(torch.relu(self.conv2(x)))
x = self.pool2(x)
x = self.bn3(torch.relu(self.conv3(x)))
x = self.pool3(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
return x
# ============================================================================
# Improved Initialization and Training
# ============================================================================
def weights_init(m):
"""Initializing weights for conv and fc layers"""
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def evaluate(model, dataloader, device):
"""Evaluating model accuracy on validation/test set"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
model.train()
return 100 * correct / total
# ============================================================================
# Main Training Loop (Following Paper's CIFAR-100 Setup)
# ============================================================================
def main():
start_time = time.perf_counter()
# ========================================================================
# Data Setup (Following paper specifications - page 5)
# ========================================================================
# Training transforms: normalize to zero mean and unit std + data augmentation
transform_train = transforms.Compose([
transforms.Pad(4, fill=0), # Paper: 4 pixels zero padding on each side
transforms.RandomCrop(32), # Paper: crop 32x32 from 40x40
transforms.ToTensor(),
# Paper specifies: normalize to zero mean and one standard deviation
# Using CIFAR-100 statistics
transforms.Normalize(
mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761]
)
])
# Test transforms: only normalize (no augmentation)
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761]
)
])
# Loading datasets
trainset = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=transform_train
)
testset = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=transform_test
)
# Data loaders
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=100, shuffle=True, num_workers=2,
pin_memory=torch.cuda.is_available()
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2,
pin_memory=torch.cuda.is_available()
)
# ========================================================================
# Model Setup
# ========================================================================
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using device: {device}")
# Pooling
net = Net(pool_type='perceptron').to(device)
net.apply(weights_init)
# ========================================================================
# Training Setup (Following paper specifications - page 5)
# ========================================================================
criterion = nn.CrossEntropyLoss()
# Paper specifies for CIFAR-100:
# - Optimizer: SGD with momentum (0.9)
# - Learning rate: 0.1
# - Weight decay: 5e-4
# - LR schedule: reduce by 0.1 at epochs 80 and 120
# - Total epochs: 160
# Separating pooling parameters for reduced learning rate
pooling_params = []
other_params = []
for name, param in net.named_parameters():
if 'pool' in name:
pooling_params.append(param)
else:
other_params.append(param)
# Paper: reduce learning rate by factor 10^-1 for pooling layers
optimizer = torch.optim.SGD([
{'params': other_params},
{'params': pooling_params, 'lr': 0.01, 'weight_decay': 0} # Factor 0.1
], lr=0.1, momentum=0.9, weight_decay=5e-4)
# Paper schedule: reduce LR at epochs 80 and 120
total_epochs = 160
print(f"\nTraining for {total_epochs} epochs")
print(f"Pooling parameters: {len(pooling_params)}")
print(f"Other parameters: {len(other_params)}")
print("-" * 60)
# ========================================================================
# Training Loop
# ========================================================================
best_acc = 0.0
for epoch in range(total_epochs):
# Adjusting learning rate at specified epochs
if epoch == 80 or epoch == 120:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
print(f"Learning rate reduced at epoch {epoch + 1}")
net.train()
running_loss = 0.0
total_loss = 0.0 # Track full epoch loss (fixed bug)
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = net(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Track loss
running_loss += loss.item()
total_loss += loss.item()
# Print every 200 mini-batches
if i % 200 == 199:
avg_loss = running_loss / 200
print(f'[{epoch + 1:3d}, {i + 1:5d}] loss: {avg_loss:.3f}')
running_loss = 0.0
# Safety check
if torch.isnan(loss) or torch.isinf(loss):
print(f"NaN or Inf detected in loss at epoch {epoch}, batch {i}")
return
# Calculating epoch statistics
epoch_loss = total_loss / len(trainloader)
# Evaluating on test set every 5 epochs
if (epoch + 1) % 5 == 0 or epoch == 0:
test_acc = evaluate(net, testloader, device)
print(f'Epoch {epoch + 1}/{total_epochs} | '
f'Loss: {epoch_loss:.3f} | '
f'Test Acc: {test_acc:.2f}%')
if test_acc > best_acc:
best_acc = test_acc
# Optionally save best model
# torch.save(net.state_dict(), 'best_model.pth')
else:
print(f'Epoch {epoch + 1}/{total_epochs} | Loss: {epoch_loss:.3f}')
# ========================================================================
# Final Evaluation
# ========================================================================
final_acc = evaluate(net, testloader, device)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print("\n" + "=" * 60)
print(f'Training Completed!')
print(f'Time taken: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)')
print(f'Best test accuracy: {best_acc:.2f}%')
print(f'Final test accuracy: {final_acc:.2f}%')
print("=" * 60)
# Calculation of additional parameters from pooling
pooling_param_count = sum(p.numel() for p in pooling_params)
print(f'\nAdditional parameters from pooling: {pooling_param_count}')
return net, final_acc
if __name__ == '__main__':
main()