-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessInput.py
More file actions
1440 lines (1218 loc) · 63.1 KB
/
preprocessInput.py
File metadata and controls
1440 lines (1218 loc) · 63.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
from fairseq.models.roberta import RobertaModel
# from fairseq.models.transformer import TransformerModel
import os
import io
import zipfile
import miditoolkit
import random
import time
import math
import signal
import hashlib
from multiprocessing import Pool, Lock, Manager
import numpy as np
import torch
import torch.nn.functional as F
import sys
from muzic.musicbert.musicbert import *
from music21 import *
from tqdm import tqdm
from enum import Enum
from itertools import chain
from music21 import *
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import accelerate # Required by diffusers
from PIL import Image
from datetime import datetime
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
diffusers_available = True
us = environment.UserSettings()
us['musescoreDirectPNGPath'] = './mscore'
us['directoryScratch'] = './tmp'
bar_max = 256
pos_resolution = 16
velocity_quant = 4
tempo_quant = 12
min_tempo = 16
max_tempo = 256
duration_max = 8
max_ts_denominator = 6
max_notes_per_bar = 2
beat_note_factor = 4
deduplicate = True
filter_symbolic = False
filter_symbolic_ppl = 16
trunc_pos = 2 ** 16
sample_len_max = 1000
sample_overlap_rate = 4
ts_filter = False
pool_num = 24
max_inst = 127
max_pitch = 127
max_velocity = 127
data_zip = None
output_file = None
lock_file = Lock()
lock_write = Lock()
lock_set = Lock()
manager = Manager()
midi_dict = manager.dict()
ts_dict = dict()
ts_list = list()
for i in range(0, max_ts_denominator + 1): # 1 ~ 64
for j in range(1, ((2 ** i) * max_notes_per_bar) + 1):
ts_dict[(j, 2 ** i)] = len(ts_dict)
ts_list.append((j, 2 ** i))
dur_enc = list()
dur_dec = list()
for i in range(duration_max):
for j in range(pos_resolution):
dur_dec.append(len(dur_enc))
for k in range(2 ** i):
dur_enc.append(len(dur_dec) - 1)
class timeout:
def __init__(self, seconds=1, error_message='Timeout'):
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum, frame):
raise TimeoutError(self.error_message)
def __enter__(self):
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
def __exit__(self, exc_type, value, traceback):
signal.alarm(0)
def t2e(x):
assert x in ts_dict, 'unsupported time signature: ' + str(x)
return ts_dict[x]
def e2t(x):
return ts_list[x]
def d2e(x):
return dur_enc[x] if x < len(dur_enc) else dur_enc[-1]
def e2d(x):
return dur_dec[x] if x < len(dur_dec) else dur_dec[-1]
def v2e(x):
return x // velocity_quant
def e2v(x):
return (x * velocity_quant) + (velocity_quant // 2)
def b2e(x):
x = max(x, min_tempo)
x = min(x, max_tempo)
x = x / min_tempo
e = round(math.log2(x) * tempo_quant)
return e
def e2b(x):
return 2 ** (x / tempo_quant) * min_tempo
def time_signature_reduce(numerator, denominator):
# reduction (when denominator is too large)
while denominator > 2 ** max_ts_denominator and denominator % 2 == 0 and numerator % 2 == 0:
denominator //= 2
numerator //= 2
# decomposition (when length of a bar exceed max_notes_per_bar)
while numerator > max_notes_per_bar * denominator:
for i in range(2, numerator + 1):
if numerator % i == 0:
numerator //= i
break
return numerator, denominator
def writer(file_name, output_str_list):
# note: parameter "file_name" is reserved for patching
with open(output_file, 'a') as f:
for output_str in output_str_list:
f.write(output_str + '\n')
def gen_dictionary(file_name):
num = 0
with open(file_name, 'w') as f:
for j in range(bar_max):
print('<0-{}>'.format(j), num, file=f)
for j in range(beat_note_factor * max_notes_per_bar * pos_resolution):
print('<1-{}>'.format(j), num, file=f)
for j in range(max_inst + 1 + 1):
# max_inst + 1 for percussion
print('<2-{}>'.format(j), num, file=f)
for j in range(2 * max_pitch + 1 + 1):
# max_pitch + 1 ~ 2 * max_pitch + 1 for percussion
print('<3-{}>'.format(j), num, file=f)
for j in range(duration_max * pos_resolution):
print('<4-{}>'.format(j), num, file=f)
for j in range(v2e(max_velocity) + 1):
print('<5-{}>'.format(j), num, file=f)
for j in range(len(ts_list)):
print('<6-{}>'.format(j), num, file=f)
for j in range(b2e(max_tempo) + 1):
print('<7-{}>'.format(j), num, file=f)
def MIDI_to_encoding(midi_obj):
def time_to_pos(t):
return round(t * pos_resolution / midi_obj.ticks_per_beat)
notes_start_pos = [time_to_pos(j.start)
for i in midi_obj.instruments for j in i.notes]
if len(notes_start_pos) == 0:
return list()
max_pos = min(max(notes_start_pos) + 1, trunc_pos)
pos_to_info = [[None for _ in range(4)] for _ in range(
max_pos)] # (Measure, TimeSig, Pos, Tempo)
tsc = midi_obj.time_signature_changes
tpc = midi_obj.tempo_changes
for i in range(len(tsc)):
for j in range(time_to_pos(tsc[i].time), time_to_pos(tsc[i + 1].time) if i < len(tsc) - 1 else max_pos):
if j < len(pos_to_info):
pos_to_info[j][1] = t2e(time_signature_reduce(
tsc[i].numerator, tsc[i].denominator))
for i in range(len(tpc)):
for j in range(time_to_pos(tpc[i].time), time_to_pos(tpc[i + 1].time) if i < len(tpc) - 1 else max_pos):
if j < len(pos_to_info):
pos_to_info[j][3] = b2e(tpc[i].tempo)
for j in range(len(pos_to_info)):
if pos_to_info[j][1] is None:
# MIDI default time signature
pos_to_info[j][1] = t2e(time_signature_reduce(4, 4))
if pos_to_info[j][3] is None:
pos_to_info[j][3] = b2e(120.0) # MIDI default tempo (BPM)
cnt = 0
bar = 0
measure_length = None
for j in range(len(pos_to_info)):
ts = e2t(pos_to_info[j][1])
if cnt == 0:
measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1]
pos_to_info[j][0] = bar
pos_to_info[j][2] = cnt
cnt += 1
if cnt >= measure_length:
assert cnt == measure_length, 'invalid time signature change: pos = {}'.format(
j)
cnt -= measure_length
bar += 1
encoding = []
start_distribution = [0] * pos_resolution
for inst in midi_obj.instruments:
for note in inst.notes:
if time_to_pos(note.start) >= trunc_pos:
continue
start_distribution[time_to_pos(note.start) % pos_resolution] += 1
info = pos_to_info[time_to_pos(note.start)]
encoding.append((info[0], info[2], max_inst + 1 if inst.is_drum else inst.program, note.pitch + max_pitch +
1 if inst.is_drum else note.pitch, d2e(time_to_pos(note.end) - time_to_pos(note.start)), v2e(note.velocity), info[1], info[3]))
if len(encoding) == 0:
return list()
tot = sum(start_distribution)
start_ppl = 2 ** sum((0 if x == 0 else -(x / tot) *
math.log2((x / tot)) for x in start_distribution))
# filter unaligned music
if filter_symbolic:
assert start_ppl <= filter_symbolic_ppl, 'filtered out by the symbolic filter: ppl = {:.2f}'.format(
start_ppl)
encoding.sort()
return encoding
def encoding_to_MIDI(encoding):
# TODO: filter out non-valid notes and error handling
bar_to_timesig = [list()
for _ in range(max(map(lambda x: x[0], encoding)) + 1)]
for i in encoding:
bar_to_timesig[i[0]].append(i[6])
bar_to_timesig = [max(set(i), key=i.count) if len(
i) > 0 else None for i in bar_to_timesig]
for i in range(len(bar_to_timesig)):
if bar_to_timesig[i] is None:
bar_to_timesig[i] = t2e(time_signature_reduce(
4, 4)) if i == 0 else bar_to_timesig[i - 1]
bar_to_pos = [None] * len(bar_to_timesig)
cur_pos = 0
for i in range(len(bar_to_pos)):
bar_to_pos[i] = cur_pos
ts = e2t(bar_to_timesig[i])
measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1]
cur_pos += measure_length
pos_to_tempo = [list() for _ in range(
cur_pos + max(map(lambda x: x[1], encoding)))]
for i in encoding:
pos_to_tempo[bar_to_pos[i[0]] + i[1]].append(i[7])
pos_to_tempo = [round(sum(i) / len(i)) if len(i) >
0 else None for i in pos_to_tempo]
for i in range(len(pos_to_tempo)):
if pos_to_tempo[i] is None:
pos_to_tempo[i] = b2e(120.0) if i == 0 else pos_to_tempo[i - 1]
midi_obj = miditoolkit.midi.parser.MidiFile()
def get_tick(bar, pos):
return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // pos_resolution
midi_obj.instruments = [miditoolkit.containers.Instrument(program=(
0 if i == 128 else i), is_drum=(i == 128), name=str(i)) for i in range(128 + 1)]
for i in encoding:
start = get_tick(i[0], i[1])
program = i[2]
pitch = (i[3] - 128 if program == 128 else i[3])
duration = get_tick(0, e2d(i[4]))
if duration == 0:
duration = 1
end = start + duration
velocity = e2v(i[5])
midi_obj.instruments[program].notes.append(miditoolkit.containers.Note(
start=start, end=end, pitch=pitch, velocity=velocity))
midi_obj.instruments = [
i for i in midi_obj.instruments if len(i.notes) > 0]
cur_ts = None
for i in range(len(bar_to_timesig)):
new_ts = bar_to_timesig[i]
if new_ts != cur_ts:
numerator, denominator = e2t(new_ts)
midi_obj.time_signature_changes.append(miditoolkit.containers.TimeSignature(
numerator=numerator, denominator=denominator, time=get_tick(i, 0)))
cur_ts = new_ts
cur_tp = None
for i in range(len(pos_to_tempo)):
new_tp = pos_to_tempo[i]
if new_tp != cur_tp:
tempo = e2b(new_tp)
midi_obj.tempo_changes.append(
miditoolkit.containers.TempoChange(tempo=tempo, time=get_tick(0, i)))
cur_tp = new_tp
return midi_obj
def get_hash(encoding):
# add i[4] and i[5] for stricter match
midi_tuple = tuple((i[2], i[3]) for i in encoding)
midi_hash = hashlib.md5(str(midi_tuple).encode('ascii')).hexdigest()
return midi_hash
def F(file_name):
try_times = 10
midi_file = None
for _ in range(try_times):
try:
lock_file.acquire()
with data_zip.open(file_name) as f:
# this may fail due to unknown bug
midi_file = io.BytesIO(f.read())
except BaseException as e:
try_times -= 1
time.sleep(1)
if try_times == 0:
print('ERROR(READ): ' + file_name +
' ' + str(e) + '\n', end='')
return None
finally:
lock_file.release()
try:
with timeout(seconds=600):
midi_obj = miditoolkit.midi.parser.MidiFile(file=midi_file)
# check abnormal values in parse result
assert all(0 <= j.start < 2 ** 31 and 0 <= j.end < 2 **
31 for i in midi_obj.instruments for j in i.notes), 'bad note time'
assert all(0 < j.numerator < 2 ** 31 and 0 < j.denominator < 2 **
31 for j in midi_obj.time_signature_changes), 'bad time signature value'
assert 0 < midi_obj.ticks_per_beat < 2 ** 31, 'bad ticks per beat'
except BaseException as e:
print('ERROR(PARSE): ' + file_name + ' ' + str(e) + '\n', end='')
return None
midi_notes_count = sum(len(inst.notes) for inst in midi_obj.instruments)
if midi_notes_count == 0:
print('ERROR(BLANK): ' + file_name + '\n', end='')
return None
try:
e = MIDI_to_encoding(midi_obj)
if len(e) == 0:
print('ERROR(BLANK): ' + file_name + '\n', end='')
return None
if ts_filter:
allowed_ts = t2e(time_signature_reduce(4, 4))
if not all(i[6] == allowed_ts for i in e):
print('ERROR(TSFILT): ' + file_name + '\n', end='')
return None
if deduplicate:
duplicated = False
dup_file_name = ''
midi_hash = '0' * 32
try:
midi_hash = get_hash(e)
except BaseException as e:
pass
lock_set.acquire()
if midi_hash in midi_dict:
dup_file_name = midi_dict[midi_hash]
duplicated = True
else:
midi_dict[midi_hash] = file_name
lock_set.release()
if duplicated:
print('ERROR(DUPLICATED): ' + midi_hash + ' ' +
file_name + ' == ' + dup_file_name + '\n', end='')
return None
output_str_list = []
sample_step = max(round(sample_len_max / sample_overlap_rate), 1)
for p in range(0 - random.randint(0, sample_len_max - 1), len(e), sample_step):
L = max(p, 0)
R = min(p + sample_len_max, len(e)) - 1
bar_index_list = [e[i][0]
for i in range(L, R + 1) if e[i][0] is not None]
bar_index_min = 0
bar_index_max = 0
if len(bar_index_list) > 0:
bar_index_min = min(bar_index_list)
bar_index_max = max(bar_index_list)
offset_lower_bound = -bar_index_min
offset_upper_bound = bar_max - 1 - bar_index_max
# to make bar index distribute in [0, bar_max)
bar_index_offset = random.randint(
offset_lower_bound, offset_upper_bound) if offset_lower_bound <= offset_upper_bound else offset_lower_bound
e_segment = []
for i in e[L: R + 1]:
if i[0] is None or i[0] + bar_index_offset < bar_max:
e_segment.append(i)
else:
break
tokens_per_note = 8
output_words = (['<s>'] * tokens_per_note) \
+ [('<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) if k is not None else '<unk>') for i in e_segment for j, k in enumerate(i)] \
+ (['</s>'] * (tokens_per_note - 1)
) # tokens_per_note - 1 for append_eos functionality of binarizer in fairseq
output_str_list.append(' '.join(output_words))
# no empty
if not all(len(i.split()) > tokens_per_note * 2 - 1 for i in output_str_list):
print('ERROR(ENCODE): ' + file_name + ' ' + str(e) + '\n', end='')
return False
try:
lock_write.acquire()
writer(file_name, output_str_list)
except BaseException as e:
print('ERROR(WRITE): ' + file_name + ' ' + str(e) + '\n', end='')
return False
finally:
lock_write.release()
print('SUCCESS: ' + file_name + '\n', end='')
return True
except BaseException as e:
print('ERROR(PROCESS): ' + file_name + ' ' + str(e) + '\n', end='')
return False
print('ERROR(GENERAL): ' + file_name + '\n', end='')
return False
def G(file_name):
try:
return F(file_name)
except BaseException as e:
print('ERROR(UNCAUGHT): ' + file_name + '\n', end='')
return False
def str_to_encoding(s):
encoding = [int(i[3: -1]) for i in s.split() if 's' not in i]
tokens_per_note = 8
assert len(encoding) % tokens_per_note == 0
encoding = [tuple(encoding[i + j] for j in range(tokens_per_note))
for i in range(0, len(encoding), tokens_per_note)]
return encoding
def encoding_to_str(e, bar_max = bar_max):
bar_index_offset = 0
p = 0
tokens_per_note = 8
return ' '.join((['<s>'] * tokens_per_note)
+ ['<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) for i in e[p: p +
sample_len_max] if i[0] + bar_index_offset < bar_max for j, k in enumerate(i)]
+ (['</s>'] * (tokens_per_note
- 1))) # 8 - 1 for append_eos functionality of binarizer in fairseq
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print('loading model and data')
roberta_base = MusicBERTModel.from_pretrained('.',
checkpoint_file = 'muzic/musicbert/checkpoints/checkpoint_last_musicbert_base_w_genre_head.pt',
# user_dir='musicbert' # activate the MusicBERT plugin with this keyword
)
samp = roberta_base.model.encoder.sentence_encoder
print(samp)
del samp
samp = roberta_base.model.encoder.lm_head
print(samp)
del samp
roberta_base.cuda()
roberta_base.eval()
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
# Load the CLIP model and processor
clip_model_id = "openai/clip-vit-base-patch32" # You can choose a different CLIP model if needed
clip_model = CLIPModel.from_pretrained(clip_model_id).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_id)
print(f"CLIP model ({clip_model_id}) loaded successfully on {device}.")
# Move to GPU if available
text_bert_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model.to(text_bert_device)
bert_model.eval() # Set to evaluation mode
print(f"Text BERT model loaded successfully on {text_bert_device}.")
text_bert_loaded = True
# model_id = "stabilityai/stable-diffusion-2-1-base" # Or try 2.1 base
# pipe = StableDiffusionPipeline.from_pretrained(
# model_id,
# torch_dtype=torch.float16 if device.type == 'cuda' else torch.float32, # Use float16 on GPU
# # Add use_auth_token=True or your HF token if needed for gated models
# )
# # Use a potentially faster/better scheduler
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# pipe.return_dict = True
# pipe.vae.register_to_config(return_dict=True)
# pipe = pipe.to(device)
# # If low GPU memory, enable attention slicing
# # pipe.enable_attention_slicing()
# sd_pipe = pipe # Assign to global variable
# stable_diffusion_loaded = True
# print(f"Stable Diffusion pipeline ({model_id}) loaded successfully.")
def parse_midi_file(sample_midi_path: str):
midi_obj = miditoolkit.midi.parser.MidiFile(sample_midi_path)
midi_name = sample_midi_path.split('/')[-1].split('.')[0]
return midi_obj, midi_name
def filter_tracks(midi_obj: miditoolkit.midi.parser.MidiFile, root_folder:str = '.',
cache_folder:str = "musicbert_cache", midi_name:str = "final", track_ids:list = []):
if len(track_ids) > 0:
new_midi_obj = miditoolkit.midi.parser.MidiFile()
new_midi_obj.instruments = [midi_obj.instruments[i] for i in track_ids]
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
new_midi_obj.time_signature_changes = midi_obj.time_signature_changes
new_midi_obj.tempo_changes = midi_obj.tempo_changes
midi_obj = new_midi_obj
midi_obj.dump(f'{root_folder}/{cache_folder}/{midi_name}_track_filtered.mid')
print(f"> Parsed MIDI: {midi_name}")
print(f"> Saved Parsed MIDI as : {root_folder}/{cache_folder}/{midi_name}_track_filtered.mid")
# playMidi(f'/content/{midi_name}_track_filtered.mid')
return midi_obj
def cache_midi_tracks(midi_obj: miditoolkit.midi.parser.MidiFile, midi_name, root_folder = '.',
cache_folder = "musicbert_cache", verbose = False):
instrument_progs = []
for track_idx, ins_track in enumerate(midi_obj.instruments):
temp_midi_obj = miditoolkit.midi.parser.MidiFile()
temp_midi_obj.instruments = [ins_track]
temp_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
temp_midi_obj.time_signature_changes = midi_obj.time_signature_changes
temp_midi_obj.tempo_changes = midi_obj.tempo_changes
instrument_progs.append(ins_track.program)
temp_midi_obj.dump(f'{root_folder}/{cache_folder}/{midi_name}_track_{track_idx}_prog_{ins_track.program}.mid')
if verbose:
print(f'The input MIDI has {len(midi_obj.instruments)} tracks with program IDs {instrument_progs} respectively.')
def reverse_label_dict(label_dict: fairseq.data.dictionary.Dictionary):
return {v: k for k, v in label_dict.indices.items()}
def decode_w_label_dict(label_dict: fairseq.data.dictionary.Dictionary, octuple_midi_enc:torch.Tensor,
skip_masked_tokens = False):
octuple_midi_enc_copy = octuple_midi_enc.clone().tolist()
seq = []
rev_inv_map = reverse_label_dict(label_dict)
for token in octuple_midi_enc_copy:
seq.append(rev_inv_map[token])
seq_str = " ".join(seq)
if skip_masked_tokens:
seq = seq_str.split()
masked_oct_idxs = set([(idx - idx%8) for idx, elem in enumerate(seq) if elem == '<mask>'])
#Deleting Octuples with any <mask> element until none remains
try:
while(True):
masked_oct_idx = seq.index('<mask>')
masked_oct_idx = masked_oct_idx - masked_oct_idx%8
del seq[masked_oct_idx: masked_oct_idx+8]
except ValueError: #Error: substring not found
pass
seq_str = " ".join(seq)
del octuple_midi_enc_copy
return seq_str
def get_bar_idx(octuple_encoding, bar):
max_bars = octuple_encoding[-1][0]
if(bar > max_bars):
print('starting bar greater than total no. of bars')
return
bars = list(zip(*octuple_encoding))[0]
return bars.index(bar)
def shift_bar_to_front(octuple_encoding):
min_bar = octuple_encoding[0][0]
for index, oct in enumerate(octuple_encoding):
oct_lst = list(oct)
oct_lst[0] -= min_bar
octuple_encoding[index] = tuple(oct_lst)
return octuple_encoding
def get_min_bar_idx_from_oct(octuple_midi_str_aslist = ("<s> "*8).split() + (" </s>"*8).split()
, min_bar_mask: int = 0):
max_bars = int(octuple_midi_str_aslist[-16][3:-1])
# print(f'max_bars = {max_bars}')
try:
assert min_bar_mask <= max_bars
except:
raise Exception(f"The input MIDI does not have {min_bar_mask} bars, it has {octuple_midi_str_aslist[-16][3:-2]} bars")
try:
# '<0-min_bar_mask>' should be present if a note from bar `min_bar_mask` is present
min_idx = octuple_midi_str_aslist.index(f'<0-{min_bar_mask}>')
except:
return get_min_bar_idx_from_oct(octuple_midi_str_aslist, min_bar_mask + 1)
# #Exception, if no note from the bar `min_bar_mask` is present, program fails
try:
assert min_idx % 8 == 0
except:
raise Exception("Fatal backend error!: min_idx not a multiple of 8")
print(f'Minimum index having {min_bar_mask} bars is {min_idx} belonging to octuple with index {int(min_idx/8)} ')
return min_idx
#Masks every element of octuples with `program` except the program entry, predicting masks on this leads to remixed instrument
#program: instrument ID (https://jazz-soft.net/demo/GeneralMidi.html)
#octuplemidi_token_encoding: like torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, ......., 2, 2, 2, 2, 2, 2, 2, 2])
#percentage_mask: how much percentage of notes of `program` instrument are to be masked
#replacement_program: the `replacement_program` instrument that will replace the masked octuples,
# perform mask prediction to predict notes of `replacement_program` in place
#mask_attribs: Only these elements of octuples will be masked from the chosen octuples that are to be masked,
# e.g: If `mask_attribs` = [0,1,3,4,5,6,7], everything except `program` will be masked in octuples, similarly,
# If `mask_attribs` = [3], only `pitch` will be masked in chosen octuples
def mask_instrument_notes_program(program: int, octuplemidi_token_encoding: torch.Tensor, \
label_dict: fairseq.data.dictionary.Dictionary, percentage_mask = 100,
replacement_program:int = None, mask_attribs = [0, 1, 3, 4, 5, 6, 7],
min_bar_mask = 0, seed = 42):
np.random.seed(seed)
octuplemidi_token_encoding = octuplemidi_token_encoding.clone()
rev_label_dict = reverse_label_dict(label_dict)
octuple_midi_str_aslist = [rev_label_dict[x] for x in octuplemidi_token_encoding.tolist()]
#Find minimum index having `positon` equal to `min_bar_mask`
#https://stackoverflow.com/questions/2361426/get-the-first-item-from-an-iterable-that-matches-a-condition
min_idx = get_min_bar_idx_from_oct(octuple_midi_str_aslist, min_bar_mask)
print(min_idx)
#Expecting soft copies to be made, i.e, changing `octuplemidi_token_encoding_mutable` also changes `octuplemidi_token_encoding`
octuplemidi_token_encoding_mutable = octuplemidi_token_encoding[min_idx: ]
octuple_midi_str_aslist_mutable = octuple_midi_str_aslist[min_idx: ]
instrument_octuple_indices = [int(index/8) for index,value in enumerate(octuple_midi_str_aslist_mutable) if value == f'<2-{program}>' ]
try:
assert len(instrument_octuple_indices) > 0
except:
raise Exception(f"No notes found with program = {program}")
print(f'Found {len(instrument_octuple_indices)} octuples with program = {program}')
print(f'Choosing {int(len(instrument_octuple_indices) * (percentage_mask/100) )} octuples to mask....')
if percentage_mask <= 100 and percentage_mask >= 0:
masked_octs = np.random.choice( a = instrument_octuple_indices , \
size = int( len(instrument_octuple_indices) * (percentage_mask/100) ), \
replace = False)
masked_octs = list(masked_octs)
masked_octs.sort(reverse = False)
#Prints octuple indices valid for original input `octuplemidi_token_encoding` and NOT `octuplemidi_token_encoding_mutable`
masked_octs_orig = [( int(min_idx/8) + x ) for x in masked_octs]
print(f'Masking octuple numbers: { masked_octs_orig}')
mask_idx = label_dict.index('<mask>')
replacement_program_idx = None
if replacement_program is not None:
replacement_program_idx = label_dict.index(f'<2-{replacement_program}>')
for masked_oct in masked_octs:
octuplemidi_token_encoding_mutable.index_fill_(0, torch.tensor( masked_oct * 8 + mask_attribs ), mask_idx)
if replacement_program is not None:
octuplemidi_token_encoding_mutable.index_fill_(0, torch.tensor( [
masked_oct * 8 + 2
]) , replacement_program_idx)
# octuplemidi_token_encoding[ masked_oct * 8: (masked_oct + 1)*8 ] = mask_idx
else:
raise IndexError
#Expecting `octuplemidi_token_encoding` to have changed when we changed `octuplemidi_token_encoding_mutable` above
octuplemidi_token_encoding[min_idx: ] = octuplemidi_token_encoding_mutable
return octuplemidi_token_encoding, masked_octs_orig
BAR_START = "<0-0>"
BAR_END = "<0-255>"
POS_START = "<1-0>"
POS_END = "<1-127>"
INS_START = "<2-0>"
INS_END = "<2-127>"
PITCH_START = "<3-0>"
PITCH_END = "<3-255>"
DUR_START = "<4-0>"
DUR_END = "<4-127>"
VEL_START = "<5-0>"
VEL_END = "<5-31>"
SIG_START = "<6-0>"
SIG_END = "<6-253>"
TEMPO_START = "<7-0>"
TEMPO_END = "<7-48>"
SPECIAL_TOKENS = ['<mask>', '<s>', '<pad>', '</s>', '<unk>']
def bar_range(label_dict): return label_dict.index(BAR_START), label_dict.index(BAR_END)+1
def pos_range(label_dict): return label_dict.index(POS_START), label_dict.index(POS_END)+1
def ins_range(label_dict): return label_dict.index(INS_START), label_dict.index(INS_END)+1
def pitch_range(label_dict): return label_dict.index(PITCH_START), label_dict.index(PITCH_END)+1
def dur_range(label_dict): return label_dict.index(DUR_START), label_dict.index(DUR_END)+1
def vel_range(label_dict): return label_dict.index(VEL_START), label_dict.index(VEL_END)+1
def sig_range(label_dict): return label_dict.index(SIG_START), label_dict.index(SIG_END)+1
def tempo_range(label_dict): return label_dict.index(TEMPO_START), label_dict.index(TEMPO_END)+1
def top_k_top_p(logits_batch, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
"""
logits_batch = logits_batch.clone()
# print(logits_batch.dim())
if(logits_batch.dim() == 1):
logits_batch = logits_batch.unsqueeze(0)
assert logits_batch.dim() == 2 # batch size 1 for now - could be updated for more but the code would be less clear
# iterate through batch size
for index, logits in enumerate(logits_batch):
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits_batch
# The tokens should be in order (`0-bar`, `1-position`, `2-instrument`, `3-pitch`, `4-duration`, `5-velocity`, `6-timesig` , `7-tempo`) so we switch temperature value accordingly
# Limit to some specific fields such as pitch temp, duration temp, velocity temp, instrument temp
def switch_temperature(prev_index: int, label_dict, temperature_dict):
""" Changes temperature to value for one of the eight fields in octuple
Args:
logits: logits distribution shape (vocabulary size)
prev_index: previous predicted token
label_dict : dictionary mapping string octuple encodings to indices
temperature_dict : dict containing temperature values for all the 8 individual octuple elements
Returns: next temperature value
"""
# First we convert the token to it's string mapping
prev_index = prev_index.item()
rev_inv_map = reverse_label_dict(label_dict)
str_encoding = rev_inv_map[prev_index]
# print(((int(str_encoding[1]) + 1)%8))
# print(str_encoding)
return temperature_dict[((int(str_encoding[1]) + 1)%(8))]
def filter_invalid_indexes(logits, prev_index, label_dict, filter_value=-float('Inf')):
""" Filter a distribution of logits using prev_predicted token
Args:
logits: logits distribution shape (vocabulary size)
prev_index: previous predicted token
label_dict : dictionary mapping string octuple encodings to indices
Returns: filtered logits according to prev_idx
"""
logits = logits.clone()
prev_index = prev_index.item()
rev_inv_map = reverse_label_dict(label_dict)
str_encoding = rev_inv_map[prev_index]
# For example if previous index was pitch than according to Octuple encoding next note should be duration
# Therefore we fill up all the other 7 element ranges with infinity
for tok in SPECIAL_TOKENS:
logits[label_dict.index(tok)] = filter_value
# if previous token was 'bar' then we mask everything excluding 'pos'
if(str_encoding[1] == '0'):
logits[list(range(*bar_range(label_dict)))] = filter_value
logits[list(range(*ins_range(label_dict)))] = filter_value
logits[list(range(*pitch_range(label_dict)))] = filter_value
logits[list(range(*dur_range(label_dict)))] = filter_value
logits[list(range(*vel_range(label_dict)))] = filter_value
logits[list(range(*sig_range(label_dict)))] = filter_value
logits[list(range(*tempo_range(label_dict)))] = filter_value
# pos
elif(str_encoding[1] == '1'):
logits[list(range(*bar_range(label_dict)))] = filter_value
logits[list(range(*pos_range(label_dict)))] = filter_value
logits[list(range(*pitch_range(label_dict)))] = filter_value
logits[list(range(*dur_range(label_dict)))] = filter_value
logits[list(range(*vel_range(label_dict)))] = filter_value
logits[list(range(*sig_range(label_dict)))] = filter_value
logits[list(range(*tempo_range(label_dict)))] = filter_value
# ins
elif(str_encoding[1] == '2'):
logits[list(range(*bar_range(label_dict)))] = filter_value
logits[list(range(*pos_range(label_dict)))] = filter_value
logits[list(range(*ins_range(label_dict)))] = filter_value
logits[list(range(*dur_range(label_dict)))] = filter_value
logits[list(range(*vel_range(label_dict)))] = filter_value
logits[list(range(*sig_range(label_dict)))] = filter_value
logits[list(range(*tempo_range(label_dict)))] = filter_value
# pitch
elif(str_encoding[1] == '3'):
logits[list(range(*bar_range(label_dict)))] = filter_value
logits[list(range(*pos_range(label_dict)))] = filter_value
logits[list(range(*ins_range(label_dict)))] = filter_value
logits[list(range(*pitch_range(label_dict)))] = filter_value
logits[list(range(*vel_range(label_dict)))] = filter_value
logits[list(range(*sig_range(label_dict)))] = filter_value
logits[list(range(*tempo_range(label_dict)))] = filter_value
# dur
elif(str_encoding[1] == '4'):
logits[list(range(*bar_range(label_dict)))] = filter_value
logits[list(range(*pos_range(label_dict)))] = filter_value
logits[list(range(*ins_range(label_dict)))] = filter_value
logits[list(range(*pitch_range(label_dict)))] = filter_value
logits[list(range(*dur_range(label_dict)))] = filter_value
logits[list(range(*sig_range(label_dict)))] = filter_value
logits[list(range(*tempo_range(label_dict)))] = filter_value
# vel
elif(str_encoding[1] == '5'):
logits[list(range(*bar_range(label_dict)))] = filter_value
logits[list(range(*pos_range(label_dict)))] = filter_value
logits[list(range(*ins_range(label_dict)))] = filter_value
logits[list(range(*pitch_range(label_dict)))] = filter_value
logits[list(range(*dur_range(label_dict)))] = filter_value
logits[list(range(*vel_range(label_dict)))] = filter_value
logits[list(range(*tempo_range(label_dict)))] = filter_value
# sig
elif(str_encoding[1] == '6'):
logits[list(range(*bar_range(label_dict)))] = filter_value
logits[list(range(*pos_range(label_dict)))] = filter_value
logits[list(range(*ins_range(label_dict)))] = filter_value
logits[list(range(*pitch_range(label_dict)))] = filter_value
logits[list(range(*dur_range(label_dict)))] = filter_value
logits[list(range(*vel_range(label_dict)))] = filter_value
logits[list(range(*sig_range(label_dict)))] = filter_value
# tempo
elif(str_encoding[1] == '7'):
logits[list(range(*pos_range(label_dict)))] = filter_value
logits[list(range(*ins_range(label_dict)))] = filter_value
logits[list(range(*pitch_range(label_dict)))] = filter_value
logits[list(range(*dur_range(label_dict)))] = filter_value
logits[list(range(*vel_range(label_dict)))] = filter_value
logits[list(range(*sig_range(label_dict)))] = filter_value
logits[list(range(*tempo_range(label_dict)))] = filter_value
return logits
class PRED_MODE(Enum):
VANILLA = 1
OCTUPLE_MODE = 2
#Helper function for OCTUPLE_MODE and MULTI_OCTUPLE_MODE
# https://www.geeksforgeeks.org/break-list-chunks-size-n-python/
def split_multi_oct(list_a, chunk_size):
for i in range(0, len(list_a), chunk_size):
yield list(chain(*list_a[i:i + chunk_size]))
#Predict missing masks in sequence from left to right
'''
NOTE: Predicts ONLY the masked octuples provided in `masked_octuples` if `prediction_mode` is NOT Vanilla
Else if prediction_mode is Vanilla, it predicts all the masks in the input `octuplemidi_token_encoding`
'''
#octuplemidi_token_encoding: of the format torch.Tensor([0,0,0,0,0,0,0,0, ..........., 2,2,2,2,2,2,2,2]), where 0 is label_dict.bos_idx & 2 is label_dict.eos_idx
#prediction_mode: decides the speed of the mask prediction
#mask_attribs: decides which of the (`bar`, `position`, `instrument`, `pitch`, `duration`, `velocity`, `timesig` , `tempo`) are masked
#masked_octuples: List of the octuple indices in `octuplemidi_token_encoding` that are masked, note that the element indices for `bar` field of these elements would be (mask_octuple_idxs * 8)
def predict_all_masks(roberta_model, roberta_label_dict, temperature_dict, octuplemidi_token_encoding:torch.Tensor, masked_octuples:list = None,
prediction_mode:PRED_MODE = PRED_MODE.VANILLA, mask_attribs:list = [3,4,5] ,num_multi_octuples:int = None,
temperature = 1.0, top_k=30, top_p=0.6,
verbose = False):
mask_idx = 1236
octuplemidi_token_encoding = octuplemidi_token_encoding.clone()
try:
assert octuplemidi_token_encoding.dim() == 1
except:
raise Exception('Please input single dimensional octuple sequence')
try:
bos_idx = roberta_label_dict.bos_index
eos_idx = roberta_label_dict.eos_index
tens_type = octuplemidi_token_encoding.dtype
assert torch.equal(octuplemidi_token_encoding[:8], torch.Tensor([bos_idx]*8).type(tens_type)) and \
torch.equal(octuplemidi_token_encoding[-8:], torch.Tensor([eos_idx]*8).type(tens_type))
except:
print('Start:', octuplemidi_token_encoding[:8] )
print(torch.Tensor([bos_idx]*8))
print('End:', octuplemidi_token_encoding[-8:])
print(torch.Tensor([bos_idx]*8))
raise Exception('`octuplemidi_token_encoding` either does not have 8 <s> tokens or 8 </s> tokens at beginning and end')
#---------------------------------------------------------------
# Altering input mask list based on `prediction_mode`
#---------------------------------------------------------------
mask_indices = None
# If `masked_octuples` not provided, then the `prediction_mode` MUST be Vanilla
if masked_octuples == None:
try:
assert prediction_mode == PRED_MODE.VANILLA
except:
#Since the current faster implementations involves the premise that in all the masked notes, same fields of each octuple is masked,
#For example, we are not considering that in the sequence one octuple has just `duration` masked and another has just `pitch` masked
raise Exception("Error: Please choose `prediction_mode` as Vanilla since `masked_octuples` is not provided, to use faster modes provide `mask_indices`")
mask_indices = [i for i, x in enumerate(octuplemidi_token_encoding.tolist()) if x == mask_idx]
elif prediction_mode == PRED_MODE.VANILLA:
print('Warning: Ignoring `masked_octuples`, `mask_attribs` & `num_multi_octuples` as `prediction_mode` is set as Vanilla')
mask_indices = [i for i, x in enumerate(octuplemidi_token_encoding.tolist()) if x == mask_idx]
elif prediction_mode == PRED_MODE.OCTUPLE_MODE:
if num_multi_octuples is not None:
print('Warning: Ignoring `num_multi_octuples` as `prediction_mode` is set as Octuple mode (not Multi-octuple mode)')
mask_indices = [ [x*8 + y for y in mask_attribs] for x in masked_octuples]
else:
raise Exception("Invalid `prediction_mode`")
try:
assert len(mask_indices) > 0
except AssertionError:
raise Exception('Please input sentence tokens with at least one mask token')
try:
assert all( torch.all(octuplemidi_token_encoding[octuple_midi_mask_elem] == mask_idx) for octuple_midi_mask_elem in mask_indices )
except:
print([octuplemidi_token_encoding[octuple_midi_mask_elem] == mask_idx for octuple_midi_mask_elem in mask_indices])
raise Exception('Fatal error: At least one element of `mask_indices` is not <mask> (1236)')
#--------------------------------------------------------------
# Inputting masked indices to model using `prediction_strategy`
#--------------------------------------------------------------
if prediction_mode == PRED_MODE.VANILLA:
pass
elif prediction_mode == PRED_MODE.OCTUPLE_MODE:
#Checking if `mask_attribs` is fine
try:
mask_attribs_len = len(mask_attribs)
assert mask_attribs_len > 0 and \
len(set(mask_attribs)) == mask_attribs_len and \
all( (x >= 0 and x < 8) for x in mask_attribs)
except:
raise Exception("`mask_attribs` not appropriate")