forked from calum-green/OpenLSR-X
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathblocks.py
More file actions
110 lines (88 loc) · 2.72 KB
/
blocks.py
File metadata and controls
110 lines (88 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
The purpose of this file is to define all of the building blocks of
the used SRGAN architecture as seen in original paper
https://arxiv.org/abs/1609.04802v5
"""
import pytorch_lightning as pl
from torch import nn
from loss import PixelNorm
class ConvBlock(pl.LightningModule):
"""
Conv -> BatchNorm -> ReLU activation
"""
def __init__(
self,
in_channels,
out_channels,
discriminator=False,
use_act=True,
use_bn=True,
**kwargs,
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
# self.bn = nn.BatchNorm2d(out_channels) if use_bn is True else nn.Identity()
self.bn = PixelNorm() if use_bn is True else nn.Identity()
# self.bn = nn.InstanceNorm2d(out_channels) if use_bn is True else nn.Identity()
self.act = (
nn.LeakyReLU(0.2, inplace=True)
if discriminator is True
else nn.PReLU(num_parameters=out_channels)
)
self.use_act = use_act
def forward(self, x):
return (
self.act(self.bn(self.conv(x)))
if self.use_act is True
else self.bn(self.conv(x))
)
class UpsampleBlock(pl.LightningModule):
"""
This blocks performs the upsampling using
nn.PixelShuffle with a specific scale factor
"""
def __init__(self, in_channels, scale_factor):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
in_channels * scale_factor**2,
kernel_size=3,
stride=1,
padding=1,
padding_mode="reflect",
)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
self.act = nn.PReLU(num_parameters=in_channels)
def forward(self, x):
return self.act(self.pixel_shuffle(self.conv(x)))
class ResidualBlock(pl.LightningModule):
"""
Residual blocks from SRGAN Generator
"""
def __init__(self, in_channels):
super().__init__()
self.block = ConvBlock(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1,
padding_mode="reflect",
)
self.block2 = ConvBlock(
in_channels=in_channels,
out_channels=in_channels,
use_act=False,
kernel_size=3,
stride=1,
padding=1,
padding_mode="reflect",
)
def forward(self, x):
"""
Pass through residual blocks, then concatenate output with input
to form the Elementwise Sum
"""
out = self.block(x)
out = self.block2(out)
return out + x