-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSViT.py
More file actions
189 lines (126 loc) · 7.56 KB
/
SViT.py
File metadata and controls
189 lines (126 loc) · 7.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# https://github.com/hhb072/SViT
""""
以下是该模块的主要组件和功能:
Unfold 操作:Unfold 类定义了一个卷积操作,用于将输入图像进行解展开(unfolding)。具体来说,它将输入图像划分成不重叠的局部块,并将这些块展平成向量。这有助于在局部区域之间建立联系。
Fold 操作:Fold 类定义了一个卷积转置操作,用于将展开的局部块还原为原始的图像形状。这有助于将局部特征重新组合成图像。
Attention 操作:Attention 类定义了一个加性注意力机制,用于计算局部块之间的关联权重。通过对展开的局部块执行注意力操作,可以确定不同块之间的相关性,从而更好地捕获局部特征。
Stoken 操作:StokenAttention 类将图像划分为多个小块,并在这些小块之间执行加性注意力操作。它还包括对块之间的关系进行迭代更新的逻辑,以更好地捕获图像中的局部特征。
直接传递操作:direct_forward 方法用于直接传递输入图像,而不进行块划分和注意力操作。这对于某些情况下不需要局部特征建模的情况很有用。
Stoken 操作和直接传递操作的选择:根据 self.stoken_size 参数的设置,模块可以选择执行 Stoken 操作或直接传递操作。如果 self.stoken_size 的值大于 1,则执行 Stoken 操作,否则执行直接传递操作。
总的来说,这个模块提供了一种有效的方式来处理图像数据,并在图像的不同局部区域之间建立关联,以捕获局部特征。这对于许多计算机视觉任务,如目标检测和图像分割,都具有重要意义。
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Unfold(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
weights = torch.eye(kernel_size ** 2)
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
self.weights = nn.Parameter(weights, requires_grad=False)
def forward(self, x):
b, c, h, w = x.shape
x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2)
return x.reshape(b, c * 9, h * w)
class Fold(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
weights = torch.eye(kernel_size ** 2)
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
self.weights = nn.Parameter(weights, requires_grad=False)
def forward(self, x):
b, _, h, w = x.shape
x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
return x
class Attention(nn.Module):
def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.window_size = window_size
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim, dim, 1)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, C, H, W = x.shape
N = H * W
q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3,
dim=2) # (B, num_heads, head_dim, N)
attn = (k.transpose(-1, -2) @ q) * self.scale
attn = attn.softmax(dim=-2) # (B, h, N, N)
attn = self.attn_drop(attn)
x = (v @ attn).reshape(B, C, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
class StokenAttention(nn.Module):
def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0.):
super().__init__()
self.n_iter = n_iter
self.stoken_size = stoken_size
self.scale = dim ** - 0.5
self.unfold = Unfold(3)
self.fold = Fold(3)
self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=proj_drop)
def stoken_forward(self, x):
'''
x: (B, C, H, W)
'''
B, C, H0, W0 = x.shape
h, w = self.stoken_size
pad_l = pad_t = 0
pad_r = (w - W0 % w) % w
pad_b = (h - H0 % h) % h
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
_, _, H, W = x.shape
hh, ww = H // h, W // w
stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # (B, C, hh, ww)
pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
with torch.no_grad():
for idx in range(self.n_iter):
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
affinity_matrix = pixel_features @ stoken_features * self.scale # (B, hh*ww, h*w, 9)
affinity_matrix = affinity_matrix.softmax(-1) # (B, hh*ww, h*w, 9)
affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
affinity_matrix_sum = self.fold(affinity_matrix_sum)
if idx < self.n_iter - 1:
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(
B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum + 1e-12) # (B, C, hh, ww)
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12) # (B, C, hh, ww)
stoken_features = self.stoken_refine(stoken_features)
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9) # (B, hh*ww, C, 9)
pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # (B, hh*ww, C, h*w)
pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
if pad_r > 0 or pad_b > 0:
pixel_features = pixel_features[:, :, :H0, :W0]
return pixel_features
def direct_forward(self, x):
B, C, H, W = x.shape
stoken_features = x
stoken_features = self.stoken_refine(stoken_features)
return stoken_features
def forward(self, x):
if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
return self.stoken_forward(x)
else:
return self.direct_forward(x)
# 输入 N C H W, 输出 N C H W
if __name__ == '__main__':
input = torch.randn(3, 64, 32, 64).cuda()
se = StokenAttention(64, stoken_size=[8,8]).cuda()
output = se(input)
print(output.shape)