-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
88 lines (76 loc) · 2.76 KB
/
model.py
File metadata and controls
88 lines (76 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch.nn as nn
import numpy as np
class Generator(nn.Module):
# this generator is a simplified version of DCGAN to speed up the training
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
# input size 100 x 1 x 1
nn.ConvTranspose2d(100, 1024, 4, 1, 0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU(),
# size 512 x 4 x 4
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(),
# state size. 256 x 8 x 8
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
# size 128 x 16 x 16
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(),
# size 128 x 32 x 32
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
# size 64 x 64 x 64
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
nn.Hardsigmoid()
# output size 1 x 128 x 128
)
def forward(self, z):
if z.shape[-1] != 1:
# change the shape from (batch_size, 100) to (batch_size, 100, 1, 1)
z = z[..., None, None]
output = self.net(z)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# input size batch x 1 x 128 x 128
nn.Conv2d(1, 256, 4, 2, 1, bias=False) ,
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Dropout(0.3),
# input size batch x 256
nn.Conv2d(256, 512, 4, 2, 1, bias=False) ,
nn.BatchNorm2d(512),
nn.LeakyReLU(),
nn.Flatten()
# input size batch x 524288
)
# Linear layer for determining real/fake
self.disc = nn.Sequential(
nn.Linear(524288,1),
nn.Sigmoid())
# Linear layers for classify the image
self.classify = nn.Sequential(
nn.Linear(524288,512),
nn.LeakyReLU(),
nn.Linear(512, 19)
# nn.Softmax(dim=1))
)
def load(self, backup):
for m_from, m_to in zip(backup.modules(), self.modules()):
if isinstance(m_to, nn.Linear):
m_to.weight.data = m_from.weight.data.clone()
if m_to.bias is not None:
m_to.bias.data = m_from.bias.data.clone()
def forward(self, img):
x = self.net(img)
r_out = self.disc(x)
c_out = self.classify(x)
return r_out, c_out