-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathmodel.py
More file actions
647 lines (531 loc) · 25.6 KB
/
model.py
File metadata and controls
647 lines (531 loc) · 25.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import json
import logging
import math
import os
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import copy
from transformers.modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from transformers.modeling_gpt2 import *
from transformers.modeling_bert import gelu
from transformers.configuration_gpt2 import GPT2Config
from transformers.file_utils import add_start_docstrings
####################### auxiliary attention blocks #######################
class Unmasked_Attention(Attention):
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
if attention_mask is not None:
# Apply the attention mask
w = w + attention_mask
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
if self.output_attentions:
outputs.append(w)
return outputs
class Unmasked_Block(Block):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Unmasked_Attention(nx, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
class AverageSelfAttention(nn.Module):
def __init__(self, attention_size):
super(AverageSelfAttention, self).__init__()
w = torch.empty(attention_size)
nn.init.normal_(w, std=0.02)
self.attention_weights = nn.Parameter(w)
self.softmax = nn.Softmax(dim=-1)
self.non_linearity = gelu
def forward(self, inputs, attention_mask=None):
##################################################################
# STEP 1 - perform dot product
# of the attention vector and each hidden state
##################################################################
# inputs is a 3D Tensor: batch, len, hidden_size
# scores is a 2D Tensor: batch, len
scores = self.non_linearity(inputs.matmul(self.attention_weights))
##################################################################
# Step 2 - Masking
##################################################################
if attention_mask is not None:
scores = scores + attention_mask
##################################################################
# Step 3 - Weighted sum of hidden states, by the attention scores
##################################################################
scores = self.softmax(scores)
# multiply each hidden state with the attention weights
weighted = torch.mul(inputs, scores.unsqueeze(-1).expand_as(inputs))
# sum the hidden states
representations = weighted.sum(1).squeeze(1)
return representations, scores
# Pseudo self-attention
class Cond_Attention(Attention):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
# add code here
self.c_z = Conv1D(n_state * 2, nx)
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns - nd : ns, :ns]
w = w * b - 1e4 * (1 - b)
if attention_mask is not None:
# add code here: w size has been bsz * n_heads * L * (L+1), mask bsz * 1 * 1 * L
assert attention_mask.size()[-1] == w.size()[-1] - 1
zeros = torch.zeros(attention_mask.size()[:-1], device=attention_mask.device, dtype=attention_mask.dtype).unsqueeze(-1)
attention_mask = torch.cat((zeros, attention_mask), dim=-1)
# Apply the attention mask
w = w + attention_mask
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
if self.output_attentions:
outputs.append(w)
return outputs
def forward(self, x, z, layer_past=None, attention_mask=None, head_mask=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
z_conv = self.c_z(z)
key_z, value_z = z_conv.split(self.split_size, dim=2)
key_z = self.split_heads(key_z, k=True)
value_z = self.split_heads(value_z)
key = torch.cat((key_z, key), dim=-1)
value = torch.cat((value_z, value), dim=-2)
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
a = attn_outputs[0]
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
class Cond_Block(Block):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Cond_Attention(nx, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, z, layer_past=None, attention_mask=None, head_mask=None):
output_attn = self.attn(
self.ln_1(x), z, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
)
a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
outputs = [x] + output_attn[1:]
return outputs # x, present, (attentions)
####################### transformer-based vae #######################
class Encoder(GPT2Model):
def __init__(self, config):
super(GPT2Model, self).__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
# manually modify number of layers in encoder to accommodate GPU memory
n = 6 # config.n_layer
self.h = nn.ModuleList([Unmasked_Block(config.n_ctx, config, scale=True) for _ in range(n)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
# added code here
self.averageSelfAttention = AverageSelfAttention(config.n_embd)
nx = config.n_embd
nz = config.n_embd
self.mean = Conv1D(nz, nx)
self.logvar = Conv1D(nz, nx)
def forward(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
presents = ()
all_attentions = []
all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(
hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
)
hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# added code here
representations, _ = self.averageSelfAttention(hidden_states, attention_mask.squeeze(1).squeeze(1))
mean = self.mean(representations)
logvar = self.logvar(representations)
outputs = (mean, logvar, hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs # mean, logvar, last hidden state, (presents), (all hidden_states), (attentions)
class Decoder(GPT2Model):
def __init__(self, config, add_input=False, add_attn=False, attn_proj_vary=False):
super(GPT2Model, self).__init__(config)
# added code here
self.add_input = add_input
self.add_attn = add_attn
self.attn_proj_vary = attn_proj_vary
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
if self.add_input:
nz = config.n_embd
nx = config.n_embd
self.input_proj = nn.Linear(nz, nx, bias=False)
if self.add_attn:
nz = config.n_embd
nx = config.n_embd
n = config.n_layer
if self.attn_proj_vary:
self.attn_proj = nn.Linear(nz, nx * n, bias=False)
else:
self.attn_proj = nn.Linear(nz, nx, bias=False)
self.h = nn.ModuleList([Cond_Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
else:
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
def forward(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
representations=None
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
# add code here
if self.add_input:
assert (representations is not None)
input_proj = self.input_proj(representations).unsqueeze(1)
hidden_states = hidden_states + input_proj
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
# add code here
if self.add_attn:
assert (representations is not None)
attn_proj = self.attn_proj(representations).unsqueeze(1)
if self.attn_proj_vary:
attn_proj = attn_proj.split(hidden_states.size(-1), dim=-1)
assert len(attn_proj) == len(self.h)
presents = ()
all_attentions = []
all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
if self.add_attn:
if self.attn_proj_vary:
z = attn_proj[i]
else:
z = attn_proj
outputs = block(
hidden_states, z, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
)
else:
outputs = block(
hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
)
hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs # last hidden state, (presents), (all hidden_states), (attentions)
class LM_head_rep(nn.Module):
def __init__(self, in_dim=768, out_dim=50257):
super().__init__()
self.Nu_fc1 = nn.Linear(in_dim, 1024)
self.Nu_fc2 = nn.Linear(1024, out_dim)
def forward(self, z):
z = F.leaky_relu(self.Nu_fc1(z))
z = self.Nu_fc2(z)
return z
class VAEModel(GPT2LMHeadModel):
def __init__(self, config, add_input=False, add_attn=False, add_softmax=False, attn_proj_vary=False, learn_prior=False):
super(GPT2LMHeadModel, self).__init__(config)
# add code here
self.add_input = add_input
self.add_attn = add_attn
self.add_softmax = add_softmax
self.attn_proj_vary = attn_proj_vary
self.learn_prior = learn_prior
self.transformer = Decoder(config, add_input, add_attn, attn_proj_vary)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.encoder = Encoder(config)
if self.learn_prior:
self.encoder_prior = Encoder(config)
if self.add_softmax:
nz = config.n_embd
self.lm_head_rep = Conv1D(config.vocab_size, nz)
# self.lm_head_rep = LM_head_rep(nz, config.vocab_size)
def reparameterize(self, mean, logvar, z=None):
std = logvar.mul(0.5).exp()
if z is None:
z = torch.randn(std.size(), device=mean.device, dtype=mean.dtype)
return z.mul(std) + mean
def kl_loss(self, mean1, logvar1, mean2, logvar2):
exponential = logvar1 - logvar2 - torch.pow(mean1 - mean2, 2) / logvar2.exp() - torch.exp(logvar1 - logvar2) + 1
result = -0.5 * torch.sum(exponential, tuple(range(1, len(exponential.shape))))
return result.mean()
def forward(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
x_mask=None,
x_tokens=None,
y_mask=None,
y_tokens=None,
from_prior=False,
from_mean=False
):
# latent representation
posterior_mean, posterior_logvar = self.encoder(input_ids=y_tokens, attention_mask=y_mask)[:2]
if self.learn_prior:
prior_mean, prior_logvar = self.encoder_prior(input_ids=x_tokens, attention_mask=x_mask)[:2]
else:
prior_mean = prior_logvar = torch.zeros([input_ids.size(0), self.config.n_embd], device=input_ids.device)
prior_mean, prior_logvar = prior_mean.to(posterior_mean.dtype), prior_logvar.to(posterior_logvar.dtype)
if from_prior:
latent_mean, latent_logvar = prior_mean, prior_logvar
else:
latent_mean, latent_logvar = posterior_mean, posterior_logvar
if from_mean:
z = latent_mean
else:
z = self.reparameterize(latent_mean, latent_logvar)
assert not torch.isnan(z).any(), 'training get nan z'
transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
representations=z)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if self.add_softmax:
lm_logits_rep = self.lm_head_rep(z)
lm_logits = lm_logits + lm_logits_rep.unsqueeze(dim=1)
outputs = (lm_logits,) + transformer_outputs[1:]
# kl_loss
kl_loss = self.kl_loss(posterior_mean, posterior_logvar, prior_mean, prior_logvar).unsqueeze(0)
outputs = outputs + (kl_loss,)
return outputs # lm_logits, presents, (all hidden_states), (attentions), (kl_loss)