-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiffusion.py
More file actions
1246 lines (1103 loc) · 46.3 KB
/
diffusion.py
File metadata and controls
1246 lines (1103 loc) · 46.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import itertools
from dataclasses import dataclass
import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from tqdm import tqdm
from collections import OrderedDict
import dataloader
import metrics
import models
import noise_schedule
import utils
import numpy as np
import itertools
from block_utils import BlockPlan, token2block_from_plan, make_plan_from_lens
def _sample_categorical(categorical_probs):
gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
samples = (categorical_probs / gumbel_norm).argmax(dim=-1)
return samples
def _unsqueeze(x, reference):
return x.view(
* x.shape,
* ((1,) * (len(reference.shape) - len(x.shape))))
@dataclass
class Loss:
loss: torch.FloatTensor
nlls: torch.FloatTensor
token_mask: torch.FloatTensor
class Diffusion(L.LightningModule):
def __init__(
self,
config,
tokenizer: transformers.PreTrainedTokenizer):
super().__init__()
self.save_hyperparameters()
self.config = config
self.tokenizer = tokenizer
self.vocab_size = self.tokenizer.vocab_size
self.sampler = self.config.algo.sampler
self.antithetic_sampling = self.config.training.antithetic_sampling
self.cross_attn = self.config.algo.cross_attn
self.ignore_bos = self.config.algo.ignore_bos
self.mdlm_loss_scale = self.config.algo.mdlm_loss_scale
if (not hasattr(self.tokenizer, 'mask_token')
or self.tokenizer.mask_token is None):
self.mask_index = self.vocab_size
self.vocab_size += 1
else:
self.mask_index = self.tokenizer.mask_token_id
if hasattr(self.config, 'algo'):
self.parameterization = self.config.algo.parameterization
else:
self.parameterization = self.config.parameterization
if hasattr(self.config, 'block_size'):
self.block_size = self.config.block_size
else:
self.block_size = self.config.model.length
# FIXME: load block_sizes
self.block_sizes = None # [32, 64, 128]
self.block_plan: BlockPlan | None = None
self.block_levels = None
self.token2block = None # [Ltot]
self.Bmax = None # max block length
if hasattr(config, "block_sizes") and config.block_sizes is not None:
bs = [int(x) for x in list(config.block_sizes)]
if len(bs) > 0:
self.block_sizes = bs
self.Bmax = max(bs)
Ltot = int(sum(bs))
if int(self.config.model.length) != Ltot:
raise ValueError(
f"Got model.length={int(self.config.model.length)} but sum(block_sizes)={Ltot}."
)
self.block_plan = make_plan_from_lens(bs, device=torch.device("cpu"))
self.token2block = token2block_from_plan(self.block_plan, L=Ltot) # [Ltot]
self.block_size = self.Bmax
print(f"[multiscale] block_sizes={self.block_sizes}, Ltot={Ltot}, Bmax={self.Bmax}")
if self.parameterization == 'ar':
self.block_size = 1
if self.config.algo.backbone == 'dit':
self.backbone = models.dit.DIT(
self.config, vocab_size=self.vocab_size)
elif self.config.algo.backbone == 'dimamba':
self.backbone = models.dimamba.DiMamba(
self.config,
vocab_size=self.vocab_size,
pad_token_id=self.tokenizer.pad_token_id)
elif self.config.algo.backbone == 'hf_dit':
self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
config.eval.checkpoint_path, trust_remote_code=True)
# egenerate mask if pretrained model uses flex attention mask
# and current model uses sdpa mask
if getattr(self.backbone.config, 'attn_backend', None) == 'flex' and \
self.config.model.attn_backend == 'sdpa':
self.backbone.config.attn_backend = 'sdpa'
for i in self.backbone.backbone.blocks:
i.attn_backend = 'sdpa'
self.backbone.backbone.gen_mask(self.config.model.length, self.block_size, attn_backend='sdpa')
else:
raise ValueError(f'Unknown backbone: {self.config.algo.backbone}')
# FIXME: regenerate attention mask
if self.config.algo.name == 'bd3lm' and self.parameterization != 'ar' and self.config.algo.backbone == 'dit':
bb = getattr(self.backbone, "backbone", None)
target = bb if bb is not None else self.backbone
if (self.block_plan is not None) and hasattr(target, "gen_mask_from_plan"):
target.gen_mask_from_plan(
seqlen=self.config.model.length,
plan=self.block_plan,
attn_backend=self.config.model.attn_backend,
)
elif hasattr(target, "gen_mask"):
target.gen_mask(
self.config.model.length,
self.block_size,
attn_backend=self.config.model.attn_backend,
)
else:
raise AttributeError("Backbone has neither gen_mask_from_plan nor gen_mask")
self.T = self.config.algo.T
self.num_tokens = self.config.model.length
self.noise = noise_schedule.get_noise(self.config)
self.metrics = metrics.Metrics(config)
if self.config.training.ema > 0:
self.ema = models.ema.ExponentialMovingAverage(
self._get_parameters(),
decay=self.config.training.ema)
else:
self.ema = None
self.var_min = self.config.algo.var_min
if self.var_min:
self.register_buffer('sampling_eps_min', torch.tensor(
self.config.training.sampling_eps_min))
self.register_buffer('sampling_eps_max', torch.tensor(
self.config.training.sampling_eps_max))
self.time_conditioning = self.config.algo.time_conditioning
self.neg_infinity = -1000000.0
self.fast_forward_epochs = None
self.fast_forward_batches = None
self._validate_configuration()
def _get_parameters(self):
parameters = [self.backbone.parameters(),
self.noise.parameters()]
return itertools.chain(* parameters)
def on_validation_model_zero_grad(self) -> None:
'''
Small hack to avoid first validation on resume.
This will NOT work if the gradient accumulation step should be performed at this point.
'''
super().on_validation_model_zero_grad()
if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True):
self.trainer.sanity_checking = True
self._restarting_skip_val_flag = False
def _validate_configuration(self):
if self.config.mode == 'sample_eval' and \
self.config.sampling.first_hitting:
assert self.config.loader.eval_batch_size == 1
assert self.config.algo.backbone in {
'dit', 'ar', 'hf_dit'}
if self.config.algo.parameterization == 'ar':
assert not self.config.algo.time_conditioning
if self.config.sampling.kv_cache:
assert self.config.algo.name in {'ar', 'bd3lm'}
if self.parameterization in {'sedd'}:
assert self.time_conditioning
if self.config.mode == 'sample_eval':
assert self.config.model.attn_backend != 'flex', 'FlexAttention mask not supported at inference.'
if self.config.model.attn_backend == 'flex':
assert self.config.algo.name == 'bd3lm', 'Custom FlexAttention mask only supported for BD3LM.'
def to(self, *args, **kwargs):
self = super().to(*args, **kwargs)
self.metrics.to(*args, **kwargs)
if hasattr(self.backbone, "block_diff_mask") and self.config.model.attn_backend == 'sdpa':
self.backbone.block_diff_mask = self.backbone.block_diff_mask.to(*args, **kwargs)
elif hasattr(self.backbone, "block_diff_mask") and self.config.model.attn_backend == 'flex':
self.backbone.block_diff_mask = self.backbone.block_diff_mask.to(self.device)
if hasattr(self, 'sampling_eps_min') and torch.is_tensor(self.sampling_eps_min):
self.sampling_eps_min = self.sampling_eps_min.to(*args, **kwargs)
self.sampling_eps_max = self.sampling_eps_max.to(*args, **kwargs)
# FIXME: also move block plan
if self.block_plan is not None:
dev = self.device
self.block_plan = BlockPlan(
lens=self.block_plan.lens.to(dev),
starts=self.block_plan.starts.to(dev),
)
if self.token2block is not None:
self.token2block = self.token2block.to(dev)
return self
def _replace_ckpt_keys(self, checkpoint):
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k,v in state_dict.items():
new_state_dict[k.replace('_orig_mod.', '')] = v
checkpoint['state_dict'] = new_state_dict
return checkpoint
def on_load_checkpoint(self, checkpoint):
print('Loading checkpoint at', checkpoint['global_step'])
self._restarting_skip_val_flag = True
# for models compiled with `torch.compile`
if '_orig_mod.' in list(checkpoint['state_dict'].keys())[0]:
checkpoint = self._replace_ckpt_keys(checkpoint)
if self.ema:
self.ema.load_state_dict(checkpoint['ema'])
if 'sampling_eps_min' in checkpoint.keys():
self.sampling_eps_min = checkpoint['sampling_eps_min']
self.sampling_eps_max = checkpoint['sampling_eps_max']
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
self.fast_forward_epochs = checkpoint['loops'][
'fit_loop']['epoch_progress']['current']['completed']
self.fast_forward_batches = checkpoint['loops'][
'fit_loop']['epoch_loop.batch_progress'][
'current']['completed']
def on_save_checkpoint(self, checkpoint):
if self.ema:
checkpoint['ema'] = self.ema.state_dict()
if hasattr(self, 'sampling_eps_min'):
checkpoint['sampling_eps_min'] = self.sampling_eps_min
checkpoint['sampling_eps_max'] = self.sampling_eps_max
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
# ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
# behind, so we're using the optimizer's progress.
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['total'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total'][
'completed'] * self.trainer.accumulate_grad_batches
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['current'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['current'][
'completed'] * self.trainer.accumulate_grad_batches
# _batches_that_stepped tracks the number of global steps, not the number
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
checkpoint['loops']['fit_loop'][
'epoch_loop.state_dict'][
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total']['completed']
if 'sampler' not in checkpoint.keys():
checkpoint['sampler'] = {}
if hasattr(self.trainer.train_dataloader.sampler,
'state_dict'):
sampler_state_dict = self.trainer.\
train_dataloader.sampler.state_dict()
checkpoint['sampler'][
'random_state'] = sampler_state_dict.get(
'random_state', None)
else:
checkpoint['sampler']['random_state'] = None
def on_train_start(self):
if self.ema:
self.ema.move_shadow_params_to_device(self.device)
# Adapted from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
distributed = (
self.trainer._accelerator_connector.use_distributed_sampler
and self.trainer._accelerator_connector.is_distributed)
if distributed:
sampler_cls = dataloader.FaultTolerantDistributedSampler
else:
sampler_cls = dataloader.RandomFaultTolerantSampler
updated_dls = []
for dl in self.trainer.fit_loop._combined_loader.flattened:
if hasattr(dl.sampler, 'shuffle'):
dl_sampler = sampler_cls(
dl.dataset, shuffle=dl.sampler.shuffle)
else:
dl_sampler = sampler_cls(dl.dataset)
if (distributed
and self.fast_forward_epochs is not None
and self.fast_forward_batches is not None):
dl_sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': (self.fast_forward_batches
* self.config.loader.batch_size)})
updated_dls.append(
torch.utils.data.DataLoader(
dl.dataset,
batch_size=self.config.loader.batch_size,
num_workers=self.config.loader.num_workers,
pin_memory=self.config.loader.pin_memory,
sampler=dl_sampler,
shuffle=False,
persistent_workers=True,
collate_fn=getattr(dl, "collate_fn", None),))
self.trainer.fit_loop._combined_loader.flattened = updated_dls
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
if self.ema:
self.ema.update(self._get_parameters())
def _subs_parameterization(self, logits, xt):
# log prob at the mask index = - infinity
logits[:, :, self.mask_index] += self.neg_infinity
# Normalize the logits such that x.exp() is
# a probability distribution over vocab_size.
logits = logits - torch.logsumexp(logits, dim=-1,
keepdim=True)
# Apply updates directly in the logits matrix.
# For the logits of the unmasked tokens, set all values
# to -infinity except for the indices corresponding to
# the unmasked tokens.
unmasked_indices = (xt != self.mask_index)
logits[unmasked_indices] = self.neg_infinity
logits[unmasked_indices, xt[unmasked_indices]] = 0
return logits
def _sedd_parameterization(self, logits, xt, sigma):
esigm1_log = torch.where(
sigma < 0.5,
torch.expm1(sigma),
sigma.exp() - 1).log().to(logits.dtype)
# logits shape
# (batch_size, diffusion_model_input_length, vocab_size)
logits = logits - esigm1_log[:, None, None] - np.log(
logits.shape[-1] - 1)
# The below scatter operation sets the log score
# for the input word to 0.
logits = torch.scatter(logits, -1, xt[..., None],
torch.zeros_like(logits[..., :1]))
return logits
def _process_sigma(self, sigma):
# cause of overfitting for block size 1?
if self.parameterization == 'ar':
return None
assert sigma.ndim == 2
sigma = sigma.mean(-1).squeeze()
if sigma.ndim == 0:
sigma = sigma.unsqueeze(0)
if not self.time_conditioning:
sigma = torch.zeros_like(sigma)
assert sigma.ndim == 1, sigma.shape
return sigma
def forward(self, x, sigma, sample_mode=False, store_kv=False):
"""Returns log score."""
sigma = self._process_sigma(sigma)
with torch.amp.autocast('cuda', dtype=torch.float32):
if self.config.algo.name == 'bd3lm':
logits = self.backbone(x, sigma,
store_kv=store_kv,
sample_mode=sample_mode)
elif self.config.algo.name == 'ar':
if self.config.algo.backbone == 'hf_dit':
logits = self.backbone(x, None)
else:
logits = self.backbone(x, sigma, sample_mode=sample_mode, store_kv=store_kv)
logits[:, :, self.mask_index] = self.neg_infinity
logits = logits.log_softmax(-1)
else:
logits = self.backbone(x, sigma)
if self.cross_attn:
x = x[:, :self.config.model.length]
if self.parameterization == 'subs':
return self._subs_parameterization(logits=logits,
xt=x)
elif self.parameterization == 'sedd':
return self._sedd_parameterization(logits=logits,
xt=x,
sigma=sigma)
return logits
def on_train_epoch_start(self):
self.backbone.train()
self.noise.train()
self.metrics.reset()
assert self.metrics.train_nlls.nll.mean_value == 0
assert self.metrics.train_nlls.nll.weight == 0
def training_step(self, batch, batch_idx):
del batch_idx
losses = self._loss(batch['input_ids'],
batch['attention_mask'])
self.metrics.train_nlls.update(losses.nlls, losses.token_mask)
self.log(name='trainer/loss',
value=losses.loss.item(),
on_step=True,
on_epoch=False,
sync_dist=True)
return losses.loss
def on_validation_epoch_start(self):
self.metrics.reset()
if self.ema:
self.ema.store(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.ema.copy_to(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.eval()
self.backbone.eval()
self.noise.eval()
assert self.metrics.valid_nlls.nll.mean_value == 0
assert self.metrics.valid_nlls.nll.weight == 0
self.sampling_eps = self.config.training.sampling_eps
def on_validation_epoch_end(self):
for k, v in self.metrics.valid_nlls.items():
self.log(name=k, value=v.compute(), on_step=False,
on_epoch=True, sync_dist=True)
if self.ema:
self.ema.restore(self._get_parameters())
if self.var_min and not self.trainer.sanity_checking:
self._clipped_schedule_search()
self.log('sampling_eps_min',
self.sampling_eps_min,
on_epoch=True,
on_step=False,
sync_dist=True)
self.log('sampling_eps_max',
self.sampling_eps_max,
on_epoch=True,
on_step=False,
sync_dist=True)
def _check_val_sampling_intvl(self, sampling_eps_min, sampling_eps_max):
"""Checks if the current sampling interval is valid for reporting likelihood."""
if (sampling_eps_min == 1e-3 \
and sampling_eps_max == 1 \
and not (self.block_size == 1 and self.config.training.eval_nll)):
return True # elbo
elif (self.block_size == 1 and sampling_eps_min >= 1):
return True # nll (block size 1)
return False # not a valid elbo (biased estimate)
def validation_step(self, batch, batch_idx):
if self.var_min:
for noise_clip_start in self.metrics.valid_vars.keys():
sampling_eps_min, sampling_eps_max = noise_clip_start
if self._check_val_sampling_intvl(sampling_eps_min, sampling_eps_max) == True:
# compute and record nelbo
losses_clip = self._loss(batch['input_ids'],
batch['attention_mask'],
sampling_eps_min=sampling_eps_min,
sampling_eps_max=sampling_eps_max)
losses = Loss(
nlls=losses_clip.nlls.clone(),
token_mask=losses_clip.token_mask,
loss=losses_clip.loss.clone())
elif len(self.metrics.valid_vars[noise_clip_start]) < 100:
# elbo from clipped schedule (biased estimate)
losses_clip = self._loss(batch['input_ids'],
batch['attention_mask'],
sampling_eps_min=sampling_eps_min,
sampling_eps_max=sampling_eps_max)
if len(self.metrics.valid_vars[noise_clip_start]) < 100:
# only report variance over 100 batches
nlls = losses_clip.nlls
# FIXME: reshape corresponding the block, need check
if self.block_plan is not None:
plan = self.block_plan
lens = plan.lens.to(nlls.device)
starts = plan.starts.to(nlls.device)
B, L = nlls.shape
N = int(lens.numel())
per_block = []
eps = 1e-8
for bi in range(N):
s = int(starts[bi].item())
l = int(lens[bi].item())
e = s + l
blk_nll = nlls[:, s:e] # [B, l] - block token nll
blk_m = losses_clip.token_mask[:, s:e].to(nlls.dtype) # [B, l] - mask
denom = blk_m.sum(dim=-1).clamp_min(eps) # [B]
mean_blk = (blk_nll * blk_m).sum(dim=-1) / denom # [B]
per_block.append(mean_blk)
# [B, Nblocks]
per_block = torch.stack(per_block, dim=1)
self.metrics.valid_vars[noise_clip_start].append(per_block)
else:
self.metrics.valid_vars[noise_clip_start].append(
nlls.reshape(
nlls.shape[0], -1, self.block_size).mean(-1))
elif self.block_size == 1:
# nll
losses = self._loss(batch['input_ids'],
batch['attention_mask'],
sampling_eps_min=1,
sampling_eps_max=1)
else:
# nelbo
losses = self._loss(batch['input_ids'],
batch['attention_mask'],
sampling_eps_min=1e-3,
sampling_eps_max=1)
self.metrics.valid_nlls.update(losses.nlls, losses.token_mask)
return losses.loss
def configure_optimizers(self):
# TODO(yair): Lightning currently giving this warning when using `fp16`:
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
# Not clear if this is a problem or not.
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
optimizer = torch.optim.AdamW(
self._get_parameters(),
lr=self.config.optim.lr,
betas=(self.config.optim.beta1,
self.config.optim.beta2),
eps=self.config.optim.eps,
weight_decay=self.config.optim.weight_decay)
scheduler = hydra.utils.instantiate(
self.config.lr_scheduler, optimizer=optimizer)
scheduler_dict = {'scheduler': scheduler,
'interval': 'step',
'monitor': 'val/loss',
'name': 'trainer/lr'}
return [optimizer], [scheduler_dict]
def _resample_q_xt(
self, x, xt, move_indices, p, block_size, sampling_eps_min, sampling_eps_max):
"""Resamples x_t if the percentage of masked tokens is outside the bounds
defined by sampling_eps_min and sampling_eps_max."""
perc_masked = (xt == self.mask_index).float().sum(-1) / block_size
while (perc_masked < sampling_eps_min).any() or \
(perc_masked > sampling_eps_max).any():
# if a bound is epsilon, don't resample
if sampling_eps_min == 1e-3 and sampling_eps_max != 1:
regen_idx = (perc_masked > sampling_eps_max)
if regen_idx.max() == 0:
break
elif sampling_eps_min != 1e-3 and sampling_eps_max == 1:
regen_idx = (perc_masked < sampling_eps_min)
if regen_idx.max() == 0:
break
elif sampling_eps_min != 1e-3 and sampling_eps_max != 1:
regen_idx = (perc_masked < sampling_eps_min) | (perc_masked > sampling_eps_max)
regen_idx = regen_idx.repeat_interleave(block_size,dim=-1)
move_indices[regen_idx] = (torch.rand(
* x.shape, device=x.device) < p)[regen_idx]
xt = torch.where(move_indices, self.mask_index, x)
xt = xt.reshape(xt.shape[0], -1, block_size)
perc_masked = (xt == self.mask_index).float().sum(-1) / block_size
return xt
# FIXME: resample var block
def _resample_q_xt_plan(self, x, xt, move_indices, p, sampling_eps_min, sampling_eps_max):
plan = self.block_plan
assert plan is not None
lens = plan.lens.to(x.device)
starts = plan.starts.to(x.device)
B = x.shape[0]
max_iters = 50
for _ in range(max_iters):
any_bad = False
for bi in range(int(lens.numel())):
s = int(starts[bi].item())
l = int(lens[bi].item())
e = s + l
xt_blk = xt[:, s:e]
perc = (xt_blk == self.mask_index).float().sum(-1) / float(l) # [B]
regen = (perc < sampling_eps_min) | (perc > sampling_eps_max)
if regen.any():
any_bad = True
p_blk = p[regen, s].unsqueeze(1) # [B_regen, 1]
new_move = (torch.rand((regen.sum().item(), l), device=x.device) <= p_blk) # [B_regen, l]
move_indices[regen, s:e] = new_move
xt[regen, s:e] = torch.where(new_move, self.mask_index, x[regen, s:e])
if not any_bad:
break
return xt
def q_xt(
self, x, p, block_size=None, sampling_eps_min=None, sampling_eps_max=None):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
p: float torch.Tensor with shape (batch_size, 1).
block_size: int, block size.
sampling_eps_min: float, minimum percentage of masked tokens.
sampling_eps_max: float, maximum percentage of masked tokens.
"""
if block_size is None:
block_size = self.block_size
move_indices = torch.rand(
* x.shape, device=x.device) <= p
xt = torch.where(move_indices, self.mask_index, x)
if block_size == 1 and sampling_eps_min == 1.0:
return torch.full_like(x, self.mask_index)
# no need to resample for bounds 1e-3, 1
if sampling_eps_min is None: sampling_eps_min = 1e-3
if sampling_eps_max is None: sampling_eps_max = 1.0
if self.config.training.resample and \
not (sampling_eps_min == 1e-3 and sampling_eps_max == 1.0):
if self.block_plan is None:
xt = xt.reshape(xt.shape[0], -1, block_size)
xt = self._resample_q_xt(x, xt, move_indices, p, block_size, sampling_eps_min, sampling_eps_max)
xt = xt.reshape(xt.shape[0], -1)
# FIXME: var block xt
else:
xt = self._resample_q_xt_plan(x, xt, move_indices, p, sampling_eps_min, sampling_eps_max)
return xt
def _sample_prior(self, *batch_dims):
return self.mask_index * torch.ones(
* batch_dims, dtype=torch.int64, device=self.device)
@torch.no_grad()
def _nucleus_sample(self, p_x0):
p = self.config.sampling.nucleus_p
if p == 1.0:
return p_x0
p_x0_ = p_x0[:, -self.block_size:].clone()
sorted_probs, sorted_indices = p_x0_.sort(dim=-1, descending=True)
cum_probs = sorted_probs.cumsum(dim=-1)
nucleus_mask = cum_probs <= p
nucleus_mask[..., 0] = 1
sorted_probs = sorted_probs * nucleus_mask
p_x0_.scatter_(-1, sorted_indices, sorted_probs * nucleus_mask)
p_x0_ /= p_x0_.sum(-1, keepdim=True)
p_x0[:, -self.block_size:] = p_x0_
return p_x0
@torch.no_grad()
def _ddpm_caching_update(self, x, t, dt, p_x0=None):
_, move_chance_t = self.noise(t)
_, move_chance_s = self.noise(t - dt)
sigma_t = self._sigma_from_p(move_chance_t)
move_chance_t = move_chance_t[:, None]
move_chance_s = move_chance_s[:, None]
mask_prob = move_chance_s / move_chance_t
if p_x0 is None:
if self.config.sampling.kv_cache:
p_x0 = self.forward(x[:, -self.block_size:],
sigma_t,
sample_mode=True).to(torch.float64)
else:
p_x0 = self.forward(x,
sigma_t,
sample_mode=True).to(torch.float64)
p_x0 = p_x0[:, -self.block_size:]
p_x0 = p_x0.exp()
p_x0 = self._nucleus_sample(p_x0)
if self.config.sampling.first_hitting:
x_block = _sample_categorical(p_x0)
# randomly and uniformly select an index in the block (among masked tokens)
num_masked = (x[:, -self.block_size:] == self.mask_index).sum(-1)
ind = torch.randint(0, num_masked, (x_block.shape[0],))
ind = (x[:, -self.block_size:] == self.mask_index).nonzero()[ind, 1]
mask = (torch.arange(self.block_size, device=x.device) == ind[:, None]).to(x_block.dtype)
x_block = x_block * mask + x[:, -self.block_size:] * (1 - mask)
else:
q_xs = p_x0 * (1 - mask_prob)
q_xs[:, :, self.mask_index] = mask_prob.squeeze(-1)
x_block = _sample_categorical(q_xs)
copy_flag = (x[:, -self.block_size:] != self.mask_index).to(x.dtype)
x_block = copy_flag * x[:, -self.block_size:] + (1 - copy_flag) * x_block
x_new = torch.cat((x[:, :-self.block_size], x_block), dim=-1)
# compute kv cache if all tokens in a block are sampled
if self.config.sampling.kv_cache and self.mask_index not in x_block:
_ = self.forward(x_block, sigma_t, sample_mode=True, store_kv=True)
if not torch.allclose(x_new, x):
return None, x_new
else:
return p_x0, x_new
@torch.no_grad()
def _ar_sampler(self, bsz, context_len=1024):
# reset kvs
if self.config.sampling.kv_cache:
self.backbone.reset_kv_cache()
with torch.amp.autocast('cuda', dtype=torch.float32):
# precompute token buffer
num_pred_tokens = self.num_tokens - 1
x = torch.zeros(
(bsz, num_pred_tokens + 1),
dtype=torch.long,
device=self.device)
x[:, 0] = self.tokenizer.bos_token_id
stop = False
for i in tqdm(range(num_pred_tokens)):
# need to sample a gumbel for each token
# to save memory in variable-length sampling
noise = (torch.distributions.Gumbel(0, 1)
.sample((bsz, self.vocab_size))
.to(self.device))
next_logits = self.forward(
x[:, :i + 1][:, -context_len:],
None,
store_kv=self.config.sampling.kv_cache)[:, -1:].to(torch.float64)
next_logits = next_logits.exp()
next_logits = self._nucleus_sample(next_logits).log()
y = (next_logits[:, -1] + noise).argmax(-1)
# check if we need to resample (or stop sampling for variable-length sampling)
if (i+1) > 256:
stop, x_out = self._check_stop_conds(x[:, :i+1])
if stop:
x = x_out
if (stop and not self.config.sampling.var_length) \
or (stop and x.shape[-1] == 1):
return None
elif stop:
break
x[:, i + 1] = y
return x
@torch.no_grad()
def _sample(
self, seqlen=None, num_steps=None, eps=1e-5, batch_size_per_gpu=None):
"""Generate samples from the model."""
if seqlen is None:
seqlen = self.config.model.length
if batch_size_per_gpu is None:
batch_size_per_gpu = self.config.loader.eval_batch_size
samples = []
if self.parameterization == 'ar':
for _ in range(self.config.sampling.num_sample_batches):
sample_i, num_tries = None, 0
while sample_i is None:
num_tries += 1
sample_i = self._ar_sampler(batch_size_per_gpu)
if num_tries > 10:
raise ValueError('Sampling failed.')
samples.append(sample_i)
self.metrics.gen_nfes.append(self.config.model.length)
samples = torch.cat(samples, dim=0)
return self.tokenizer.batch_decode(samples)
if self.sampler == 'semi_ar':
for _ in range(self.config.sampling.num_sample_batches):
sample_i, num_tries = None, 0
while sample_i is None:
num_tries += 1
sample_i, nfes = self._semi_ar_sampler(
n_samples=batch_size_per_gpu,
num_strides=(seqlen // self.block_size),
num_steps=num_steps,
seqlen=seqlen)
if num_tries > 10:
raise ValueError('Sampling failed.')
samples.append(sample_i)
self.metrics.nfes.update(nfes)
self.metrics.gen_nfes.append(nfes)
else:
nfes = num_steps
for _ in range(self.config.sampling.num_sample_batches):
sample_i, num_tries = None, 0
while sample_i is None:
sample_i = self._analytic_sampler(
n_samples=batch_size_per_gpu,
num_steps=num_steps,
seqlen=seqlen,
eps=eps)
num_tries += 1
if num_tries > 10 and sample_i is None:
raise ValueError('Sampling failed.')
samples.append(sample_i)
self.metrics.nfes.update(nfes)
self.metrics.gen_nfes.append(nfes)
samples = torch.cat(samples, dim=0)
return self.tokenizer.batch_decode(samples)
def _sigma_from_p(self, p):
return torch.min(- torch.log(1 - p), self.noise.sigma_max)
def restore_model_and_sample(self, num_steps, eps=1e-5, seqlen=None):
"""Generate samples from the model."""
if self.ema:
self.ema.store(self._get_parameters())
self.ema.copy_to(self._get_parameters())
self.backbone.eval()
self.noise.eval()
samples = self._sample(
seqlen=seqlen,
batch_size_per_gpu=self.config.loader.eval_batch_size,
num_steps=num_steps,
eps=eps)
self.metrics.record_generative_perplexity(
samples,
self.config.model.length,
self.config.loader.eval_batch_size,
self.device)
return samples
def get_score(self, x, sigma):
model_output = self.forward(x, sigma).to(torch.float64)
if self.config.sampling.nucleus_p == 1.0:
return model_output.exp()
model_output = model_output - model_output.logsumexp(-1, keepdim=True)
model_output = self._nucleus_sample(model_output.exp())
return model_output
def _staggered_score(self, score, dsigma):
score = score.clone()
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
score *= dsigma.exp()[:, None]
score[..., self.mask_index] += extra_const
return score
def _analytic_update(self, x, t, dt):
sigma_t = self._sigma_from_p(self.noise(t)[1])
sigma_s = self._sigma_from_p(self.noise(t - dt)[1])
dsigma = sigma_t - sigma_s
score = self.get_score(x, sigma_t)
stag_score = self._staggered_score(score, dsigma)
probs = stag_score * self._transp_transition(x, dsigma)
return _sample_categorical(probs)
def _denoiser_update(self, x, t):
sigma = self._sigma_from_p(self.noise(t)[1])
score = self.get_score(x, sigma)
stag_score = self._staggered_score(score, sigma)
probs = stag_score * self._transp_transition(x, sigma)
probs[..., self.mask_index] = 0
samples = _sample_categorical(probs)
return samples
def _transp_transition(self, i, sigma):
sigma = _unsqueeze(sigma, reference=i[..., None])
edge = torch.exp(-sigma) * F.one_hot(
i, num_classes=self.vocab_size)
edge += torch.where(i == self.mask_index,
1 - torch.exp(-sigma).squeeze(-1),
0)[..., None]
return edge
def _sample_t(
self, batch_dims, device, sampling_eps_min, sampling_eps_max, block_size=None):
if self.block_plan is None:
if block_size is None:
block_size = self.block_size
n = batch_dims[-1]
num_blocks = n // block_size
_eps_b = torch.rand((batch_dims[0], num_blocks), device=device)
# antithetic sampling along blocks & batches (for uniform sampling)
if self.antithetic_sampling:
offset_b = torch.arange(batch_dims[0] * num_blocks, device=device) / (batch_dims[0] * num_blocks)
offset_b = offset_b.view(batch_dims[0], num_blocks)
_eps_b = (_eps_b / (batch_dims[0] * num_blocks) + offset_b) % 1
t = _eps_b
if block_size != self.config.model.length:
t = t.repeat_interleave(block_size, dim=-1)
# nll
if sampling_eps_max >= 1 and sampling_eps_min >= 1:
return torch.ones_like(t)
t = t * (sampling_eps_max - sampling_eps_min) + sampling_eps_min
return t
# FIXME: var block
else:
B = batch_dims[0]
n = batch_dims[-1]
plan = self.block_plan
lens = plan.lens.to(device)
if int(lens.sum().item()) != n:
raise ValueError(f"Plan sum {int(lens.sum().item())} != n={n}")
Nblocks = int(lens.numel())
eps_b = torch.rand((B, Nblocks), device=device)
if self.antithetic_sampling:
offset = torch.arange(B * Nblocks, device=device) / (B * Nblocks)
eps_b = (eps_b / (B * Nblocks) + offset.view(B, Nblocks)) % 1
if sampling_eps_max >= 1 and sampling_eps_min >= 1:
t_block = torch.ones_like(eps_b)
else:
t_block = eps_b * (sampling_eps_max - sampling_eps_min) + sampling_eps_min
# [B, Nblocks] -> [B, n]
t = torch.repeat_interleave(t_block, lens, dim=1)
return t
def _maybe_sub_sample(self, x0, attention_mask):
seqlen = x0.shape[1]
if seqlen > self.num_tokens:
assert seqlen == 2 * self.num_tokens
# cropping is needed for text8-crop dataset
# try the same starting point for now
start = np.random.choice(self.num_tokens)
end = start + self.num_tokens
input_tokens = x0[:, start: end]
output_tokens = x0[:, start + 1: end + 1]
new_attention_mask = attention_mask[:, start: end]
# Helps with validation ppl, since the val
# examples will all start and end with BOS/EOS
if self.config.data.insert_train_special == True:
input_tokens[:, 0] = self.tokenizer.bos_token_id
output_tokens[:, -1] = self.tokenizer.eos_token_id
elif self.parameterization == 'ar':
input_tokens = x0[:, :-1]
output_tokens = x0[:, 1:]
new_attention_mask = attention_mask[:, 1:]
else:
input_tokens = x0
output_tokens = None
new_attention_mask = attention_mask
return input_tokens, output_tokens, new_attention_mask
def _forward_pass_diffusion(self, x0, t=None, sampling_eps_min=None, sampling_eps_max=None):
if t is None:
t = self._sample_t(x0.shape,
x0.device,
sampling_eps_min,
sampling_eps_max)
loss_scale, p = self.noise(t)
sigma = self._sigma_from_p(p[:,0].unsqueeze(-1))
dsigma = - loss_scale * torch.expm1(sigma) # used for sedd
# below is needed to reproduce mdlm/sedd numbers with models from sahoo et al
# (numerical imprecision computing probs under loglinear schedule)
if self.mdlm_loss_scale:
sigma, dsigma = self.noise.total_noise(t), self.noise.rate_noise(t)
p = 1 - torch.exp(-sigma)
loss_scale = - (dsigma / torch.expm1(sigma))
xt = self.q_xt(x0,
p,
sampling_eps_min=sampling_eps_min,
sampling_eps_max=sampling_eps_max)
if sampling_eps_min is not None and sampling_eps_min > 0.5:
loss_scale = - torch.ones_like(loss_scale)
if self.ignore_bos:
xt[:, 0] = x0[:, 0]
x_input = xt
if self.cross_attn:
x_input = torch.cat((xt, x0), dim=-1)
model_output = self.forward(x_input, sigma=sigma)
utils.print_nans(model_output, 'model_output')