-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
457 lines (390 loc) · 17.5 KB
/
model.py
File metadata and controls
457 lines (390 loc) · 17.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
from torch import nn
import copy
import torch
import torch.nn.functional as F
import math
import numpy as np
c = copy.deepcopy
class SelfTransformer:
def __init__(self,N=6, d_embed=512, d_ff=2048, h=8, dropout=0.1):
self.N = N
self.d_embed = d_embed
self.d_ff = d_ff
self.h = h
self.dropout = dropout
self.ffn = FFN(d_embed, d_ff, dropout)
def pad_mask(self, input):
return (input != 0).unsqueeze(1).unsqueeze(2) # [B, 1, 1, T] ✅
def sub_mask(self, input):
if isinstance(input, int):
seq_len = input
else:
seq_len = input.size(1)
attn_shape = (1, seq_len, seq_len)
subsequent = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent) == 0
def ende_model(self,src_vocab_size,tar_vocab_size):
attn = MultiHeadedAttention(self.h, self.d_embed)
pos_emb = PositionalEmbedding(self.d_embed,self.dropout)
model = EnDecoder(
ED_Encoder(ED_EnLayer(self.d_embed, c(attn), c(self.ffn), self.dropout), self.N),
ED_Decoder(ED_DeLayer(self.d_embed, c(attn), c(attn), c(self.ffn), self.dropout), self.N),
nn.Sequential(WordEmbedding(self.d_embed, src_vocab_size), c(pos_emb)),
nn.Sequential(WordEmbedding(self.d_embed, tar_vocab_size), c(pos_emb)),
FC(tar_vocab_size,self.d_embed)
)
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform(p)
return model
def deonly_model(self,vocab_size,PE="Sinusoidal"):
match PE:
case "Sinusoidal":
attn = MultiHeadedAttention(self.h, self.d_embed)
pos_emb = PositionalEmbedding(self.d_embed,self.dropout)
case "Alibi":
attn = AlibiAttention(self.h, self.d_embed, self.dropout)
pos_emb = nn.Identity()
case "SAliBi":
attn = SAlibiAttention(self.h, self.d_embed, self.dropout)
pos_emb = nn.Identity()
case "RoPE":
attn = RoPEAttention(self.h, self.d_embed, self.dropout)
pos_emb = nn.Identity()
case _:
raise ValueError(f"Unsupported PE type: {PE}")
model = DecoderOnly(
DO_Decoder(DO_DeLayer(self.d_embed, c(attn), c(self.ffn), self.dropout),self.N),
nn.Sequential(WordEmbedding(self.d_embed, vocab_size), c(pos_emb)),
FC(vocab_size,self.d_embed)
)
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class AddNorm(nn.Module):
def __init__(self, d_embed, dropout):
super().__init__()
self.norm = nn.LayerNorm(d_embed, eps=1e-6)
self.dropout = nn.Dropout(dropout)
def forward(self, x, layer):
# ?layer必须是以x为输入的函数
layer_out = layer(x)
return self.norm(x + self.dropout(layer_out))
# 这是自定义的LayerNorm实现,如果你想使用PyTorch内置的LayerNorm,可以直接用nn.LayerNorm
# 等效的,可以在此进行自定义修改
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
# 每个w的缩放系数
self.b_2 = nn.Parameter(torch.zeros(features))
# 每个b的偏置的偏移量
self.eps = eps
# 防止除0错误
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class FFN(nn.Module):
'''
位置感知前馈网络(FFN-PositionWise Feed Forward)
先升维,非线性,再降维
论文中公式: FFN(x) = max(0, xW1 + b1)W2 + b2
这里用ReLU作为激活函数
扩维:学习更丰富的中间表示,增加模型参数(占总参数的70%)
位置独立建模: 每个位置的FFN参数相同, 和位置无关 (PositonWise)
d_ff=2048是平衡了表达力和计算量的结果, BERT有时候会用4*768维度
Attention 是全局感知, FFN是局部感知
'''
def __init__(self, d_embed=512, d_ff=2048, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_embed, d_ff)
self.w_2 = nn.Linear(d_ff, d_embed)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class FC(nn.Module):
def __init__(self, vocab_size,d_embed=512):
super(FC, self).__init__()
self.proj = nn.Linear(d_embed, vocab_size)
def forward(self, x):
return F.log_softmax(self.proj(x), dim=-1)
class WordEmbedding(nn.Module):
def __init__(self, d_embed, vocab_size):
super().__init__()
self.d_embed = d_embed
self.emb_tab = nn.Embedding(vocab_size, d_embed)
def forward(self, x):
return self.emb_tab(x) * math.sqrt(self.d_embed)
class PositionalEmbedding(nn.Module):
def __init__(self, d_embed, dropout=0.1, seq_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pos_emb = torch.zeros(seq_len, d_embed)
pos = torch.arange(0, seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))
pos_emb[:, 0::2] = torch.sin(pos * div_term)
pos_emb[:, 1::2] = torch.cos(pos * div_term)
pos_emb = pos_emb.unsqueeze(0)
self.register_buffer('pos_emb', pos_emb)
def forward(self, x):
return self.dropout(x+self.pos_emb[:, :x.size(1)])
def apply_rope(q, k, cos, sin):
"""
对 Q、K 应用旋转位置编码(按最后一维的 (0,1),(2,3),... 配对旋转)。
q, k: [B, H, L, d],d 为每头维度
cos, sin: [L, d//2] 或可广播到该形状
"""
d = q.size(-1)
assert d % 2 == 0
# [B, H, L, d] -> [B, H, L, d/2, 2]
q0, q1 = q[..., 0::2], q[..., 1::2]
k0, k1 = k[..., 0::2], k[..., 1::2]
# cos/sin: [L, d/2] -> 与 q0 等对齐
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, L, d/2]
sin = sin.unsqueeze(0).unsqueeze(0)
q_rot0 = q0 * cos - q1 * sin
q_rot1 = q0 * sin + q1 * cos
k_rot0 = k0 * cos - k1 * sin
k_rot1 = k0 * sin + k1 * cos
# 交错还原为 [B, H, L, d]
q_rot = torch.stack([q_rot0, q_rot1], dim=-1).flatten(-2)
k_rot = torch.stack([k_rot0, k_rot1], dim=-1).flatten(-2)
return q_rot, k_rot
class RoPE(torch.nn.Module):
"""
预计算并缓存 cos/sin [max_seq_len, d_qkv/2],按序列长度截取后传给 apply_rope。
不持有可学习参数,仅作为 RoPEAttention 的内部依赖。
"""
def __init__(self, d_qkv, max_seq_len=8192, base=10000.0):
super().__init__()
self.d_qkv = d_qkv
inv_freq = 1.0 / (base ** (torch.arange(0, d_qkv, 2, dtype=torch.float32) / d_qkv))
pos = torch.arange(max_seq_len, dtype=torch.float32)
freqs = pos.unsqueeze(1) * inv_freq.unsqueeze(0)
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, q, k, seq_len=None):
"""
q, k: [B, H, L, d_qkv]
返回 (q_rot, k_rot),长度取 seq_len 或 q.size(2)
"""
L = seq_len if seq_len is not None else q.size(2)
cos = self.cos_cached[:L].to(dtype=q.dtype, device=q.device)
sin = self.sin_cached[:L].to(dtype=q.dtype, device=q.device)
return apply_rope(q, k, cos, sin)
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_embed, dropout=0.1):
super(MultiHeadedAttention, self).__init__()
assert d_embed % h == 0
self.d_qkv = d_embed // h
self.h = h
self.linears = clones(nn.Linear(d_embed, d_embed), 4)
# ?前三个是q,k,v的线性变换,最后一个是输出的线性变换
self.attn_prob = None
self.dropout = nn.Dropout(p=dropout)
def attention(self,query, key, value, mask=None):
d_qkv = query.size(-1)
simlarity = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_qkv)
if mask is not None:
simlarity = simlarity.masked_fill(mask == 0, -1e9)
attn_prob = F.softmax(simlarity, dim = -1)
if self.dropout is not None:
attn_prob = self.dropout(attn_prob)
return torch.matmul(attn_prob, value), attn_prob
def forward(self, query, key, value, mask=None):
# if mask is not None:
# mask = mask.unsqueeze(1)
no_batch = query.size(0)
# !对qkv分别应用三个不同线性层,[n_batch, seq_len, d_embed] => [n_batch, h, seq_len, d_qkv]
query, key, value = \
[
linear(qkv).view(no_batch, -1, self.h, self.d_qkv).transpose(1, 2)
for linear, qkv in zip(self.linears, (query, key, value))
]
# zip只把能配对的配对,多余的第四个忽略
# !attention的经典公式计算
x, self.attn_prob = self.attention(query, key, value, mask=mask)
# !把多头结果合并回去 [n_batch, h, seq_len, d_qkv] => [n_batch, seq_len, d_embed]
x = x.transpose(1, 2).contiguous().view(no_batch, -1, self.h * self.d_qkv)
# !最后再通过第四个线性层
return self.linears[-1](x)
class AlibiAttention(MultiHeadedAttention):
"""
ALiBi注意力,直接继承原MHA
m是ALiBi的斜率,alibi_bias是距离偏置矩阵
attention方法重载添加距离偏置
"""
def __init__(self, h, d_embed, dropout=0.1, max_seq_len=8192,a = None):
super().__init__(h, d_embed, dropout)
m = torch.arange(1, h + 1, dtype=torch.float32)
m = (-1.0 * torch.pow(5000.0, (m - 1) / (h - 1) * 8.0 / 10.0)).exp()
self.register_buffer('m', m[None, :, None, None])
seq = torch.arange(max_seq_len, dtype=torch.float32)
dist = seq[:, None] - seq[None, :]
self.register_buffer('alibi_bias', self.m * dist)
def attention(self, query, key, value, mask=None):
d_qkv = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_qkv)
# ALiBi核心:添加距离偏置
tgt_len, src_len = scores.shape[-2:]
# bias = self.alibi_bias[..., :tgt_len, :src_len].to(scores.dtype)
bias = self.alibi_bias[:, :, :tgt_len, :src_len].to(dtype=scores.dtype, device=scores.device)
scores += bias
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_prob = F.softmax(scores, dim=-1)
if self.dropout is not None:
attn_prob = self.dropout(attn_prob)
return torch.matmul(attn_prob, value), attn_prob
class SAlibiAttention(MultiHeadedAttention):
"""
SAlibi4: 使用 log-sigmoid 距离的加性偏置版本:
Attention(i,j) = Q_i K_j^T / sqrt(d) + bias_ij
bias_ij = -m_h * θ(a + |i-j|)
θ(x) = log(σ(x)) = -log(1 + e^x)
σ(x) = 1 / (1 + e^x)
其中 a 为可调超参数(偏移),用于控制整体衰减强度。
"""
def __init__(self, h, d_embed, dropout=0.1, max_seq_len=8192, a: float = 1000.0):
super().__init__(h, d_embed, dropout)
# 与 Alibi 相同的 head 斜率构造
m = torch.arange(1, h + 1, dtype=torch.float32)
m = (-1.0 * torch.pow(5000.0, (m - 1) / (h - 1) * 8.0 / 10.0)).exp()
self.register_buffer('m', m[None, :, None, None]) # [1, h, 1, 1]
self.register_buffer('a', torch.tensor(float(a))) # 标量超参数 a
# 预计算 |i-j| 距离矩阵
seq = torch.arange(max_seq_len, dtype=torch.float32)
dist = torch.abs(seq[:, None] - seq[None, :]) # [L, L]
self.register_buffer('dist', dist[None, None, ...]) # [1, 1, L, L]
def attention(self, query, key, value, mask=None):
d_qkv = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_qkv) # [B, h, T, S]
tgt_len, src_len = scores.shape[-2:]
dist = self.dist[..., :tgt_len, :src_len].to(dtype=scores.dtype, device=scores.device) # [1,1,T,S]
x = self.a + dist
theta = -F.softplus(x)
bias = -self.m * theta
scores = scores + bias
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_prob = F.softmax(scores, dim=-1)
if self.dropout is not None:
attn_prob = self.dropout(attn_prob)
return torch.matmul(attn_prob, value), attn_prob
class RoPEAttention(MultiHeadedAttention):
"""
RoPE 注意力:在 Q、K 上应用旋转位置编码后再做缩放点积注意力。
不依赖位置嵌入层,与 AlibiAttention 一样仅通过注意力内部注入位置信息。
"""
def __init__(self, h, d_embed, dropout=0.1, max_seq_len=8192):
super().__init__(h, d_embed, dropout)
self.rope = RoPE(self.d_qkv, max_seq_len=max_seq_len)
def forward(self, query, key, value, mask=None):
no_batch = query.size(0)
query, key, value = [
linear(qkv).view(no_batch, -1, self.h, self.d_qkv).transpose(1, 2)
for linear, qkv in zip(self.linears, (query, key, value))
]
query, key = self.rope(query, key)
x, self.attn_prob = self.attention(query, key, value, mask=mask)
x = x.transpose(1, 2).contiguous().view(no_batch, -1, self.h * self.d_qkv)
return self.linears[-1](x)
class EnDecoder(nn.Module):
def __init__(self, encoder, decoder, src_embed, tar_embed, FC):
super(EnDecoder, self).__init__()
self.encoder = encoder
self.src_embed = src_embed
self.tar_embed = tar_embed
self.decoder = decoder
self.FC = FC
def forward(self, src, tar, src_mask, tar_mask):
encoded = self.encoder(self.src_embed(src), src_mask)
decoded = self.decoder(self.tar_embed(tar), encoded, src_mask, tar_mask)
# return self.FC(decoded)
# 说是用来预测下一个token的,所以decoded[:, :-1, :],去掉最后一个时间步的输出
return self.FC(decoded[:, :-1, :])
class ED_EnLayer(nn.Module):
def __init__(self, d_embed, attn, ffn, dropout):
super().__init__()
self.d_embed = d_embed
self.attn = attn
self.ffn = ffn
self.attn_norm = AddNorm(d_embed, dropout)
self.ffn_norm = AddNorm(d_embed, dropout)
def forward(self, x, mask):
x = self.attn_norm(x, lambda x: self.attn(x, x, x, mask))
x = self.ffn_norm(x, self.ffn)
# ?attn输入不是x,需要用lambda包装一下
# ?ffn输入是x,可以直接传入
return x
class ED_DeLayer(nn.Module):
def __init__(self, d_embed, attn, src_attn, feed_forward, dropout):
super().__init__()
self.d_embed = d_embed
self.attn = attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.attn_norm = AddNorm(d_embed, dropout)
self.src_attn_norm = AddNorm(d_embed, dropout)
self.ff_norm = AddNorm(d_embed, dropout)
def forward(self, x, memory, src_mask, tgt_mask):
m = memory
x = self.attn_norm(x, lambda x: self.attn(x, x, x, tgt_mask)) # 自注意力
x = self.src_attn_norm(x, lambda x: self.src_attn(x, m, m, src_mask)) # 交叉注意力
x = self.ff_norm(x, self.feed_forward) # FFN
return x
class ED_Encoder(nn.Module):
def __init__(self, layer, N): # layers → layer(单例克隆)
super().__init__()
self.layers = clones(layer, N)
self.norm = nn.LayerNorm(layer.d_embed)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class ED_Decoder(nn.Module):
def __init__(self, layer, N):
super().__init__()
self.layers = clones(layer, N)
self.norm = nn.LayerNorm(layer.d_embed)
def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
class DecoderOnly(nn.Module):
def __init__(self, decoder, embed, FC):
super().__init__()
self.embed = embed
self.decoder = decoder
self.FC = FC
# 通常这里的mask是causal mask & padding mask的结合
def forward(self, x, mask):
x = self.embed(x)
x = self.decoder(x, mask)
return self.FC(x)
class DO_Decoder(nn.Module):
def __init__(self, layer, N):
super().__init__()
self.layers = clones(layer, N)
self.norm = nn.LayerNorm(layer.d_embed)
def forward(self, x, tgt_mask):
for layer in self.layers:
x = layer(x, tgt_mask)
return self.norm(x)
class DO_DeLayer(nn.Module):
def __init__(self, d_embed, attn, feed_forward, dropout):
super().__init__()
self.d_embed = d_embed
self.attn = attn # 只剩自注意力!
self.feed_forward = feed_forward
self.attn_norm = AddNorm(d_embed, dropout)
self.ff_norm = AddNorm(d_embed, dropout)
def forward(self, x, tgt_mask):
x = self.attn_norm(x, lambda x: self.attn(x, x, x, tgt_mask))
x = self.ff_norm(x, self.feed_forward)
return x