-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstructural_reparam.py
More file actions
139 lines (122 loc) · 6.24 KB
/
structural_reparam.py
File metadata and controls
139 lines (122 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# -*- coding: utf-8 -*-
# @Author : youngx
# @Time : 9:26 2022-06-08
import torch
import torch.nn as nn
import torch.nn.init as init
class Model(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, padding_mode='zeros', deploy=False,
use_affine=True, reduce_gamma=False, gamma_init=None):
super(Model, self).__init__()
self.deploy = deploy
self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(kernel_size, kernel_size), stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=False,
padding_mode=padding_mode)
self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
if padding - kernel_size // 2 >= 0:
# Common use case. E.g., k=3, p=1 or k=5, p=2
self.crop = 0
# Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust
# to align the sliding windows (Fig 2 in the paper)
hor_padding = [padding - kernel_size // 2, padding]
ver_padding = [padding, padding - kernel_size // 2]
else:
# A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping.
# Since nn.Conv2d does not support negative padding, we implement it manually
self.crop = kernel_size // 2 - padding
hor_padding = [0, padding]
ver_padding = [padding, 0]
self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(kernel_size, 1),
stride=stride,
padding=ver_padding, dilation=dilation, groups=groups, bias=False,
padding_mode=padding_mode)
self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(1, kernel_size),
stride=stride,
padding=hor_padding, dilation=dilation, groups=groups, bias=False,
padding_mode=padding_mode)
self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
def _fuse_bn_tensor(self, conv, bn):
std = (bn.running_var + bn.eps).sqrt()
t = (bn.weight / std).reshape(-1, 1, 1, 1)
return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std
def _add_to_square_kernel(self, square_kernel, asym_kernel):
asym_h = asym_kernel.size(2)
asym_w = asym_kernel.size(3)
square_h = square_kernel.size(2)
square_w = square_kernel.size(3)
square_kernel[:, :, square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h,
square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel
def get_equivalent_kernel_bias(self):
hor_k, hor_b = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
ver_k, ver_b = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
square_k, square_b = self._fuse_bn_tensor(self.square_conv, self.square_bn)
self._add_to_square_kernel(square_k, hor_k)
self._add_to_square_kernel(square_k, ver_k)
return square_k, hor_b + ver_b + square_b
def switch_to_deploy(self):
deploy_k, deploy_b = self.get_equivalent_kernel_bias()
self.deploy = True
self.fused_conv = nn.Conv2d(in_channels=self.square_conv.in_channels,
out_channels=self.square_conv.out_channels,
kernel_size=self.square_conv.kernel_size, stride=self.square_conv.stride,
padding=self.square_conv.padding, dilation=self.square_conv.dilation,
groups=self.square_conv.groups, bias=True,
padding_mode=self.square_conv.padding_mode)
self.__delattr__('square_conv')
self.__delattr__('square_bn')
self.__delattr__('hor_conv')
self.__delattr__('hor_bn')
self.__delattr__('ver_conv')
self.__delattr__('ver_bn')
self.fused_conv.weight.data = deploy_k
self.fused_conv.bias.data = deploy_b
def forward(self, input):
if self.deploy:
return self.fused_conv(input)
else:
square_outputs = self.square_conv(input)
square_outputs = self.square_bn(square_outputs)
if self.crop > 0:
ver_input = input[:, :, :, self.crop:-self.crop]
hor_input = input[:, :, self.crop:-self.crop, :]
else:
ver_input = input
hor_input = input
vertical_outputs = self.ver_conv(ver_input)
vertical_outputs = self.ver_bn(vertical_outputs)
horizontal_outputs = self.hor_conv(hor_input)
horizontal_outputs = self.hor_bn(horizontal_outputs)
result = square_outputs + vertical_outputs + horizontal_outputs
return result
def forward_fuse(self, x):
pass
if __name__ == '__main__':
N = 1
C = 2
H = 62
W = 62
O = 8
groups = 4
x = torch.randn(N, C, H, W)
print('input shape is ', x.size())
# test_kernel_padding = [(3, 1), (3, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 6)]
test_kernel_padding = [(3, 1)]
for k, p in test_kernel_padding:
acb = Model(C, O, kernel_size=k, padding=p, stride=1, deploy=False)
acb.eval()
for module in acb.modules():
if isinstance(module, nn.BatchNorm2d):
nn.init.uniform_(module.running_mean, 0, 0.1)
nn.init.uniform_(module.running_var, 0, 0.2)
nn.init.uniform_(module.weight, 0, 0.3)
nn.init.uniform_(module.bias, 0, 0.4)
out = acb(x)
acb.switch_to_deploy()
deployout = acb(x)
print('difference between the outputs of the training-time and converted ACB is')
print(((deployout - out) ** 2).sum())