-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathduc.py
More file actions
104 lines (82 loc) · 2.86 KB
/
duc.py
File metadata and controls
104 lines (82 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo
from torchvision import models
import math
class DUC(nn.Module):
def __init__(self, inplanes, planes, upscale_factor=2):
super(DUC, self).__init__()
self.relu = nn.ReLU()
self.conv = nn.Conv2d(inplanes, planes, kernel_size=3,
padding=1)
self.bn = nn.BatchNorm2d(planes)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.pixel_shuffle(x)
return x
class FCN(nn.Module):
def __init__(self, num_classes):
super(FCN, self).__init__()
self.num_classes = num_classes
resnet = models.resnet50(pretrained=True)
self.conv1 = resnet.conv1
self.bn0 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.duc1 = DUC(2048, 2048*2)
self.duc2 = DUC(1024, 1024*2)
self.duc3 = DUC(512, 512*2)
self.duc4 = DUC(128, 128*2)
self.duc5 = DUC(64, 64*2)
self.out1 = self._classifier(1024)
self.out2 = self._classifier(512)
self.out3 = self._classifier(128)
self.out4 = self._classifier(64)
self.out5 = self._classifier(32)
self.transformer = nn.Conv2d(320, 128, kernel_size=1)
def _classifier(self, inplanes):
if inplanes == 32:
return nn.Sequential(
nn.Conv2d(inplanes, self.num_classes, 1),
nn.Conv2d(self.num_classes, self.num_classes,
kernel_size=3, padding=1)
)
return nn.Sequential(
nn.Conv2d(inplanes, inplanes/2, 3, padding=1, bias=False),
nn.BatchNorm2d(inplanes/2, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(.1),
nn.Conv2d(inplanes/2, self.num_classes, 1),
)
def forward(self, x):
x = self.conv1(x)
x = self.bn0(x)
x = self.relu(x)
conv_x = x
x = self.maxpool(x)
pool_x = x
fm1 = self.layer1(x)
fm2 = self.layer2(fm1)
fm3 = self.layer3(fm2)
fm4 = self.layer4(fm3)
dfm1 = fm3 + self.duc1(fm4)
out16 = self.out1(dfm1)
dfm2 = fm2 + self.duc2(dfm1)
out8 = self.out2(dfm2)
dfm3 = fm1 + self.duc3(dfm2)
dfm3_t = self.transformer(torch.cat((dfm3, pool_x), 1))
out4 = self.out3(dfm3_t)
dfm4 = conv_x + self.duc4(dfm3_t)
out2 = self.out4(dfm4)
dfm5 = self.duc5(dfm4)
out = self.out5(dfm5)
return out, out2, out4, out8, out16