-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdata.py
More file actions
82 lines (67 loc) · 3.03 KB
/
data.py
File metadata and controls
82 lines (67 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torchaudio
from functools import partial
from torch.utils.data import DataLoader
SAMPLE_RATE = 16000
def collect_audio_batch(batch, extra_noise=0., maxLen=600000):
'''Collects a batch, should be list of tuples (audio_path <str>, list of int token <list>)
e.g. [(file1,txt1),(file2,txt2),...]
'''
def audio_reader(filepath):
wav, sample_rate = torchaudio.load(filepath)
if sample_rate != SAMPLE_RATE:
wav = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)(wav)
wav = wav.reshape(-1)
if wav.shape[-1] >= maxLen:
print(f'{filepath} has len {wav.shape}, truncate to {maxLen}')
wav = wav[:maxLen]
wav += extra_noise * torch.randn_like(wav)
return wav
# Bucketed batch should be [[(file1,txt1),(file2,txt2),...]]
if type(batch[0]) is not tuple:
batch = batch[0]
# Read batch
file, audio_feat, audio_len, text = [], [], [], []
with torch.no_grad():
for b in batch:
feat = audio_reader(str(b[0])).numpy()
# feat = audio_reader(str(b[0]))
file.append(str(b[0]).split('/')[-1].split('.')[0])
audio_feat.append(feat)
audio_len.append(len(feat))
text.append(b[1])
return torch.tensor(audio_len), audio_feat, text, file
def create_dataset(name, path, batch_size=1, noise_type=None, noise_snr=None):
''' Interface for creating all kinds of dataset'''
# Recognize corpus
if name.lower() == "librispeech":
from corpus.librispeech import LibriDataset as Dataset
elif name.lower() == "chime":
from corpus.CHiME import CHiMEDataset as Dataset
elif name.lower() == "ted":
from corpus.ted import TedDataset as Dataset
elif name.lower() == "commonvoice":
from corpus.commonvoice import CVDataset as Dataset
elif name.lower() == "valentini":
from corpus.valentini import ValDataset as Dataset
elif name.lower() =="l2arctic":
from corpus.l2arctic import L2ArcticDataset as Dataset
else:
raise NotImplementedError
loader_bs = batch_size
if name.lower() == "librispeech":
dataset = Dataset(batch_size, path, noise_type=noise_type, noise_snr=noise_snr)
else:
dataset = Dataset(batch_size, path)
print(f'[INFO] There are {len(dataset)} samples.')
return dataset, loader_bs
def load_dataset(name='librispeech', path=None, batch_size=1, extra_noise=0., noise_type=None, noise_snr=None, num_workers=4):
''' Prepare dataloader for training/validation'''
dataset, loader_bs = create_dataset(name, path, batch_size, noise_type=noise_type, noise_snr=noise_snr)
if name == "librispeech" and noise_type == None:
collate_fn = partial(collect_audio_batch, extra_noise=extra_noise)
else:
collate_fn = partial(collect_audio_batch, extra_noise=0)
dataloader = DataLoader(dataset, batch_size=loader_bs, shuffle=False,
collate_fn=collate_fn, num_workers=num_workers)
return dataloader