-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualization.py
More file actions
40 lines (33 loc) · 1.1 KB
/
visualization.py
File metadata and controls
40 lines (33 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from torchfusion.layers import *
from torchfusion.metrics import *
from torchfusion.datasets import *
import torch.nn as nn
import torch.cuda as cuda
from torchfusion.learners import *
from torch.optim import Adam
from torchfusion.utils import VisdomLogger
model = nn.Sequential(
Flatten(),
Linear(784, 100),
Swish(),
Linear(100, 100),
Swish(),
Linear(100, 100),
Swish(),
Linear(100, 100),
Swish(),
Linear(100, 10)
)
if cuda.is_available():
model.cuda()
optimizer = Adam(model.parameters())
fmnist_train = fashionmnist_loader(28,batch_size=128)
fmnist_test = fashionmnist_loader(28,batch_size=128,train=False)
train_metrics = [Accuracy()]
test_metrics = [Accuracy()]
loss_fn = nn.CrossEntropyLoss()
visdom_loggger = VisdomLogger()
if __name__ == "__main__":
learner = StandardLearner(model)
print(learner.summary((1,28,28)))
learner.train(fmnist_train,loss_fn,optimizer=optimizer,train_metrics=train_metrics,test_loader=fmnist_test,test_metrics=test_metrics,tensorboard_log="./tboard_logs",visdom_log=visdom_loggger,num_epochs=25,model_dir="./fashion-mnist-model")