forked from Aaditya-Singh/SAFIN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
92 lines (73 loc) · 3.37 KB
/
utils.py
File metadata and controls
92 lines (73 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import numpy as np
import torch.nn as nn
def get_wav(in_channels, pool=True):
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
if pool: net = nn.Conv2d
else: net = nn.ConvTranspose2d
LL = net(in_channels, in_channels,
kernel_size=2, stride=2, padding=0, bias=False,
groups=in_channels)
LH = net(in_channels, in_channels,
kernel_size=2, stride=2, padding=0, bias=False,
groups=in_channels)
HL = net(in_channels, in_channels,
kernel_size=2, stride=2, padding=0, bias=False,
groups=in_channels)
HH = net(in_channels, in_channels,
kernel_size=2, stride=2, padding=0, bias=False,
groups=in_channels)
LL.weight.requires_grad = False
LH.weight.requires_grad = False
HL.weight.requires_grad = False
HH.weight.requires_grad = False
LL.weight.data = filter_LL.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
LH.weight.data = filter_LH.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
HL.weight.data = filter_HL.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
HH.weight.data = filter_HH.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
return LL, LH, HL, HH
class WavePool(nn.Module):
def __init__(self, in_channels):
super(WavePool, self).__init__()
self.LL, self.LH, self.HL, self.HH = get_wav(in_channels)
def forward(self, x):
return self.LL(x), self.LH(x), self.HL(x), self.HH(x)
class WaveUnpool(nn.Module):
def __init__(self, in_channels):
super(WaveUnpool, self).__init__()
self.in_channels = in_channels
self.LL, self.LH, self.HL, self.HH = get_wav(self.in_channels, pool=False)
def reshape(self, x, y):
assert len(x.shape)==len(y.shape)==4
y = y[:, :, :x.shape[2], :x.shape[3]]
return y
def forward(self, ll, lh, hl, hh):
lh = self.reshape(ll, lh); hl = self.reshape(ll, hl); hh = self.reshape(ll, hh)
return self.LL(ll) + self.LH(lh) + self.HL(hl) + self.HH(hh)
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def stat_transform(content_feat, style_feat):
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)